Browse Source

Add stream-unary

Richard Belleville 5 years ago
parent
commit
38bef98463

+ 3 - 3
src/python/grpcio/grpc/__init__.py

@@ -2035,9 +2035,9 @@ __all__ = (
     'unary_unary',
 )
 
-if sys.version_info[0] > 2:
-    from grpc._simple_stubs import unary_unary, unary_stream
-    __all__ = __all__ + (unary_unary, unary_stream)
+if sys.version_info[0] >= 3:
+    from grpc._simple_stubs import unary_unary, unary_stream, stream_unary
+    __all__ = __all__ + (unary_unary, unary_stream, stream_unary)
 
 ############################### Extension Shims ################################
 

+ 36 - 5
src/python/grpcio/grpc/_simple_stubs.py

@@ -7,7 +7,7 @@ import logging
 import threading
 
 import grpc
-from typing import Any, AnyStr, Callable, Iterator, Optional, Sequence, Tuple, Union
+from typing import Any, AnyStr, Callable, Iterator, Optional, Sequence, Tuple, TypeVar, Union
 
 
 _LOGGER = logging.getLogger(__name__)
@@ -122,8 +122,10 @@ class ChannelCache:
         with self._lock:
             return len(self._mapping)
 
+RequestType = TypeVar('RequestType')
+ResponseType = TypeVar('ResponseType')
 
-def unary_unary(request: Any,
+def unary_unary(request: RequestType,
                 target: str,
                 method: str,
                 request_serializer: Optional[Callable[[Any], bytes]] = None,
@@ -135,7 +137,7 @@ def unary_unary(request: Any,
                 compression: Optional[grpc.Compression] = None,
                 wait_for_ready: Optional[bool] = None,
                 timeout: Optional[float] = None,
-                metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Any:
+                metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> ResponseType:
     """Invokes a unary RPC without an explicitly specified channel.
 
     This is backed by a cache of channels evicted by a background thread
@@ -152,7 +154,7 @@ def unary_unary(request: Any,
                          timeout=timeout)
 
 
-def unary_stream(request: Any,
+def unary_stream(request: RequestType,
                  target: str,
                  method: str,
                  request_serializer: Optional[Callable[[Any], bytes]] = None,
@@ -164,7 +166,7 @@ def unary_stream(request: Any,
                  compression: Optional[grpc.Compression] = None,
                  wait_for_ready: Optional[bool] = None,
                  timeout: Optional[float] = None,
-                 metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Iterator[Any]:
+                 metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Iterator[ResponseType]:
     """Invokes a unary-stream RPC without an explicitly specified channel.
 
     This is backed by a cache of channels evicted by a background thread
@@ -179,3 +181,32 @@ def unary_stream(request: Any,
                          wait_for_ready=wait_for_ready,
                          credentials=call_credentials,
                          timeout=timeout)
+
+
+def stream_unary(request_iterator: Iterator[RequestType],
+                 target: str,
+                 method: str,
+                 request_serializer: Optional[Callable[[Any], bytes]] = None,
+                 request_deserializer: Optional[Callable[[bytes], Any]] = None,
+                 options: Sequence[Tuple[AnyStr, AnyStr]] = (),
+                 # TODO: Somehow make insecure_channel opt-in, not the default.
+                 channel_credentials: Optional[grpc.ChannelCredentials] = None,
+                 call_credentials: Optional[grpc.CallCredentials] = None,
+                 compression: Optional[grpc.Compression] = None,
+                 wait_for_ready: Optional[bool] = None,
+                 timeout: Optional[float] = None,
+                 metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> ResponseType:
+    """Invokes a stream-unary RPC without an explicitly specified channel.
+
+    This is backed by a cache of channels evicted by a background thread
+    on a periodic basis.
+
+    TODO: Document the parameters and return value.
+    """
+    channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression)
+    multicallable = channel.stream_unary(method, request_serializer, request_deserializer)
+    return multicallable(request_iterator,
+                         metadata=metadata,
+                         wait_for_ready=wait_for_ready,
+                         credentials=call_credentials,
+                         timeout=timeout)

+ 24 - 0
src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py

@@ -37,6 +37,7 @@ _CACHE_EPOCHS = 8
 _CACHE_TRIALS = 6
 
 _SERVER_RESPONSE_COUNT = 10
+_CLIENT_REQUEST_COUNT = _SERVER_RESPONSE_COUNT
 
 _UNARY_UNARY = "/test/UnaryUnary"
 _UNARY_STREAM = "/test/UnaryStream"
@@ -51,12 +52,21 @@ def _unary_stream_handler(request, context):
         yield request
 
 
+def _stream_unary_handler(request_iterator, context):
+    request = None
+    for single_request in request_iterator:
+        request = single_request
+    return request
+
+
 class _GenericHandler(grpc.GenericRpcHandler):
     def service(self, handler_call_details):
         if handler_call_details.method == _UNARY_UNARY:
             return grpc.unary_unary_rpc_method_handler(_unary_unary_handler)
         elif handler_call_details.method == _UNARY_STREAM:
             return grpc.unary_stream_rpc_method_handler(_unary_stream_handler)
+        elif handler_call_details.method == _STREAM_UNARY:
+            return grpc.stream_unary_rpc_method_handler(_stream_unary_handler)
         else:
             raise NotImplementedError()
 
@@ -195,6 +205,20 @@ class SimpleStubsTest(unittest.TestCase):
                                              channel_credentials=grpc.local_channel_credentials()):
                 self.assertEqual(request, response)
 
+    def test_stream_unary(self):
+        def request_iter():
+            for _ in range(_CLIENT_REQUEST_COUNT):
+                yield request
+        with _server(grpc.local_server_credentials()) as (_, port):
+            target = f'localhost:{port}'
+            request = b'0000'
+            response = grpc.stream_unary(request_iter(),
+                                         target,
+                                         _UNARY_STREAM,
+                                         channel_credentials=grpc.local_channel_credentials())
+            self.assertEqual(request, response)
+
+
 
     # TODO: Test request_serializer
     # TODO: Test request_deserializer