|
@@ -41,6 +41,8 @@ _CLIENT_REQUEST_COUNT = _SERVER_RESPONSE_COUNT
|
|
|
|
|
|
_UNARY_UNARY = "/test/UnaryUnary"
|
|
_UNARY_UNARY = "/test/UnaryUnary"
|
|
_UNARY_STREAM = "/test/UnaryStream"
|
|
_UNARY_STREAM = "/test/UnaryStream"
|
|
|
|
+_STREAM_UNARY = "/test/StreamUnary"
|
|
|
|
+_STREAM_STREAM = "/test/StreamStream"
|
|
|
|
|
|
|
|
|
|
def _unary_unary_handler(request, context):
|
|
def _unary_unary_handler(request, context):
|
|
@@ -59,6 +61,11 @@ def _stream_unary_handler(request_iterator, context):
|
|
return request
|
|
return request
|
|
|
|
|
|
|
|
|
|
|
|
+def _stream_stream_handler(request_iterator, context):
|
|
|
|
+ for request in request_iterator:
|
|
|
|
+ yield request
|
|
|
|
+
|
|
|
|
+
|
|
class _GenericHandler(grpc.GenericRpcHandler):
|
|
class _GenericHandler(grpc.GenericRpcHandler):
|
|
def service(self, handler_call_details):
|
|
def service(self, handler_call_details):
|
|
if handler_call_details.method == _UNARY_UNARY:
|
|
if handler_call_details.method == _UNARY_UNARY:
|
|
@@ -67,6 +74,8 @@ class _GenericHandler(grpc.GenericRpcHandler):
|
|
return grpc.unary_stream_rpc_method_handler(_unary_stream_handler)
|
|
return grpc.unary_stream_rpc_method_handler(_unary_stream_handler)
|
|
elif handler_call_details.method == _STREAM_UNARY:
|
|
elif handler_call_details.method == _STREAM_UNARY:
|
|
return grpc.stream_unary_rpc_method_handler(_stream_unary_handler)
|
|
return grpc.stream_unary_rpc_method_handler(_stream_unary_handler)
|
|
|
|
+ elif handler_call_details.method == _STREAM_STREAM:
|
|
|
|
+ return grpc.stream_stream_rpc_method_handler(_stream_stream_handler)
|
|
else:
|
|
else:
|
|
raise NotImplementedError()
|
|
raise NotImplementedError()
|
|
|
|
|
|
@@ -214,10 +223,22 @@ class SimpleStubsTest(unittest.TestCase):
|
|
request = b'0000'
|
|
request = b'0000'
|
|
response = grpc.stream_unary(request_iter(),
|
|
response = grpc.stream_unary(request_iter(),
|
|
target,
|
|
target,
|
|
- _UNARY_STREAM,
|
|
|
|
|
|
+ _STREAM_UNARY,
|
|
channel_credentials=grpc.local_channel_credentials())
|
|
channel_credentials=grpc.local_channel_credentials())
|
|
self.assertEqual(request, response)
|
|
self.assertEqual(request, response)
|
|
|
|
|
|
|
|
+ def test_stream_stream(self):
|
|
|
|
+ def request_iter():
|
|
|
|
+ for _ in range(_CLIENT_REQUEST_COUNT):
|
|
|
|
+ yield request
|
|
|
|
+ with _server(grpc.local_server_credentials()) as (_, port):
|
|
|
|
+ target = f'localhost:{port}'
|
|
|
|
+ request = b'0000'
|
|
|
|
+ for response in grpc.stream_stream(request_iter(),
|
|
|
|
+ target,
|
|
|
|
+ _STREAM_STREAM,
|
|
|
|
+ channel_credentials=grpc.local_channel_credentials()):
|
|
|
|
+ self.assertEqual(request, response)
|
|
|
|
|
|
|
|
|
|
# TODO: Test request_serializer
|
|
# TODO: Test request_serializer
|