cygrpc_test.py 16 KB


  1. # Copyright 2015 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. import threading
  16. import unittest
  17. import platform
  18. from grpc._cython import cygrpc
  19. from tests.unit._cython import test_utilities
  20. from tests.unit import test_common
  21. from tests.unit import resources
  22. _SSL_HOST_OVERRIDE = b'foo.test.google.fr'
  23. _CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
  24. _CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
  25. _EMPTY_FLAGS = 0
  26. def _metadata_plugin(context, callback):
  27. callback(((
  28. _CALL_CREDENTIALS_METADATA_KEY,
  29. _CALL_CREDENTIALS_METADATA_VALUE,
  30. ),), cygrpc.StatusCode.ok, b'')
  31. class TypeSmokeTest(unittest.TestCase):
  32. def testCompletionQueueUpDown(self):
  33. completion_queue = cygrpc.CompletionQueue()
  34. del completion_queue
  35. def testServerUpDown(self):
  36. server = cygrpc.Server(set([(
  37. b'grpc.so_reuseport',
  38. 0,
  39. )]), False)
  40. del server
  41. def testChannelUpDown(self):
  42. channel = cygrpc.Channel(b'[::]:0', None, None)
  43. channel.close(cygrpc.StatusCode.cancelled, 'Test method anyway!')
  44. def test_metadata_plugin_call_credentials_up_down(self):
  45. cygrpc.MetadataPluginCallCredentials(_metadata_plugin,
  46. b'test plugin name!')
  47. def testServerStartNoExplicitShutdown(self):
  48. server = cygrpc.Server([(
  49. b'grpc.so_reuseport',
  50. 0,
  51. )], False)
  52. completion_queue = cygrpc.CompletionQueue()
  53. server.register_completion_queue(completion_queue)
  54. port = server.add_http2_port(b'[::]:0')
  55. self.assertIsInstance(port, int)
  56. server.start()
  57. del server
  58. def testServerStartShutdown(self):
  59. completion_queue = cygrpc.CompletionQueue()
  60. server = cygrpc.Server([
  61. (
  62. b'grpc.so_reuseport',
  63. 0,
  64. ),
  65. ], False)
  66. server.add_http2_port(b'[::]:0')
  67. server.register_completion_queue(completion_queue)
  68. server.start()
  69. shutdown_tag = object()
  70. server.shutdown(completion_queue, shutdown_tag)
  71. event = completion_queue.poll()
  72. self.assertEqual(cygrpc.CompletionType.operation_complete,
  73. event.completion_type)
  74. self.assertIs(shutdown_tag, event.tag)
  75. del server
  76. del completion_queue
  77. class ServerClientMixin(object):
  78. def setUpMixin(self, server_credentials, client_credentials, host_override):
  79. self.server_completion_queue = cygrpc.CompletionQueue()
  80. self.server = cygrpc.Server([(
  81. b'grpc.so_reuseport',
  82. 0,
  83. )], False)
  84. self.server.register_completion_queue(self.server_completion_queue)
  85. if server_credentials:
  86. self.port = self.server.add_http2_port(b'[::]:0',
  87. server_credentials)
  88. else:
  89. self.port = self.server.add_http2_port(b'[::]:0')
  90. self.server.start()
  91. self.client_completion_queue = cygrpc.CompletionQueue()
  92. if client_credentials:
  93. client_channel_arguments = ((
  94. cygrpc.ChannelArgKey.ssl_target_name_override,
  95. host_override,
  96. ),)
  97. self.client_channel = cygrpc.Channel(
  98. 'localhost:{}'.format(self.port).encode(),
  99. client_channel_arguments, client_credentials)
  100. else:
  101. self.client_channel = cygrpc.Channel(
  102. 'localhost:{}'.format(self.port).encode(), set(), None)
  103. if host_override:
  104. self.host_argument = None # default host
  105. self.expected_host = host_override
  106. else:
  107. # arbitrary host name necessitating no further identification
  108. self.host_argument = b'hostess'
  109. self.expected_host = self.host_argument
  110. def tearDownMixin(self):
  111. self.client_channel.close(cygrpc.StatusCode.ok, 'test being torn down!')
  112. del self.client_channel
  113. del self.server
  114. del self.client_completion_queue
  115. del self.server_completion_queue
  116. def _perform_queue_operations(self, operations, call, queue, deadline,
  117. description):
  118. """Perform the operations with given call, queue, and deadline.
  119. Invocation errors are reported with as an exception with `description`
  120. in the message. Performs the operations asynchronously, returning a
  121. future.
  122. """
  123. def performer():
  124. tag = object()
  125. try:
  126. call_result = call.start_client_batch(operations, tag)
  127. self.assertEqual(cygrpc.CallError.ok, call_result)
  128. event = queue.poll(deadline=deadline)
  129. self.assertEqual(cygrpc.CompletionType.operation_complete,
  130. event.completion_type)
  131. self.assertTrue(event.success)
  132. self.assertIs(tag, event.tag)
  133. except Exception as error:
  134. raise Exception("Error in '{}': {}".format(
  135. description, error.message))
  136. return event
  137. return test_utilities.SimpleFuture(performer)
  138. def test_echo(self):
  139. DEADLINE = time.time() + 5
  140. DEADLINE_TOLERANCE = 0.25
  141. CLIENT_METADATA_ASCII_KEY = 'key'
  142. CLIENT_METADATA_ASCII_VALUE = 'val'
  143. CLIENT_METADATA_BIN_KEY = 'key-bin'
  144. CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
  145. SERVER_INITIAL_METADATA_KEY = 'init_me_me_me'
  146. SERVER_INITIAL_METADATA_VALUE = 'whodawha?'
  147. SERVER_TRAILING_METADATA_KEY = 'california_is_in_a_drought'
  148. SERVER_TRAILING_METADATA_VALUE = 'zomg it is'
  149. SERVER_STATUS_CODE = cygrpc.StatusCode.ok
  150. SERVER_STATUS_DETAILS = 'our work is never over'
  151. REQUEST = b'in death a member of project mayhem has a name'
  152. RESPONSE = b'his name is robert paulson'
  153. METHOD = b'twinkies'
  154. server_request_tag = object()
  155. request_call_result = self.server.request_call(
  156. self.server_completion_queue, self.server_completion_queue,
  157. server_request_tag)
  158. self.assertEqual(cygrpc.CallError.ok, request_call_result)
  159. client_call_tag = object()
  160. client_initial_metadata = (
  161. (
  162. CLIENT_METADATA_ASCII_KEY,
  163. CLIENT_METADATA_ASCII_VALUE,
  164. ),
  165. (
  166. CLIENT_METADATA_BIN_KEY,
  167. CLIENT_METADATA_BIN_VALUE,
  168. ),
  169. )
  170. client_call = self.client_channel.integrated_call(
  171. 0, METHOD, self.host_argument, DEADLINE, client_initial_metadata,
  172. None, [
  173. (
  174. [
  175. cygrpc.SendInitialMetadataOperation(
  176. client_initial_metadata, _EMPTY_FLAGS),
  177. cygrpc.SendMessageOperation(REQUEST, _EMPTY_FLAGS),
  178. cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
  179. cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
  180. cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
  181. cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
  182. ],
  183. client_call_tag,
  184. ),
  185. ])
  186. client_event_future = test_utilities.SimpleFuture(
  187. self.client_channel.next_call_event)
  188. request_event = self.server_completion_queue.poll(deadline=DEADLINE)
  189. self.assertEqual(cygrpc.CompletionType.operation_complete,
  190. request_event.completion_type)
  191. self.assertIsInstance(request_event.call, cygrpc.Call)
  192. self.assertIs(server_request_tag, request_event.tag)
  193. self.assertTrue(
  194. test_common.metadata_transmitted(client_initial_metadata,
  195. request_event.invocation_metadata))
  196. self.assertEqual(METHOD, request_event.call_details.method)
  197. self.assertEqual(self.expected_host, request_event.call_details.host)
  198. self.assertLess(abs(DEADLINE - request_event.call_details.deadline),
  199. DEADLINE_TOLERANCE)
  200. server_call_tag = object()
  201. server_call = request_event.call
  202. server_initial_metadata = ((
  203. SERVER_INITIAL_METADATA_KEY,
  204. SERVER_INITIAL_METADATA_VALUE,
  205. ),)
  206. server_trailing_metadata = ((
  207. SERVER_TRAILING_METADATA_KEY,
  208. SERVER_TRAILING_METADATA_VALUE,
  209. ),)
  210. server_start_batch_result = server_call.start_server_batch([
  211. cygrpc.SendInitialMetadataOperation(server_initial_metadata,
  212. _EMPTY_FLAGS),
  213. cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
  214. cygrpc.SendMessageOperation(RESPONSE, _EMPTY_FLAGS),
  215. cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
  216. cygrpc.SendStatusFromServerOperation(
  217. server_trailing_metadata, SERVER_STATUS_CODE,
  218. SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
  219. ], server_call_tag)
  220. self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
  221. server_event = self.server_completion_queue.poll(deadline=DEADLINE)
  222. client_event = client_event_future.result()
  223. self.assertEqual(6, len(client_event.batch_operations))
  224. found_client_op_types = set()
  225. for client_result in client_event.batch_operations:
  226. # we expect each op type to be unique
  227. self.assertNotIn(client_result.type(), found_client_op_types)
  228. found_client_op_types.add(client_result.type())
  229. if client_result.type(
  230. ) == cygrpc.OperationType.receive_initial_metadata:
  231. self.assertTrue(
  232. test_common.metadata_transmitted(
  233. server_initial_metadata,
  234. client_result.initial_metadata()))
  235. elif client_result.type() == cygrpc.OperationType.receive_message:
  236. self.assertEqual(RESPONSE, client_result.message())
  237. elif client_result.type(
  238. ) == cygrpc.OperationType.receive_status_on_client:
  239. self.assertTrue(
  240. test_common.metadata_transmitted(
  241. server_trailing_metadata,
  242. client_result.trailing_metadata()))
  243. self.assertEqual(SERVER_STATUS_DETAILS, client_result.details())
  244. self.assertEqual(SERVER_STATUS_CODE, client_result.code())
  245. self.assertEqual(
  246. set([
  247. cygrpc.OperationType.send_initial_metadata,
  248. cygrpc.OperationType.send_message,
  249. cygrpc.OperationType.send_close_from_client,
  250. cygrpc.OperationType.receive_initial_metadata,
  251. cygrpc.OperationType.receive_message,
  252. cygrpc.OperationType.receive_status_on_client
  253. ]), found_client_op_types)
  254. self.assertEqual(5, len(server_event.batch_operations))
  255. found_server_op_types = set()
  256. for server_result in server_event.batch_operations:
  257. self.assertNotIn(server_result.type(), found_server_op_types)
  258. found_server_op_types.add(server_result.type())
  259. if server_result.type() == cygrpc.OperationType.receive_message:
  260. self.assertEqual(REQUEST, server_result.message())
  261. elif server_result.type(
  262. ) == cygrpc.OperationType.receive_close_on_server:
  263. self.assertFalse(server_result.cancelled())
  264. self.assertEqual(
  265. set([
  266. cygrpc.OperationType.send_initial_metadata,
  267. cygrpc.OperationType.receive_message,
  268. cygrpc.OperationType.send_message,
  269. cygrpc.OperationType.receive_close_on_server,
  270. cygrpc.OperationType.send_status_from_server
  271. ]), found_server_op_types)
  272. del client_call
  273. del server_call
  274. def test_6522(self):
  275. DEADLINE = time.time() + 5
  276. DEADLINE_TOLERANCE = 0.25
  277. METHOD = b'twinkies'
  278. empty_metadata = ()
  279. # Prologue
  280. server_request_tag = object()
  281. self.server.request_call(self.server_completion_queue,
  282. self.server_completion_queue,
  283. server_request_tag)
  284. client_call = self.client_channel.segregated_call(
  285. 0, METHOD, self.host_argument, DEADLINE, None, None,
  286. ([(
  287. [
  288. cygrpc.SendInitialMetadataOperation(empty_metadata,
  289. _EMPTY_FLAGS),
  290. cygrpc.ReceiveInitialMetadataOperation(_EMPTY_FLAGS),
  291. ],
  292. object(),
  293. ),
  294. (
  295. [
  296. cygrpc.ReceiveStatusOnClientOperation(_EMPTY_FLAGS),
  297. ],
  298. object(),
  299. )]))
  300. client_initial_metadata_event_future = test_utilities.SimpleFuture(
  301. client_call.next_event)
  302. request_event = self.server_completion_queue.poll(deadline=DEADLINE)
  303. server_call = request_event.call
  304. def perform_server_operations(operations, description):
  305. return self._perform_queue_operations(operations, server_call,
  306. self.server_completion_queue,
  307. DEADLINE, description)
  308. server_event_future = perform_server_operations([
  309. cygrpc.SendInitialMetadataOperation(empty_metadata, _EMPTY_FLAGS),
  310. ], "Server prologue")
  311. client_initial_metadata_event_future.result() # force completion
  312. server_event_future.result()
  313. # Messaging
  314. for _ in range(10):
  315. client_call.operate([
  316. cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
  317. cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
  318. ], "Client message")
  319. client_message_event_future = test_utilities.SimpleFuture(
  320. client_call.next_event)
  321. server_event_future = perform_server_operations([
  322. cygrpc.SendMessageOperation(b'', _EMPTY_FLAGS),
  323. cygrpc.ReceiveMessageOperation(_EMPTY_FLAGS),
  324. ], "Server receive")
  325. client_message_event_future.result() # force completion
  326. server_event_future.result()
  327. # Epilogue
  328. client_call.operate([
  329. cygrpc.SendCloseFromClientOperation(_EMPTY_FLAGS),
  330. ], "Client epilogue")
  331. # One for ReceiveStatusOnClient, one for SendCloseFromClient.
  332. client_events_future = test_utilities.SimpleFuture(lambda: {
  333. client_call.next_event(),
  334. client_call.next_event(),
  335. })
  336. server_event_future = perform_server_operations([
  337. cygrpc.ReceiveCloseOnServerOperation(_EMPTY_FLAGS),
  338. cygrpc.SendStatusFromServerOperation(
  339. empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
  340. ], "Server epilogue")
  341. client_events_future.result() # force completion
  342. server_event_future.result()
  343. class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
  344. def setUp(self):
  345. self.setUpMixin(None, None, None)
  346. def tearDown(self):
  347. self.tearDownMixin()
  348. class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
  349. def setUp(self):
  350. server_credentials = cygrpc.server_credentials_ssl(
  351. None, [
  352. cygrpc.SslPemKeyCertPair(resources.private_key(),
  353. resources.certificate_chain())
  354. ], False)
  355. client_credentials = cygrpc.SSLChannelCredentials(
  356. resources.test_root_certificates(), None, None)
  357. self.setUpMixin(server_credentials, client_credentials,
  358. _SSL_HOST_OVERRIDE)
  359. def tearDown(self):
  360. self.tearDownMixin()
  361. if __name__ == '__main__':
  362. unittest.main(verbosity=2)