|  | @@ -15,15 +15,15 @@
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import asyncio
 | 
	
		
			
				|  |  |  from functools import partial
 | 
	
		
			
				|  |  | -from typing import AsyncIterable, Dict, Optional
 | 
	
		
			
				|  |  | +from typing import AsyncIterable, Awaitable, Dict, Optional
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import grpc
 | 
	
		
			
				|  |  |  from grpc import _common
 | 
	
		
			
				|  |  |  from grpc._cython import cygrpc
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from . import _base_call
 | 
	
		
			
				|  |  | -from ._typing import (DeserializingFunction, MetadataType, RequestType,
 | 
	
		
			
				|  |  | -                      ResponseType, SerializingFunction, DoneCallbackType)
 | 
	
		
			
				|  |  | +from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
 | 
	
		
			
				|  |  | +                      RequestType, ResponseType, SerializingFunction)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -145,7 +145,7 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType],
 | 
	
		
			
				|  |  |                         status.trailing_metadata())
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class Call(_base_call.Call):
 | 
	
		
			
				|  |  | +class Call:
 | 
	
		
			
				|  |  |      """Base implementation of client RPC Call object.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      Implements logic around final status, metadata and cancellation.
 | 
	
	
		
			
				|  | @@ -153,11 +153,19 @@ class Call(_base_call.Call):
 | 
	
		
			
				|  |  |      _loop: asyncio.AbstractEventLoop
 | 
	
		
			
				|  |  |      _code: grpc.StatusCode
 | 
	
		
			
				|  |  |      _cython_call: cygrpc._AioCall
 | 
	
		
			
				|  |  | +    _metadata: MetadataType
 | 
	
		
			
				|  |  | +    _request_serializer: SerializingFunction
 | 
	
		
			
				|  |  | +    _response_deserializer: DeserializingFunction
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def __init__(self, cython_call: cygrpc._AioCall,
 | 
	
		
			
				|  |  | +    def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType,
 | 
	
		
			
				|  |  | +                 request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  | +                 response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  |                   loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  |          self._loop = loop
 | 
	
		
			
				|  |  |          self._cython_call = cython_call
 | 
	
		
			
				|  |  | +        self._metadata = metadata
 | 
	
		
			
				|  |  | +        self._request_serializer = request_serializer
 | 
	
		
			
				|  |  | +        self._response_deserializer = response_deserializer
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def __del__(self) -> None:
 | 
	
		
			
				|  |  |          if not self._cython_call.done():
 | 
	
	
		
			
				|  | @@ -221,63 +229,24 @@ class Call(_base_call.Call):
 | 
	
		
			
				|  |  |          return self._repr()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
 | 
	
		
			
				|  |  | -    """Object for managing unary-unary RPC calls.
 | 
	
		
			
				|  |  | +class _UnaryResponseMixin(Call):
 | 
	
		
			
				|  |  | +    _call_finisher: asyncio.Task
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    Returned when an instance of `UnaryUnaryMultiCallable` object is called.
 | 
	
		
			
				|  |  | -    """
 | 
	
		
			
				|  |  | -    _request: RequestType
 | 
	
		
			
				|  |  | -    _metadata: Optional[MetadataType]
 | 
	
		
			
				|  |  | -    _request_serializer: SerializingFunction
 | 
	
		
			
				|  |  | -    _response_deserializer: DeserializingFunction
 | 
	
		
			
				|  |  | -    _call: asyncio.Task
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    # pylint: disable=too-many-arguments
 | 
	
		
			
				|  |  | -    def __init__(self, request: RequestType, deadline: Optional[float],
 | 
	
		
			
				|  |  | -                 metadata: MetadataType,
 | 
	
		
			
				|  |  | -                 credentials: Optional[grpc.CallCredentials],
 | 
	
		
			
				|  |  | -                 channel: cygrpc.AioChannel, method: bytes,
 | 
	
		
			
				|  |  | -                 request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  | -                 response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  | -                 loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  | -        super().__init__(channel.call(method, deadline, credentials), loop)
 | 
	
		
			
				|  |  | -        self._request = request
 | 
	
		
			
				|  |  | -        self._metadata = metadata
 | 
	
		
			
				|  |  | -        self._request_serializer = request_serializer
 | 
	
		
			
				|  |  | -        self._response_deserializer = response_deserializer
 | 
	
		
			
				|  |  | -        self._call = loop.create_task(self._invoke())
 | 
	
		
			
				|  |  | +    def _init_unary_response_mixin(self,
 | 
	
		
			
				|  |  | +                                   response_coro: Awaitable[ResponseType]):
 | 
	
		
			
				|  |  | +        self._call_finisher = self._loop.create_task(response_coro)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def cancel(self) -> bool:
 | 
	
		
			
				|  |  |          if super().cancel():
 | 
	
		
			
				|  |  | -            self._call.cancel()
 | 
	
		
			
				|  |  | +            self._call_finisher.cancel()
 | 
	
		
			
				|  |  |              return True
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              return False
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def _invoke(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        serialized_request = _common.serialize(self._request,
 | 
	
		
			
				|  |  | -                                               self._request_serializer)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
 | 
	
		
			
				|  |  | -        # because the asyncio.Task class do not cache the exception object.
 | 
	
		
			
				|  |  | -        # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | -            serialized_response = await self._cython_call.unary_unary(
 | 
	
		
			
				|  |  | -                serialized_request, self._metadata)
 | 
	
		
			
				|  |  | -        except asyncio.CancelledError:
 | 
	
		
			
				|  |  | -            if not self.cancelled():
 | 
	
		
			
				|  |  | -                self.cancel()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        # Raises here if RPC failed or cancelled
 | 
	
		
			
				|  |  | -        await self._raise_for_status()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        return _common.deserialize(serialized_response,
 | 
	
		
			
				|  |  | -                                   self._response_deserializer)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      def __await__(self) -> ResponseType:
 | 
	
		
			
				|  |  |          """Wait till the ongoing RPC request finishes."""
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            response = yield from self._call
 | 
	
		
			
				|  |  | +            response = yield from self._call_finisher
 | 
	
		
			
				|  |  |          except asyncio.CancelledError:
 | 
	
		
			
				|  |  |              # Even if we caught all other CancelledError, there is still
 | 
	
		
			
				|  |  |              # this corner case. If the application cancels immediately after
 | 
	
	
		
			
				|  | @@ -289,53 +258,21 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
 | 
	
		
			
				|  |  |          return response
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
 | 
	
		
			
				|  |  | -    """Object for managing unary-stream RPC calls.
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    Returned when an instance of `UnaryStreamMultiCallable` object is called.
 | 
	
		
			
				|  |  | -    """
 | 
	
		
			
				|  |  | -    _request: RequestType
 | 
	
		
			
				|  |  | -    _metadata: MetadataType
 | 
	
		
			
				|  |  | -    _request_serializer: SerializingFunction
 | 
	
		
			
				|  |  | -    _response_deserializer: DeserializingFunction
 | 
	
		
			
				|  |  | -    _send_unary_request_task: asyncio.Task
 | 
	
		
			
				|  |  | +class _StreamResponseMixin(Call):
 | 
	
		
			
				|  |  |      _message_aiter: AsyncIterable[ResponseType]
 | 
	
		
			
				|  |  | +    _prerequisite: asyncio.Task
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    # pylint: disable=too-many-arguments
 | 
	
		
			
				|  |  | -    def __init__(self, request: RequestType, deadline: Optional[float],
 | 
	
		
			
				|  |  | -                 metadata: MetadataType,
 | 
	
		
			
				|  |  | -                 credentials: Optional[grpc.CallCredentials],
 | 
	
		
			
				|  |  | -                 channel: cygrpc.AioChannel, method: bytes,
 | 
	
		
			
				|  |  | -                 request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  | -                 response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  | -                 loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  | -        super().__init__(channel.call(method, deadline, credentials), loop)
 | 
	
		
			
				|  |  | -        self._request = request
 | 
	
		
			
				|  |  | -        self._metadata = metadata
 | 
	
		
			
				|  |  | -        self._request_serializer = request_serializer
 | 
	
		
			
				|  |  | -        self._response_deserializer = response_deserializer
 | 
	
		
			
				|  |  | -        self._send_unary_request_task = loop.create_task(
 | 
	
		
			
				|  |  | -            self._send_unary_request())
 | 
	
		
			
				|  |  | +    def _init_stream_response_mixin(self, prerequisite: asyncio.Task):
 | 
	
		
			
				|  |  |          self._message_aiter = None
 | 
	
		
			
				|  |  | +        self._prerequisite = prerequisite
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def cancel(self) -> bool:
 | 
	
		
			
				|  |  |          if super().cancel():
 | 
	
		
			
				|  |  | -            self._send_unary_request_task.cancel()
 | 
	
		
			
				|  |  | +            self._prerequisite.cancel()
 | 
	
		
			
				|  |  |              return True
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              return False
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def _send_unary_request(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        serialized_request = _common.serialize(self._request,
 | 
	
		
			
				|  |  | -                                               self._request_serializer)
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | -            await self._cython_call.initiate_unary_stream(
 | 
	
		
			
				|  |  | -                serialized_request, self._metadata)
 | 
	
		
			
				|  |  | -        except asyncio.CancelledError:
 | 
	
		
			
				|  |  | -            if not self.cancelled():
 | 
	
		
			
				|  |  | -                self.cancel()
 | 
	
		
			
				|  |  | -            raise
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      async def _fetch_stream_responses(self) -> ResponseType:
 | 
	
		
			
				|  |  |          message = await self._read()
 | 
	
		
			
				|  |  |          while message is not cygrpc.EOF:
 | 
	
	
		
			
				|  | @@ -349,7 +286,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      async def _read(self) -> ResponseType:
 | 
	
		
			
				|  |  |          # Wait for the request being sent
 | 
	
		
			
				|  |  | -        await self._send_unary_request_task
 | 
	
		
			
				|  |  | +        await self._prerequisite
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Reads response message from Core
 | 
	
		
			
				|  |  |          try:
 | 
	
	
		
			
				|  | @@ -366,7 +303,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
 | 
	
		
			
				|  |  |                                         self._response_deserializer)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      async def read(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        if self._cython_call.done():
 | 
	
		
			
				|  |  | +        if self.done():
 | 
	
		
			
				|  |  |              await self._raise_for_status()
 | 
	
		
			
				|  |  |              return cygrpc.EOF
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -378,39 +315,16 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
 | 
	
		
			
				|  |  |          return response_message
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
 | 
	
		
			
				|  |  | -    """Object for managing stream-unary RPC calls.
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    Returned when an instance of `StreamUnaryMultiCallable` object is called.
 | 
	
		
			
				|  |  | -    """
 | 
	
		
			
				|  |  | -    _metadata: MetadataType
 | 
	
		
			
				|  |  | -    _request_serializer: SerializingFunction
 | 
	
		
			
				|  |  | -    _response_deserializer: DeserializingFunction
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +class _StreamRequestMixin(Call):
 | 
	
		
			
				|  |  |      _metadata_sent: asyncio.Event
 | 
	
		
			
				|  |  |      _done_writing: bool
 | 
	
		
			
				|  |  | -    _call_finisher: asyncio.Task
 | 
	
		
			
				|  |  | -    _async_request_poller: asyncio.Task
 | 
	
		
			
				|  |  | +    _async_request_poller: Optional[asyncio.Task]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    # pylint: disable=too-many-arguments
 | 
	
		
			
				|  |  | -    def __init__(self,
 | 
	
		
			
				|  |  | -                 request_async_iterator: Optional[AsyncIterable[RequestType]],
 | 
	
		
			
				|  |  | -                 deadline: Optional[float], metadata: MetadataType,
 | 
	
		
			
				|  |  | -                 credentials: Optional[grpc.CallCredentials],
 | 
	
		
			
				|  |  | -                 channel: cygrpc.AioChannel, method: bytes,
 | 
	
		
			
				|  |  | -                 request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  | -                 response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  | -                 loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  | -        super().__init__(channel.call(method, deadline, credentials), loop)
 | 
	
		
			
				|  |  | -        self._metadata = metadata
 | 
	
		
			
				|  |  | -        self._request_serializer = request_serializer
 | 
	
		
			
				|  |  | -        self._response_deserializer = response_deserializer
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        self._metadata_sent = asyncio.Event(loop=loop)
 | 
	
		
			
				|  |  | +    def _init_stream_request_mixin(
 | 
	
		
			
				|  |  | +            self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
 | 
	
		
			
				|  |  | +        self._metadata_sent = asyncio.Event(loop=self._loop)
 | 
	
		
			
				|  |  |          self._done_writing = False
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        self._call_finisher = loop.create_task(self._conduct_rpc())
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          # If user passes in an async iterator, create a consumer Task.
 | 
	
		
			
				|  |  |          if request_async_iterator is not None:
 | 
	
		
			
				|  |  |              self._async_request_poller = self._loop.create_task(
 | 
	
	
		
			
				|  | @@ -420,7 +334,6 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def cancel(self) -> bool:
 | 
	
		
			
				|  |  |          if super().cancel():
 | 
	
		
			
				|  |  | -            self._call_finisher.cancel()
 | 
	
		
			
				|  |  |              if self._async_request_poller is not None:
 | 
	
		
			
				|  |  |                  self._async_request_poller.cancel()
 | 
	
		
			
				|  |  |              return True
 | 
	
	
		
			
				|  | @@ -430,38 +343,14 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
 | 
	
		
			
				|  |  |      def _metadata_sent_observer(self):
 | 
	
		
			
				|  |  |          self._metadata_sent.set()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def _conduct_rpc(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | -            serialized_response = await self._cython_call.stream_unary(
 | 
	
		
			
				|  |  | -                self._metadata, self._metadata_sent_observer)
 | 
	
		
			
				|  |  | -        except asyncio.CancelledError:
 | 
	
		
			
				|  |  | -            if not self.cancelled():
 | 
	
		
			
				|  |  | -                self.cancel()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        # Raises RpcError if the RPC failed or cancelled
 | 
	
		
			
				|  |  | -        await self._raise_for_status()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        return _common.deserialize(serialized_response,
 | 
	
		
			
				|  |  | -                                   self._response_deserializer)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      async def _consume_request_iterator(
 | 
	
		
			
				|  |  |              self, request_async_iterator: AsyncIterable[RequestType]) -> None:
 | 
	
		
			
				|  |  |          async for request in request_async_iterator:
 | 
	
		
			
				|  |  |              await self.write(request)
 | 
	
		
			
				|  |  |          await self.done_writing()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def __await__(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        """Wait till the ongoing RPC request finishes."""
 | 
	
		
			
				|  |  | -        try:
 | 
	
		
			
				|  |  | -            response = yield from self._call_finisher
 | 
	
		
			
				|  |  | -        except asyncio.CancelledError:
 | 
	
		
			
				|  |  | -            if not self.cancelled():
 | 
	
		
			
				|  |  | -                self.cancel()
 | 
	
		
			
				|  |  | -            raise
 | 
	
		
			
				|  |  | -        return response
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      async def write(self, request: RequestType) -> None:
 | 
	
		
			
				|  |  | -        if self._cython_call.done():
 | 
	
		
			
				|  |  | +        if self.done():
 | 
	
		
			
				|  |  |              raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
 | 
	
		
			
				|  |  |          if self._done_writing:
 | 
	
		
			
				|  |  |              raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
 | 
	
	
		
			
				|  | @@ -480,7 +369,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      async def done_writing(self) -> None:
 | 
	
		
			
				|  |  |          """Implementation of done_writing is idempotent."""
 | 
	
		
			
				|  |  | -        if self._cython_call.done():
 | 
	
		
			
				|  |  | +        if self.done():
 | 
	
		
			
				|  |  |              # If the RPC is finished, do nothing.
 | 
	
		
			
				|  |  |              return
 | 
	
		
			
				|  |  |          if not self._done_writing:
 | 
	
	
		
			
				|  | @@ -494,152 +383,153 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
 | 
	
		
			
				|  |  |                  await self._raise_for_status()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class StreamStreamCall(Call, _base_call.StreamStreamCall):
 | 
	
		
			
				|  |  | -    """Object for managing stream-stream RPC calls.
 | 
	
		
			
				|  |  | +class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
 | 
	
		
			
				|  |  | +    """Object for managing unary-unary RPC calls.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    Returned when an instance of `StreamStreamMultiCallable` object is called.
 | 
	
		
			
				|  |  | +    Returned when an instance of `UnaryUnaryMultiCallable` object is called.
 | 
	
		
			
				|  |  |      """
 | 
	
		
			
				|  |  | -    _metadata: MetadataType
 | 
	
		
			
				|  |  | -    _request_serializer: SerializingFunction
 | 
	
		
			
				|  |  | -    _response_deserializer: DeserializingFunction
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    _metadata_sent: asyncio.Event
 | 
	
		
			
				|  |  | -    _done_writing: bool
 | 
	
		
			
				|  |  | -    _initializer: asyncio.Task
 | 
	
		
			
				|  |  | -    _async_request_poller: asyncio.Task
 | 
	
		
			
				|  |  | -    _message_aiter: AsyncIterable[ResponseType]
 | 
	
		
			
				|  |  | +    _request: RequestType
 | 
	
		
			
				|  |  | +    _call: asyncio.Task
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      # pylint: disable=too-many-arguments
 | 
	
		
			
				|  |  | -    def __init__(self,
 | 
	
		
			
				|  |  | -                 request_async_iterator: Optional[AsyncIterable[RequestType]],
 | 
	
		
			
				|  |  | -                 deadline: Optional[float], metadata: MetadataType,
 | 
	
		
			
				|  |  | +    def __init__(self, request: RequestType, deadline: Optional[float],
 | 
	
		
			
				|  |  | +                 metadata: MetadataType,
 | 
	
		
			
				|  |  |                   credentials: Optional[grpc.CallCredentials],
 | 
	
		
			
				|  |  |                   channel: cygrpc.AioChannel, method: bytes,
 | 
	
		
			
				|  |  |                   request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  |                   response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  |                   loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  | -        super().__init__(channel.call(method, deadline, credentials), loop)
 | 
	
		
			
				|  |  | -        self._metadata = metadata
 | 
	
		
			
				|  |  | -        self._request_serializer = request_serializer
 | 
	
		
			
				|  |  | -        self._response_deserializer = response_deserializer
 | 
	
		
			
				|  |  | +        super().__init__(channel.call(method, deadline, credentials), metadata,
 | 
	
		
			
				|  |  | +                         request_serializer, response_deserializer, loop)
 | 
	
		
			
				|  |  | +        self._request = request
 | 
	
		
			
				|  |  | +        self._init_unary_response_mixin(self._invoke())
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        self._metadata_sent = asyncio.Event(loop=loop)
 | 
	
		
			
				|  |  | -        self._done_writing = False
 | 
	
		
			
				|  |  | +    async def _invoke(self) -> ResponseType:
 | 
	
		
			
				|  |  | +        serialized_request = _common.serialize(self._request,
 | 
	
		
			
				|  |  | +                                               self._request_serializer)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        self._initializer = self._loop.create_task(self._prepare_rpc())
 | 
	
		
			
				|  |  | +        # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
 | 
	
		
			
				|  |  | +        # because the asyncio.Task class do not cache the exception object.
 | 
	
		
			
				|  |  | +        # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
 | 
	
		
			
				|  |  | +        try:
 | 
	
		
			
				|  |  | +            serialized_response = await self._cython_call.unary_unary(
 | 
	
		
			
				|  |  | +                serialized_request, self._metadata)
 | 
	
		
			
				|  |  | +        except asyncio.CancelledError:
 | 
	
		
			
				|  |  | +            if not self.cancelled():
 | 
	
		
			
				|  |  | +                self.cancel()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        # If user passes in an async iterator, create a consumer coroutine.
 | 
	
		
			
				|  |  | -        if request_async_iterator is not None:
 | 
	
		
			
				|  |  | -            self._async_request_poller = loop.create_task(
 | 
	
		
			
				|  |  | -                self._consume_request_iterator(request_async_iterator))
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            self._async_request_poller = None
 | 
	
		
			
				|  |  | -        self._message_aiter = None
 | 
	
		
			
				|  |  | +        # Raises here if RPC failed or cancelled
 | 
	
		
			
				|  |  | +        await self._raise_for_status()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def cancel(self) -> bool:
 | 
	
		
			
				|  |  | -        if super().cancel():
 | 
	
		
			
				|  |  | -            self._initializer.cancel()
 | 
	
		
			
				|  |  | -            if self._async_request_poller is not None:
 | 
	
		
			
				|  |  | -                self._async_request_poller.cancel()
 | 
	
		
			
				|  |  | -            return True
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            return False
 | 
	
		
			
				|  |  | +        return _common.deserialize(serialized_response,
 | 
	
		
			
				|  |  | +                                   self._response_deserializer)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def _metadata_sent_observer(self):
 | 
	
		
			
				|  |  | -        self._metadata_sent.set()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def _prepare_rpc(self):
 | 
	
		
			
				|  |  | -        """This method prepares the RPC for receiving/sending messages.
 | 
	
		
			
				|  |  | +class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
 | 
	
		
			
				|  |  | +    """Object for managing unary-stream RPC calls.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        All other operations around the stream should only happen after the
 | 
	
		
			
				|  |  | -        completion of this method.
 | 
	
		
			
				|  |  | -        """
 | 
	
		
			
				|  |  | +    Returned when an instance of `UnaryStreamMultiCallable` object is called.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +    _request: RequestType
 | 
	
		
			
				|  |  | +    _send_unary_request_task: asyncio.Task
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # pylint: disable=too-many-arguments
 | 
	
		
			
				|  |  | +    def __init__(self, request: RequestType, deadline: Optional[float],
 | 
	
		
			
				|  |  | +                 metadata: MetadataType,
 | 
	
		
			
				|  |  | +                 credentials: Optional[grpc.CallCredentials],
 | 
	
		
			
				|  |  | +                 channel: cygrpc.AioChannel, method: bytes,
 | 
	
		
			
				|  |  | +                 request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  | +                 response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  | +                 loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  | +        super().__init__(channel.call(method, deadline, credentials), metadata,
 | 
	
		
			
				|  |  | +                         request_serializer, response_deserializer, loop)
 | 
	
		
			
				|  |  | +        self._request = request
 | 
	
		
			
				|  |  | +        self._send_unary_request_task = loop.create_task(
 | 
	
		
			
				|  |  | +            self._send_unary_request())
 | 
	
		
			
				|  |  | +        self._init_stream_response_mixin(self._send_unary_request_task)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def _send_unary_request(self) -> ResponseType:
 | 
	
		
			
				|  |  | +        serialized_request = _common.serialize(self._request,
 | 
	
		
			
				|  |  | +                                               self._request_serializer)
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            await self._cython_call.initiate_stream_stream(
 | 
	
		
			
				|  |  | -                self._metadata, self._metadata_sent_observer)
 | 
	
		
			
				|  |  | +            await self._cython_call.initiate_unary_stream(
 | 
	
		
			
				|  |  | +                serialized_request, self._metadata)
 | 
	
		
			
				|  |  |          except asyncio.CancelledError:
 | 
	
		
			
				|  |  |              if not self.cancelled():
 | 
	
		
			
				|  |  |                  self.cancel()
 | 
	
		
			
				|  |  | -            # No need to raise RpcError here, because no one will `await` this task.
 | 
	
		
			
				|  |  | +            raise
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def _consume_request_iterator(
 | 
	
		
			
				|  |  | -            self, request_async_iterator: Optional[AsyncIterable[RequestType]]
 | 
	
		
			
				|  |  | -    ) -> None:
 | 
	
		
			
				|  |  | -        async for request in request_async_iterator:
 | 
	
		
			
				|  |  | -            await self.write(request)
 | 
	
		
			
				|  |  | -        await self.done_writing()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def write(self, request: RequestType) -> None:
 | 
	
		
			
				|  |  | -        if self._cython_call.done():
 | 
	
		
			
				|  |  | -            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
 | 
	
		
			
				|  |  | -        if self._done_writing:
 | 
	
		
			
				|  |  | -            raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
 | 
	
		
			
				|  |  | -        if not self._metadata_sent.is_set():
 | 
	
		
			
				|  |  | -            await self._metadata_sent.wait()
 | 
	
		
			
				|  |  | +class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
 | 
	
		
			
				|  |  | +                      _base_call.StreamUnaryCall):
 | 
	
		
			
				|  |  | +    """Object for managing stream-unary RPC calls.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        serialized_request = _common.serialize(request,
 | 
	
		
			
				|  |  | -                                               self._request_serializer)
 | 
	
		
			
				|  |  | +    Returned when an instance of `StreamUnaryMultiCallable` object is called.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    # pylint: disable=too-many-arguments
 | 
	
		
			
				|  |  | +    def __init__(self,
 | 
	
		
			
				|  |  | +                 request_async_iterator: Optional[AsyncIterable[RequestType]],
 | 
	
		
			
				|  |  | +                 deadline: Optional[float], metadata: MetadataType,
 | 
	
		
			
				|  |  | +                 credentials: Optional[grpc.CallCredentials],
 | 
	
		
			
				|  |  | +                 channel: cygrpc.AioChannel, method: bytes,
 | 
	
		
			
				|  |  | +                 request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  | +                 response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  | +                 loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  | +        super().__init__(channel.call(method, deadline, credentials), metadata,
 | 
	
		
			
				|  |  | +                         request_serializer, response_deserializer, loop)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self._init_stream_request_mixin(request_async_iterator)
 | 
	
		
			
				|  |  | +        self._init_unary_response_mixin(self._conduct_rpc())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def _conduct_rpc(self) -> ResponseType:
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            await self._cython_call.send_serialized_message(serialized_request)
 | 
	
		
			
				|  |  | +            serialized_response = await self._cython_call.stream_unary(
 | 
	
		
			
				|  |  | +                self._metadata, self._metadata_sent_observer)
 | 
	
		
			
				|  |  |          except asyncio.CancelledError:
 | 
	
		
			
				|  |  |              if not self.cancelled():
 | 
	
		
			
				|  |  |                  self.cancel()
 | 
	
		
			
				|  |  | -            await self._raise_for_status()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def done_writing(self) -> None:
 | 
	
		
			
				|  |  | -        """Implementation of done_writing is idempotent."""
 | 
	
		
			
				|  |  | -        if self._cython_call.done():
 | 
	
		
			
				|  |  | -            # If the RPC is finished, do nothing.
 | 
	
		
			
				|  |  | -            return
 | 
	
		
			
				|  |  | -        if not self._done_writing:
 | 
	
		
			
				|  |  | -            # If the done writing is not sent before, try to send it.
 | 
	
		
			
				|  |  | -            self._done_writing = True
 | 
	
		
			
				|  |  | -            try:
 | 
	
		
			
				|  |  | -                await self._cython_call.send_receive_close()
 | 
	
		
			
				|  |  | -            except asyncio.CancelledError:
 | 
	
		
			
				|  |  | -                if not self.cancelled():
 | 
	
		
			
				|  |  | -                    self.cancel()
 | 
	
		
			
				|  |  | -                await self._raise_for_status()
 | 
	
		
			
				|  |  | +        # Raises RpcError if the RPC failed or cancelled
 | 
	
		
			
				|  |  | +        await self._raise_for_status()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def _fetch_stream_responses(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        """The async generator that yields responses from peer."""
 | 
	
		
			
				|  |  | -        message = await self._read()
 | 
	
		
			
				|  |  | -        while message is not cygrpc.EOF:
 | 
	
		
			
				|  |  | -            yield message
 | 
	
		
			
				|  |  | -            message = await self._read()
 | 
	
		
			
				|  |  | +        return _common.deserialize(serialized_response,
 | 
	
		
			
				|  |  | +                                   self._response_deserializer)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def __aiter__(self) -> AsyncIterable[ResponseType]:
 | 
	
		
			
				|  |  | -        if self._message_aiter is None:
 | 
	
		
			
				|  |  | -            self._message_aiter = self._fetch_stream_responses()
 | 
	
		
			
				|  |  | -        return self._message_aiter
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def _read(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        # Wait for the setup
 | 
	
		
			
				|  |  | -        await self._initializer
 | 
	
		
			
				|  |  | +class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
 | 
	
		
			
				|  |  | +                       _base_call.StreamStreamCall):
 | 
	
		
			
				|  |  | +    """Object for managing stream-stream RPC calls.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        # Reads response message from Core
 | 
	
		
			
				|  |  | +    Returned when an instance of `StreamStreamMultiCallable` object is called.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +    _initializer: asyncio.Task
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    # pylint: disable=too-many-arguments
 | 
	
		
			
				|  |  | +    def __init__(self,
 | 
	
		
			
				|  |  | +                 request_async_iterator: Optional[AsyncIterable[RequestType]],
 | 
	
		
			
				|  |  | +                 deadline: Optional[float], metadata: MetadataType,
 | 
	
		
			
				|  |  | +                 credentials: Optional[grpc.CallCredentials],
 | 
	
		
			
				|  |  | +                 channel: cygrpc.AioChannel, method: bytes,
 | 
	
		
			
				|  |  | +                 request_serializer: SerializingFunction,
 | 
	
		
			
				|  |  | +                 response_deserializer: DeserializingFunction,
 | 
	
		
			
				|  |  | +                 loop: asyncio.AbstractEventLoop) -> None:
 | 
	
		
			
				|  |  | +        super().__init__(channel.call(method, deadline, credentials), metadata,
 | 
	
		
			
				|  |  | +                         request_serializer, response_deserializer, loop)
 | 
	
		
			
				|  |  | +        self._initializer = self._loop.create_task(self._prepare_rpc())
 | 
	
		
			
				|  |  | +        self._init_stream_request_mixin(request_async_iterator)
 | 
	
		
			
				|  |  | +        self._init_stream_response_mixin(self._initializer)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def _prepare_rpc(self):
 | 
	
		
			
				|  |  | +        """This method prepares the RPC for receiving/sending messages.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        All other operations around the stream should only happen after the
 | 
	
		
			
				|  |  | +        completion of this method.
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  | -            raw_response = await self._cython_call.receive_serialized_message()
 | 
	
		
			
				|  |  | +            await self._cython_call.initiate_stream_stream(
 | 
	
		
			
				|  |  | +                self._metadata, self._metadata_sent_observer)
 | 
	
		
			
				|  |  |          except asyncio.CancelledError:
 | 
	
		
			
				|  |  |              if not self.cancelled():
 | 
	
		
			
				|  |  |                  self.cancel()
 | 
	
		
			
				|  |  | -            await self._raise_for_status()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if raw_response is cygrpc.EOF:
 | 
	
		
			
				|  |  | -            return cygrpc.EOF
 | 
	
		
			
				|  |  | -        else:
 | 
	
		
			
				|  |  | -            return _common.deserialize(raw_response,
 | 
	
		
			
				|  |  | -                                       self._response_deserializer)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    async def read(self) -> ResponseType:
 | 
	
		
			
				|  |  | -        if self._cython_call.done():
 | 
	
		
			
				|  |  | -            await self._raise_for_status()
 | 
	
		
			
				|  |  | -            return cygrpc.EOF
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        response_message = await self._read()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if response_message is cygrpc.EOF:
 | 
	
		
			
				|  |  | -            # If the read operation failed, Core should explain why.
 | 
	
		
			
				|  |  | -            await self._raise_for_status()
 | 
	
		
			
				|  |  | -        return response_message
 | 
	
		
			
				|  |  | +            # No need to raise RpcError here, because no one will `await` this task.
 |