sp_rpc.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. #include "precompile.h"
  2. #include "sp_def.h"
  3. #include "sp_svc.h"
  4. #include "sp_rpc.h"
  5. #include "sp_dbg_export.h"
  6. #include "list.h"
  7. #include "memutil.h"
  8. #include "spinlock.h"
  9. #include "refcnt.h"
  10. #include <winpr/synch.h>
  11. #define BUCKET_SIZE 127
  12. /*
  13. Create +------------+ SENT +------------+ ANS +------------+ Destroy +------------+
  14. ------> | INIT | ----> | SENT | ----->| CALLED | -------->| TERM |
  15. +------------+ +------------+ +------------+ +------------+
  16. */
  17. #define STATE_INIT 0
  18. #define STATE_SENT 1
  19. #define STATE_CALLED 2
  20. #define STATE_TERM 3
  21. #define STATE_ERROR 4
  22. #define RPC_CMD_INFO 0
  23. #define RPC_CMD_REQ 1
  24. #define RPC_CMD_ANS 2
  25. struct sp_rpc_server_t
  26. {
  27. int stop;
  28. sp_rpc_server_callback cb;
  29. sp_svc_t *svc;
  30. DECLARE_REF_COUNT_MEMBER(ref_cnt);
  31. };
  32. DECLARE_REF_COUNT_STATIC(sp_rpc_server, sp_rpc_server_t)
  33. static void __threadpool_server_on_pkt(threadpool_t *threadpool, void *arg, param_size_t param1, param_size_t param2)
  34. {
  35. sp_rpc_server_t *server = (sp_rpc_server_t *)arg;
  36. iobuffer_t *pkt = (iobuffer_t*)param1;
  37. int epid;
  38. int svc_id;
  39. int pkt_type;
  40. int pkt_id;
  41. int cmd_type;
  42. iobuffer_read(pkt, IOBUF_T_I4, &epid, 0);
  43. iobuffer_read(pkt, IOBUF_T_I4, &svc_id, 0);
  44. iobuffer_read(pkt, IOBUF_T_I4, &pkt_type, 0);
  45. iobuffer_read(pkt, IOBUF_T_I4, &pkt_id, 0);
  46. cmd_type = SP_GET_TYPE(pkt_type);
  47. if (cmd_type == RPC_CMD_INFO) {
  48. server->cb.on_info(server, epid, svc_id, pkt_id, &pkt, server->cb.user_data);
  49. } else if (cmd_type == RPC_CMD_REQ) {
  50. int call_type;
  51. iobuffer_read(pkt, IOBUF_T_I4, &call_type, NULL);
  52. server->cb.on_req(server, epid, svc_id, pkt_id, call_type, &pkt, server->cb.user_data);
  53. } else {
  54. sp_dbg_warn("RPC CMD unknown types!");
  55. }
  56. sp_rpc_server_dec_ref(server); // @
  57. if (pkt)
  58. iobuffer_dec_ref(pkt);
  59. }
  60. static int server_on_pkt(sp_svc_t *svc, int epid, int svc_id, int pkt_type, int pkt_id, iobuffer_t **p_pkt, void *user_data)
  61. {
  62. sp_rpc_server_t *server = (sp_rpc_server_t*)user_data;
  63. int rc;
  64. iobuffer_t *pkt;
  65. pkt = *p_pkt;
  66. *p_pkt = NULL;
  67. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_id, 0);
  68. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_type, 0);
  69. iobuffer_write_head(pkt, IOBUF_T_I4, &svc_id, 0);
  70. iobuffer_write_head(pkt, IOBUF_T_I4, &epid, 0);
  71. sp_rpc_server_inc_ref(server); // @
  72. rc = threadpool_queue_workitem2(sp_svc_get_threadpool(svc), NULL, &__threadpool_server_on_pkt, server, (param_size_t)pkt, 0);
  73. if (rc != 0) {
  74. sp_rpc_server_dec_ref(server); // @
  75. iobuffer_dec_ref(pkt);
  76. }
  77. return FALSE;
  78. }
  79. int sp_rpc_server_create(sp_svc_t *svc, sp_rpc_server_callback *cb, sp_rpc_server_t **p_server)
  80. {
  81. sp_rpc_server_t *server = MALLOC_T(sp_rpc_server_t);
  82. server->stop = 0;
  83. memcpy(&server->cb, cb, sizeof(sp_rpc_server_callback));
  84. server->svc = svc;
  85. REF_COUNT_INIT(&server->ref_cnt);
  86. *p_server = server;
  87. return 0;
  88. }
  89. void sp_rpc_server_destroy(sp_rpc_server_t *server)
  90. {
  91. sp_rpc_server_dec_ref(server);
  92. }
  93. int sp_rpc_server_start(sp_rpc_server_t *server)
  94. {
  95. server->stop = 0;
  96. return sp_svc_add_pkt_handler(server->svc, (int)server, SP_PKT_RPC, &server_on_pkt, server);
  97. }
  98. int sp_rpc_server_stop(sp_rpc_server_t *server)
  99. {
  100. // BugFix [4/5/2020 11:55 Gifur]
  101. if (/*!*/server->stop)
  102. return Error_Bug;
  103. server->stop = 1;
  104. return sp_svc_remove_pkt_handler(server->svc, (int)server, SP_PKT_RPC);
  105. }
  106. sp_svc_t *sp_rpc_server_get_svc(sp_rpc_server_t *server)
  107. {
  108. return server->svc;
  109. }
  110. int sp_rpc_server_send_answer(sp_rpc_server_t *server, int epid, int svc_id, int rpc_id, iobuffer_t **ans_pkt)
  111. {
  112. return sp_svc_post(server->svc, epid, svc_id, SP_PKT_RPC | RPC_CMD_ANS, rpc_id, ans_pkt);
  113. }
  114. static void __sp_rpc_destroy(sp_rpc_server_t *server)
  115. {
  116. if (server->cb.on_destroy) {
  117. (*server->cb.on_destroy)(server, server->cb.user_data);
  118. }
  119. free(server);
  120. }
  121. IMPLEMENT_REF_COUNT_MT(sp_rpc_server, sp_rpc_server_t, ref_cnt, __sp_rpc_destroy)
  122. struct sp_rpc_client_t
  123. {
  124. struct hlist_node hentry; // element of sp_rpc_client_mgr_t->rpc_buckets[index]
  125. int state;
  126. int remote_epid;
  127. int remote_svc_id;
  128. unsigned int rpc_id;
  129. int call_type;
  130. spinlock_t lock;
  131. sp_rpc_client_callback cb;
  132. sp_rpc_client_mgr_t *mgr;
  133. DECLARE_REF_COUNT_MEMBER(ref_cnt);
  134. };
  135. DECLARE_REF_COUNT_STATIC(sp_rpc_client, sp_rpc_client_t)
  136. struct sp_rpc_client_mgr_t
  137. {
  138. struct hlist_head rpc_buckets[BUCKET_SIZE]; // list of sp_rpc_client_t
  139. sp_svc_t *svc;
  140. int rpc_cnt;
  141. int stop;
  142. int local_seq;
  143. sp_rpc_client_mgr_callback cb;
  144. CRITICAL_SECTION lock;
  145. DECLARE_REF_COUNT_MEMBER(ref_cnt);
  146. };
  147. DECLARE_REF_COUNT_STATIC(sp_rpc_client_mgr, sp_rpc_client_mgr_t)
  148. static __inline void mgr_lock(sp_rpc_client_mgr_t *mgr)
  149. {
  150. EnterCriticalSection(&mgr->lock);
  151. }
  152. static __inline void mgr_unlock(sp_rpc_client_mgr_t *mgr)
  153. {
  154. LeaveCriticalSection(&mgr->lock);
  155. }
  156. static __inline void client_lock(sp_rpc_client_t *client)
  157. {
  158. spinlock_enter(&client->lock, -1);
  159. }
  160. static __inline void client_unlock(sp_rpc_client_t *client)
  161. {
  162. spinlock_leave(&client->lock);
  163. }
  164. static void client_set_error(sp_rpc_client_t *client, int error);
  165. static void client_process_ans(sp_rpc_client_t *client, iobuffer_t **ans_pkt);
  166. static void __threadpool_mgr_on_req(threadpool_t *threadpool, void *arg, param_size_t param1, param_size_t param2)
  167. {
  168. sp_rpc_client_mgr_t *mgr = (sp_rpc_client_mgr_t *)arg;
  169. iobuffer_t *pkt = (iobuffer_t*)param1;
  170. int epid;
  171. int svc_id;
  172. int pkt_type;
  173. int pkt_id;
  174. int cmd_type;
  175. iobuffer_read(pkt, IOBUF_T_I4, &epid, 0);
  176. iobuffer_read(pkt, IOBUF_T_I4, &svc_id, 0);
  177. iobuffer_read(pkt, IOBUF_T_I4, &pkt_type, 0);
  178. iobuffer_read(pkt, IOBUF_T_I4, &pkt_id, 0);
  179. cmd_type = SP_GET_TYPE(pkt_type);
  180. if (cmd_type == RPC_CMD_REQ && mgr->cb.on_req)
  181. {
  182. int call_type;
  183. iobuffer_read(pkt, IOBUF_T_I4, &call_type, NULL);
  184. mgr->cb.on_req(mgr, epid, svc_id, pkt_id, call_type, &pkt, mgr->cb.user_data);
  185. }
  186. else
  187. {
  188. sp_dbg_warn("RPC CMD unknown types!");
  189. }
  190. sp_rpc_client_mgr_dec_ref(mgr); // @
  191. if (pkt)
  192. iobuffer_dec_ref(pkt);
  193. }
  194. static int mgr_on_pkt(sp_svc_t *svc,int epid, int svc_id, int pkt_type, int pkt_id, iobuffer_t **p_pkt, void *user_data)
  195. {
  196. sp_rpc_client_mgr_t *mgr = (sp_rpc_client_mgr_t*)user_data;
  197. sp_dbg_debug("sp_rpc::mgr_on_pkt: epid:%d, svc_id: %d, pkt_type:0x%08X, pkt_id: %d, rpc:%d", epid, svc_id, pkt_type, pkt_id, SP_GET_TYPE(pkt_type));
  198. if (SP_GET_TYPE(pkt_type) == RPC_CMD_ANS) {
  199. int rpc_id = pkt_id;
  200. int slot = ((unsigned int)rpc_id) % BUCKET_SIZE;
  201. sp_rpc_client_t *tpos;
  202. struct hlist_node *pos, *n;
  203. mgr_lock(mgr);
  204. hlist_for_each_entry_safe(tpos, pos, n, &mgr->rpc_buckets[slot], sp_rpc_client_t, hentry) {
  205. if (tpos->rpc_id == rpc_id) {
  206. client_process_ans(tpos, p_pkt);
  207. break;
  208. }
  209. }
  210. mgr_unlock(mgr);
  211. return FALSE;
  212. }
  213. else if (SP_GET_TYPE(pkt_type) == RPC_CMD_REQ)
  214. {
  215. int rc;
  216. iobuffer_t *pkt = *p_pkt;
  217. *p_pkt = NULL;
  218. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_id, 0);
  219. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_type, 0);
  220. iobuffer_write_head(pkt, IOBUF_T_I4, &svc_id, 0);
  221. iobuffer_write_head(pkt, IOBUF_T_I4, &epid, 0);
  222. sp_rpc_client_mgr_inc_ref(mgr);
  223. rc = threadpool_queue_workitem2(sp_svc_get_threadpool(svc), NULL, &__threadpool_mgr_on_req, mgr, (param_size_t)pkt, 0);
  224. if (rc != 0) {
  225. sp_rpc_client_mgr_dec_ref(mgr); // @
  226. iobuffer_dec_ref(pkt);
  227. }
  228. }
  229. return TRUE;
  230. }
  231. static void mgr_on_sys(sp_svc_t *svc,int epid, int state, void *user_data)
  232. {
  233. sp_rpc_client_mgr_t *mgr = (sp_rpc_client_mgr_t*)user_data;
  234. if (state == BUS_STATE_OFF) {
  235. int i;
  236. sp_rpc_client_t *tpos;
  237. struct hlist_node *pos, *n;
  238. mgr_lock(mgr);
  239. for (i = 0; i < BUCKET_SIZE; ++i) {
  240. hlist_for_each_entry_safe(tpos, pos, n, &mgr->rpc_buckets[i], sp_rpc_client_t, hentry) {
  241. if (tpos->remote_epid == epid) {
  242. client_set_error(tpos, Error_NetBroken);
  243. }
  244. }
  245. }
  246. mgr_unlock(mgr);
  247. }
  248. }
  249. int sp_rpc_client_mgr_create(sp_svc_t *svc, sp_rpc_client_mgr_callback *cb, sp_rpc_client_mgr_t **p_mgr)
  250. {
  251. int i;
  252. sp_rpc_client_mgr_t *mgr = MALLOC_T(sp_rpc_client_mgr_t);
  253. mgr->local_seq = 0;
  254. mgr->rpc_cnt = 0;
  255. mgr->stop = 0;
  256. mgr->svc = svc;
  257. memcpy(&mgr->cb, cb, sizeof(sp_rpc_client_mgr_callback));
  258. for (i = 0;i < BUCKET_SIZE; ++i) {
  259. INIT_HLIST_HEAD(&mgr->rpc_buckets[i]);
  260. }
  261. InitializeCriticalSection(&mgr->lock);
  262. REF_COUNT_INIT(&mgr->ref_cnt);
  263. *p_mgr = mgr;
  264. return 0;
  265. }
  266. // {bug} not delete rpc_buckets arrary
  267. void sp_rpc_client_mgr_destroy(sp_rpc_client_mgr_t *mgr)
  268. {
  269. sp_rpc_client_mgr_dec_ref(mgr);
  270. }
  271. int sp_rpc_client_mgr_start(sp_rpc_client_mgr_t *mgr)
  272. {
  273. mgr->stop = 0;
  274. sp_svc_add_pkt_handler(mgr->svc, (int)mgr, SP_PKT_RPC, &mgr_on_pkt, mgr);
  275. sp_svc_add_sys_handler(mgr->svc, (int)mgr, &mgr_on_sys, mgr);
  276. return 0;
  277. }
  278. int sp_rpc_client_mgr_stop(sp_rpc_client_mgr_t *mgr)
  279. {
  280. sp_svc_remove_pkt_handler(mgr->svc, (int)mgr, SP_PKT_RPC);
  281. sp_svc_remove_sys_handler(mgr->svc, (int)mgr);
  282. return 0;
  283. }
  284. sp_svc_t *sp_rpc_client_mgr_get_svc(sp_rpc_client_mgr_t *mgr)
  285. {
  286. return mgr->svc;
  287. }
  288. int sp_rpc_client_mgr_cancel_all(sp_rpc_client_mgr_t *mgr)
  289. {
  290. int i;
  291. mgr_lock(mgr);
  292. for (i = 0; i < BUCKET_SIZE; ++i) {
  293. sp_rpc_client_t *tpos;
  294. struct hlist_node *pos;
  295. hlist_for_each_entry(tpos, pos, &mgr->rpc_buckets[i], sp_rpc_client_t, hentry) {
  296. client_set_error(tpos, Error_Cancel);
  297. }
  298. }
  299. mgr_unlock(mgr);
  300. return 0;
  301. }
  302. int sp_rpc_client_mgr_get_client_cnt(sp_rpc_client_mgr_t *mgr)
  303. {
  304. return mgr->rpc_cnt;
  305. }
  306. int sp_rpc_client_mgr_one_way_call(sp_rpc_client_mgr_t *mgr, int epid, int svc_id, int call_type, iobuffer_t **info_pkt)
  307. {
  308. return sp_svc_post(mgr->svc, epid, svc_id, SP_PKT_RPC| RPC_CMD_INFO, call_type, info_pkt);
  309. }
  310. int sp_rpc_client_mgr_send_answer(sp_rpc_client_mgr_t *mgr, int epid, int svc_id, int rpc_id, iobuffer_t **ans_pkt)
  311. {
  312. return sp_svc_post(mgr->svc, epid, svc_id, SP_PKT_RPC | RPC_CMD_ANS, rpc_id, ans_pkt);
  313. }
  314. static void __sp_rpc_client_mgr_destroy(sp_rpc_client_mgr_t *mgr)
  315. {
  316. if (mgr->cb.on_destroy)
  317. mgr->cb.on_destroy(mgr, mgr->cb.user_data);
  318. DeleteCriticalSection(&mgr->lock);
  319. free(mgr);
  320. }
  321. IMPLEMENT_REF_COUNT_MT_STATIC(sp_rpc_client_mgr, sp_rpc_client_mgr_t, ref_cnt, __sp_rpc_client_mgr_destroy)
  322. int sp_rpc_client_create(sp_rpc_client_mgr_t *mgr, int epid, int svc_id, int call_type, sp_rpc_client_callback *cb, sp_rpc_client_t **p_client)
  323. {
  324. sp_rpc_client_t *client = MALLOC_T(sp_rpc_client_t);
  325. client->mgr = mgr;
  326. client->remote_epid = epid;
  327. client->remote_svc_id = svc_id;
  328. client->call_type = call_type;
  329. memcpy(&client->cb, cb, sizeof(sp_rpc_client_callback));
  330. client->rpc_id = (int)InterlockedIncrement((LONG*)&mgr->local_seq);
  331. spinlock_init(&client->lock);
  332. client->state = STATE_INIT;
  333. REF_COUNT_INIT(&client->ref_cnt);
  334. sp_rpc_client_mgr_inc_ref(mgr);
  335. sp_rpc_client_inc_ref(client);
  336. mgr_lock(mgr);
  337. hlist_add_head(&client->hentry, &mgr->rpc_buckets[client->rpc_id % BUCKET_SIZE]);
  338. client->mgr->rpc_cnt++;
  339. mgr_unlock(mgr);
  340. *p_client = client;
  341. return 0;
  342. }
  343. int sp_rpc_client_close(sp_rpc_client_t *client)
  344. {
  345. int rc;
  346. client_lock(client);
  347. if (client->state != STATE_TERM && client->state != STATE_ERROR) {
  348. client->state = STATE_ERROR;
  349. rc = 0;
  350. } else {
  351. rc = Error_Duplication;
  352. }
  353. client_unlock(client);
  354. return rc;
  355. }
  356. void sp_rpc_client_destroy(sp_rpc_client_t *client)
  357. {
  358. mgr_lock(client->mgr);
  359. client->mgr->rpc_cnt --;
  360. hlist_del(&client->hentry);
  361. mgr_unlock(client->mgr);
  362. sp_rpc_client_dec_ref(client);
  363. client_lock(client);
  364. client->state = STATE_TERM;
  365. client_unlock(client);
  366. sp_rpc_client_dec_ref(client);
  367. }
  368. int sp_rpc_client_async_call(sp_rpc_client_t *client, iobuffer_t **req_pkt)
  369. {
  370. sp_rpc_client_mgr_t *mgr = client->mgr;
  371. int rc = 0;
  372. if (client->state != STATE_INIT)
  373. return Error_Bug;
  374. client_lock(client);
  375. if (client->state == STATE_INIT) {
  376. client->state = STATE_SENT;
  377. sp_rpc_client_inc_ref(client); // @
  378. iobuffer_write_head(*req_pkt, IOBUF_T_I4, &client->call_type, 0);
  379. rc = sp_svc_post(mgr->svc, client->remote_epid, client->remote_svc_id, SP_PKT_RPC|RPC_CMD_REQ, client->rpc_id, req_pkt);
  380. if (rc != 0) {
  381. sp_rpc_client_dec_ref(client); // @
  382. client->state = STATE_ERROR;
  383. }
  384. } else {
  385. rc = Error_NetBroken;
  386. }
  387. client_unlock(client);
  388. return rc;
  389. }
  390. int sp_rpc_client_get_rpc_id(sp_rpc_client_t *client)
  391. {
  392. return client->rpc_id;
  393. }
  394. int sp_rpc_client_get_remote_epid(sp_rpc_client_t *client)
  395. {
  396. return client->remote_epid;
  397. }
  398. int sp_rpc_client_get_remote_svc_id(sp_rpc_client_t *client)
  399. {
  400. return client->remote_svc_id;
  401. }
  402. static void client_set_error(sp_rpc_client_t *client, int error)
  403. {
  404. if (client->state != STATE_ERROR && client->state != STATE_TERM) {
  405. client_lock(client);
  406. if (client->state != STATE_ERROR && client->state != STATE_TERM) {
  407. if (client->state == STATE_SENT) {
  408. if (client->cb.on_ans) {
  409. sp_dbg_debug("%s::on_ans(%d) ", __FUNCTION__, error);
  410. client->cb.on_ans(client, error, NULL, client->cb.user_data);
  411. }
  412. } else {
  413. client->state = STATE_ERROR;
  414. }
  415. }
  416. client_unlock(client);
  417. }
  418. sp_rpc_client_dec_ref(client); // @
  419. }
  420. static void client_process_ans(sp_rpc_client_t *client, iobuffer_t **ans_pkt)
  421. {
  422. if (client->state == STATE_SENT) {
  423. client_lock(client);
  424. if (client->state == STATE_SENT) {
  425. client->state = STATE_CALLED;
  426. if (client->cb.on_ans) {
  427. sp_dbg_debug("%s::on_ans() ", __FUNCTION__);
  428. client->cb.on_ans(client, 0, ans_pkt, client->cb.user_data);
  429. }
  430. }
  431. client_unlock(client);
  432. }
  433. sp_rpc_client_dec_ref(client); // @
  434. }
  435. static void __client_destroy(sp_rpc_client_t *client)
  436. {
  437. if (client->cb.on_destroy)
  438. client->cb.on_destroy(client, client->cb.user_data);
  439. sp_rpc_client_mgr_dec_ref(client->mgr);
  440. free(client);
  441. }
  442. IMPLEMENT_REF_COUNT_MT_STATIC(sp_rpc_client, sp_rpc_client_t, ref_cnt, __client_destroy)