_metadata_flags_test.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright 2018 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests metadata flags feature by testing wait-for-ready semantics"""
  15. import time
  16. import weakref
  17. import unittest
  18. import threading
  19. import logging
  20. import socket
  21. from six.moves import queue
  22. import grpc
  23. from tests.unit import test_common
  24. from tests.unit.framework.common import test_constants
  25. import tests.unit.framework.common
  26. from tests.unit.framework.common import get_socket
  27. _UNARY_UNARY = '/test/UnaryUnary'
  28. _UNARY_STREAM = '/test/UnaryStream'
  29. _STREAM_UNARY = '/test/StreamUnary'
  30. _STREAM_STREAM = '/test/StreamStream'
  31. _REQUEST = b'\x00\x00\x00'
  32. _RESPONSE = b'\x00\x00\x00'
  33. def handle_unary_unary(test, request, servicer_context):
  34. return _RESPONSE
  35. def handle_unary_stream(test, request, servicer_context):
  36. for _ in range(test_constants.STREAM_LENGTH):
  37. yield _RESPONSE
  38. def handle_stream_unary(test, request_iterator, servicer_context):
  39. for _ in request_iterator:
  40. pass
  41. return _RESPONSE
  42. def handle_stream_stream(test, request_iterator, servicer_context):
  43. for _ in request_iterator:
  44. yield _RESPONSE
  45. class _MethodHandler(grpc.RpcMethodHandler):
  46. def __init__(self, test, request_streaming, response_streaming):
  47. self.request_streaming = request_streaming
  48. self.response_streaming = response_streaming
  49. self.request_deserializer = None
  50. self.response_serializer = None
  51. self.unary_unary = None
  52. self.unary_stream = None
  53. self.stream_unary = None
  54. self.stream_stream = None
  55. if self.request_streaming and self.response_streaming:
  56. self.stream_stream = lambda req, ctx: handle_stream_stream(
  57. test, req, ctx)
  58. elif self.request_streaming:
  59. self.stream_unary = lambda req, ctx: handle_stream_unary(
  60. test, req, ctx)
  61. elif self.response_streaming:
  62. self.unary_stream = lambda req, ctx: handle_unary_stream(
  63. test, req, ctx)
  64. else:
  65. self.unary_unary = lambda req, ctx: handle_unary_unary(
  66. test, req, ctx)
  67. class _GenericHandler(grpc.GenericRpcHandler):
  68. def __init__(self, test):
  69. self._test = test
  70. def service(self, handler_call_details):
  71. if handler_call_details.method == _UNARY_UNARY:
  72. return _MethodHandler(self._test, False, False)
  73. elif handler_call_details.method == _UNARY_STREAM:
  74. return _MethodHandler(self._test, False, True)
  75. elif handler_call_details.method == _STREAM_UNARY:
  76. return _MethodHandler(self._test, True, False)
  77. elif handler_call_details.method == _STREAM_STREAM:
  78. return _MethodHandler(self._test, True, True)
  79. else:
  80. return None
  81. def create_dummy_channel():
  82. """Creating dummy channels is a workaround for retries"""
  83. host, port, sock = get_socket()
  84. sock.close()
  85. return grpc.insecure_channel('{}:{}'.format(host, port))
  86. def perform_unary_unary_call(channel, wait_for_ready=None):
  87. channel.unary_unary(_UNARY_UNARY).__call__(
  88. _REQUEST,
  89. timeout=test_constants.LONG_TIMEOUT,
  90. wait_for_ready=wait_for_ready)
  91. def perform_unary_unary_with_call(channel, wait_for_ready=None):
  92. channel.unary_unary(_UNARY_UNARY).with_call(
  93. _REQUEST,
  94. timeout=test_constants.LONG_TIMEOUT,
  95. wait_for_ready=wait_for_ready)
  96. def perform_unary_unary_future(channel, wait_for_ready=None):
  97. channel.unary_unary(_UNARY_UNARY).future(
  98. _REQUEST,
  99. timeout=test_constants.LONG_TIMEOUT,
  100. wait_for_ready=wait_for_ready).result(
  101. timeout=test_constants.LONG_TIMEOUT)
  102. def perform_unary_stream_call(channel, wait_for_ready=None):
  103. response_iterator = channel.unary_stream(_UNARY_STREAM).__call__(
  104. _REQUEST,
  105. timeout=test_constants.LONG_TIMEOUT,
  106. wait_for_ready=wait_for_ready)
  107. for _ in response_iterator:
  108. pass
  109. def perform_stream_unary_call(channel, wait_for_ready=None):
  110. channel.stream_unary(_STREAM_UNARY).__call__(
  111. iter([_REQUEST] * test_constants.STREAM_LENGTH),
  112. timeout=test_constants.LONG_TIMEOUT,
  113. wait_for_ready=wait_for_ready)
  114. def perform_stream_unary_with_call(channel, wait_for_ready=None):
  115. channel.stream_unary(_STREAM_UNARY).with_call(
  116. iter([_REQUEST] * test_constants.STREAM_LENGTH),
  117. timeout=test_constants.LONG_TIMEOUT,
  118. wait_for_ready=wait_for_ready)
  119. def perform_stream_unary_future(channel, wait_for_ready=None):
  120. channel.stream_unary(_STREAM_UNARY).future(
  121. iter([_REQUEST] * test_constants.STREAM_LENGTH),
  122. timeout=test_constants.LONG_TIMEOUT,
  123. wait_for_ready=wait_for_ready).result(
  124. timeout=test_constants.LONG_TIMEOUT)
  125. def perform_stream_stream_call(channel, wait_for_ready=None):
  126. response_iterator = channel.stream_stream(_STREAM_STREAM).__call__(
  127. iter([_REQUEST] * test_constants.STREAM_LENGTH),
  128. timeout=test_constants.LONG_TIMEOUT,
  129. wait_for_ready=wait_for_ready)
  130. for _ in response_iterator:
  131. pass
  132. _ALL_CALL_CASES = [
  133. perform_unary_unary_call, perform_unary_unary_with_call,
  134. perform_unary_unary_future, perform_unary_stream_call,
  135. perform_stream_unary_call, perform_stream_unary_with_call,
  136. perform_stream_unary_future, perform_stream_stream_call
  137. ]
  138. class MetadataFlagsTest(unittest.TestCase):
  139. def check_connection_does_failfast(self, fn, channel, wait_for_ready=None):
  140. try:
  141. fn(channel, wait_for_ready)
  142. self.fail("The Call should fail")
  143. except BaseException as e: # pylint: disable=broad-except
  144. self.assertIs(grpc.StatusCode.UNAVAILABLE, e.code())
  145. def test_call_wait_for_ready_default(self):
  146. for perform_call in _ALL_CALL_CASES:
  147. with create_dummy_channel() as channel:
  148. self.check_connection_does_failfast(perform_call, channel)
  149. def test_call_wait_for_ready_disabled(self):
  150. for perform_call in _ALL_CALL_CASES:
  151. with create_dummy_channel() as channel:
  152. self.check_connection_does_failfast(perform_call,
  153. channel,
  154. wait_for_ready=False)
  155. def test_call_wait_for_ready_enabled(self):
  156. # To test the wait mechanism, Python thread is required to make
  157. # client set up first without handling them case by case.
  158. # Also, Python thread don't pass the unhandled exceptions to
  159. # main thread. So, it need another method to store the
  160. # exceptions and raise them again in main thread.
  161. unhandled_exceptions = queue.Queue()
  162. # We just need an unused TCP port
  163. host, port, sock = get_socket()
  164. sock.close()
  165. addr = '{}:{}'.format(host, port)
  166. wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
  167. def wait_for_transient_failure(channel_connectivity):
  168. if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE:
  169. wg.done()
  170. def test_call(perform_call):
  171. with grpc.insecure_channel(addr) as channel:
  172. try:
  173. channel.subscribe(wait_for_transient_failure)
  174. perform_call(channel, wait_for_ready=True)
  175. except BaseException as e: # pylint: disable=broad-except
  176. # If the call failed, the thread would be destroyed. The
  177. # channel object can be collected before calling the
  178. # callback, which will result in a deadlock.
  179. wg.done()
  180. unhandled_exceptions.put(e, True)
  181. test_threads = []
  182. for perform_call in _ALL_CALL_CASES:
  183. test_thread = threading.Thread(target=test_call,
  184. args=(perform_call,),
  185. daemon=True)
  186. test_thread.exception = None
  187. test_thread.start()
  188. test_threads.append(test_thread)
  189. # Start the server after the connections are waiting
  190. wg.wait()
  191. server = test_common.test_server(reuse_port=True)
  192. server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),))
  193. server.add_insecure_port(addr)
  194. server.start()
  195. for test_thread in test_threads:
  196. test_thread.join()
  197. # Stop the server to make test end properly
  198. server.stop(0)
  199. if not unhandled_exceptions.empty():
  200. raise unhandled_exceptions.get(True)
  201. if __name__ == '__main__':
  202. logging.basicConfig(level=logging.DEBUG)
  203. unittest.main(verbosity=2)