Browse Source

Add stream_stream

Richard Belleville 5 years ago
parent
commit
1abe2f0ac2

+ 2 - 2
src/python/grpcio/grpc/__init__.py

@@ -2036,8 +2036,8 @@ __all__ = (
 )
 
 if sys.version_info[0] >= 3:
-    from grpc._simple_stubs import unary_unary, unary_stream, stream_unary
-    __all__ = __all__ + (unary_unary, unary_stream, stream_unary)
+    from grpc._simple_stubs import unary_unary, unary_stream, stream_unary, stream_stream
+    __all__ = __all__ + (unary_unary, unary_stream, stream_unary, stream_stream)
 
 ############################### Extension Shims ################################
 

+ 29 - 0
src/python/grpcio/grpc/_simple_stubs.py

@@ -210,3 +210,32 @@ def stream_unary(request_iterator: Iterator[RequestType],
                          wait_for_ready=wait_for_ready,
                          credentials=call_credentials,
                          timeout=timeout)
+
+
+def stream_stream(request_iterator: Iterator[RequestType],
+                  target: str,
+                  method: str,
+                  request_serializer: Optional[Callable[[Any], bytes]] = None,
+                  request_deserializer: Optional[Callable[[bytes], Any]] = None,
+                  options: Sequence[Tuple[AnyStr, AnyStr]] = (),
+                  # TODO: Somehow make insecure_channel opt-in, not the default.
+                  channel_credentials: Optional[grpc.ChannelCredentials] = None,
+                  call_credentials: Optional[grpc.CallCredentials] = None,
+                  compression: Optional[grpc.Compression] = None,
+                  wait_for_ready: Optional[bool] = None,
+                  timeout: Optional[float] = None,
+                  metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Iterator[ResponseType]:
+    """Invokes a stream-stream RPC without an explicitly specified channel.
+
+    This is backed by a cache of channels evicted by a background thread
+    on a periodic basis.
+
+    TODO: Document the parameters and return value.
+    """
+    channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression)
+    multicallable = channel.stream_stream(method, request_serializer, request_deserializer)
+    return multicallable(request_iterator,
+                         metadata=metadata,
+                         wait_for_ready=wait_for_ready,
+                         credentials=call_credentials,
+                         timeout=timeout)

+ 22 - 1
src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py

@@ -41,6 +41,8 @@ _CLIENT_REQUEST_COUNT = _SERVER_RESPONSE_COUNT
 
 _UNARY_UNARY = "/test/UnaryUnary"
 _UNARY_STREAM = "/test/UnaryStream"
+_STREAM_UNARY = "/test/StreamUnary"
+_STREAM_STREAM = "/test/StreamStream"
 
 
 def _unary_unary_handler(request, context):
@@ -59,6 +61,11 @@ def _stream_unary_handler(request_iterator, context):
     return request
 
 
+def _stream_stream_handler(request_iterator, context):
+    for request in request_iterator:
+        yield request
+
+
 class _GenericHandler(grpc.GenericRpcHandler):
     def service(self, handler_call_details):
         if handler_call_details.method == _UNARY_UNARY:
@@ -67,6 +74,8 @@ class _GenericHandler(grpc.GenericRpcHandler):
             return grpc.unary_stream_rpc_method_handler(_unary_stream_handler)
         elif handler_call_details.method == _STREAM_UNARY:
             return grpc.stream_unary_rpc_method_handler(_stream_unary_handler)
+        elif handler_call_details.method == _STREAM_STREAM:
+            return grpc.stream_stream_rpc_method_handler(_stream_stream_handler)
         else:
             raise NotImplementedError()
 
@@ -214,10 +223,22 @@ class SimpleStubsTest(unittest.TestCase):
             request = b'0000'
             response = grpc.stream_unary(request_iter(),
                                          target,
-                                         _UNARY_STREAM,
+                                         _STREAM_UNARY,
                                          channel_credentials=grpc.local_channel_credentials())
             self.assertEqual(request, response)
 
+    def test_stream_stream(self):
+        def request_iter():
+            for _ in range(_CLIENT_REQUEST_COUNT):
+                yield request
+        with _server(grpc.local_server_credentials()) as (_, port):
+            target = f'localhost:{port}'
+            request = b'0000'
+            for response in grpc.stream_stream(request_iter(),
+                                               target,
+                                               _STREAM_STREAM,
+                                               channel_credentials=grpc.local_channel_credentials()):
+                self.assertEqual(request, response)
 
 
     # TODO: Test request_serializer