|  | @@ -16,6 +16,7 @@
 | 
	
		
			
				|  |  |  import asyncio
 | 
	
		
			
				|  |  |  from functools import partial
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  | +import enum
 | 
	
		
			
				|  |  |  from typing import AsyncIterable, Awaitable, Dict, Optional
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import grpc
 | 
	
	
		
			
				|  | @@ -238,6 +239,12 @@ class Call:
 | 
	
		
			
				|  |  |          return self._repr()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class _APIStyle(enum.IntEnum):
 | 
	
		
			
				|  |  | +    UNKNOWN = 0
 | 
	
		
			
				|  |  | +    ASYNC_GENERATOR = 1
 | 
	
		
			
				|  |  | +    READER_WRITER = 2
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class _UnaryResponseMixin(Call):
 | 
	
		
			
				|  |  |      _call_response: asyncio.Task
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -283,10 +290,19 @@ class _UnaryResponseMixin(Call):
 | 
	
		
			
				|  |  |  class _StreamResponseMixin(Call):
 | 
	
		
			
				|  |  |      _message_aiter: AsyncIterable[ResponseType]
 | 
	
		
			
				|  |  |      _preparation: asyncio.Task
 | 
	
		
			
				|  |  | +    _response_style: _APIStyle
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _init_stream_response_mixin(self, preparation: asyncio.Task):
 | 
	
		
			
				|  |  |          self._message_aiter = None
 | 
	
		
			
				|  |  |          self._preparation = preparation
 | 
	
		
			
				|  |  | +        self._response_style = _APIStyle.UNKNOWN
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _update_response_style(self, style: _APIStyle):
 | 
	
		
			
				|  |  | +        if self._response_style is _APIStyle.UNKNOWN:
 | 
	
		
			
				|  |  | +            self._response_style = style
 | 
	
		
			
				|  |  | +        elif self._response_style is not style:
 | 
	
		
			
				|  |  | +            raise cygrpc.UsageError(
 | 
	
		
			
				|  |  | +                'Please don\'t mix two styles of API for streaming responses')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def cancel(self) -> bool:
 | 
	
		
			
				|  |  |          if super().cancel():
 | 
	
	
		
			
				|  | @@ -302,6 +318,7 @@ class _StreamResponseMixin(Call):
 | 
	
		
			
				|  |  |              message = await self._read()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def __aiter__(self) -> AsyncIterable[ResponseType]:
 | 
	
		
			
				|  |  | +        self._update_response_style(_APIStyle.ASYNC_GENERATOR)
 | 
	
		
			
				|  |  |          if self._message_aiter is None:
 | 
	
		
			
				|  |  |              self._message_aiter = self._fetch_stream_responses()
 | 
	
		
			
				|  |  |          return self._message_aiter
 | 
	
	
		
			
				|  | @@ -328,6 +345,7 @@ class _StreamResponseMixin(Call):
 | 
	
		
			
				|  |  |          if self.done():
 | 
	
		
			
				|  |  |              await self._raise_for_status()
 | 
	
		
			
				|  |  |              return cygrpc.EOF
 | 
	
		
			
				|  |  | +        self._update_response_style(_APIStyle.READER_WRITER)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          response_message = await self._read()
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -339,20 +357,28 @@ class _StreamResponseMixin(Call):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class _StreamRequestMixin(Call):
 | 
	
		
			
				|  |  |      _metadata_sent: asyncio.Event
 | 
	
		
			
				|  |  | -    _done_writing: bool
 | 
	
		
			
				|  |  | +    _done_writing_flag: bool
 | 
	
		
			
				|  |  |      _async_request_poller: Optional[asyncio.Task]
 | 
	
		
			
				|  |  | +    _request_style: _APIStyle
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _init_stream_request_mixin(
 | 
	
		
			
				|  |  |              self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
 | 
	
		
			
				|  |  |          self._metadata_sent = asyncio.Event(loop=self._loop)
 | 
	
		
			
				|  |  | -        self._done_writing = False
 | 
	
		
			
				|  |  | +        self._done_writing_flag = False
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # If user passes in an async iterator, create a consumer Task.
 | 
	
		
			
				|  |  |          if request_async_iterator is not None:
 | 
	
		
			
				|  |  |              self._async_request_poller = self._loop.create_task(
 | 
	
		
			
				|  |  |                  self._consume_request_iterator(request_async_iterator))
 | 
	
		
			
				|  |  | +            self._request_style = _APIStyle.ASYNC_GENERATOR
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              self._async_request_poller = None
 | 
	
		
			
				|  |  | +            self._request_style = _APIStyle.READER_WRITER
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _raise_for_different_style(self, style: _APIStyle):
 | 
	
		
			
				|  |  | +        if self._request_style is not style:
 | 
	
		
			
				|  |  | +            raise cygrpc.UsageError(
 | 
	
		
			
				|  |  | +                'Please don\'t mix two styles of API for streaming requests')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def cancel(self) -> bool:
 | 
	
		
			
				|  |  |          if super().cancel():
 | 
	
	
		
			
				|  | @@ -369,8 +395,8 @@ class _StreamRequestMixin(Call):
 | 
	
		
			
				|  |  |              self, request_async_iterator: AsyncIterable[RequestType]) -> None:
 | 
	
		
			
				|  |  |          try:
 | 
	
		
			
				|  |  |              async for request in request_async_iterator:
 | 
	
		
			
				|  |  | -                await self.write(request)
 | 
	
		
			
				|  |  | -            await self.done_writing()
 | 
	
		
			
				|  |  | +                await self._write(request)
 | 
	
		
			
				|  |  | +            await self._done_writing()
 | 
	
		
			
				|  |  |          except AioRpcError as rpc_error:
 | 
	
		
			
				|  |  |              # Rpc status should be exposed through other API. Exceptions raised
 | 
	
		
			
				|  |  |              # within this Task won't be retrieved by another coroutine. It's
 | 
	
	
		
			
				|  | @@ -378,10 +404,10 @@ class _StreamRequestMixin(Call):
 | 
	
		
			
				|  |  |              _LOGGER.debug('Exception while consuming the request_iterator: %s',
 | 
	
		
			
				|  |  |                            rpc_error)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def write(self, request: RequestType) -> None:
 | 
	
		
			
				|  |  | +    async def _write(self, request: RequestType) -> None:
 | 
	
		
			
				|  |  |          if self.done():
 | 
	
		
			
				|  |  |              raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
 | 
	
		
			
				|  |  | -        if self._done_writing:
 | 
	
		
			
				|  |  | +        if self._done_writing_flag:
 | 
	
		
			
				|  |  |              raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
 | 
	
		
			
				|  |  |          if not self._metadata_sent.is_set():
 | 
	
		
			
				|  |  |              await self._metadata_sent.wait()
 | 
	
	
		
			
				|  | @@ -398,14 +424,13 @@ class _StreamRequestMixin(Call):
 | 
	
		
			
				|  |  |                  self.cancel()
 | 
	
		
			
				|  |  |              await self._raise_for_status()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def done_writing(self) -> None:
 | 
	
		
			
				|  |  | -        """Implementation of done_writing is idempotent."""
 | 
	
		
			
				|  |  | +    async def _done_writing(self) -> None:
 | 
	
		
			
				|  |  |          if self.done():
 | 
	
		
			
				|  |  |              # If the RPC is finished, do nothing.
 | 
	
		
			
				|  |  |              return
 | 
	
		
			
				|  |  | -        if not self._done_writing:
 | 
	
		
			
				|  |  | +        if not self._done_writing_flag:
 | 
	
		
			
				|  |  |              # If the done writing is not sent before, try to send it.
 | 
	
		
			
				|  |  | -            self._done_writing = True
 | 
	
		
			
				|  |  | +            self._done_writing_flag = True
 | 
	
		
			
				|  |  |              try:
 | 
	
		
			
				|  |  |                  await self._cython_call.send_receive_close()
 | 
	
		
			
				|  |  |              except asyncio.CancelledError:
 | 
	
	
		
			
				|  | @@ -413,6 +438,15 @@ class _StreamRequestMixin(Call):
 | 
	
		
			
				|  |  |                      self.cancel()
 | 
	
		
			
				|  |  |                  await self._raise_for_status()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    async def write(self, request: RequestType) -> None:
 | 
	
		
			
				|  |  | +        self._raise_for_different_style(_APIStyle.READER_WRITER)
 | 
	
		
			
				|  |  | +        await self._write(request)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def done_writing(self) -> None:
 | 
	
		
			
				|  |  | +        """Implementation of done_writing is idempotent."""
 | 
	
		
			
				|  |  | +        self._raise_for_different_style(_APIStyle.READER_WRITER)
 | 
	
		
			
				|  |  | +        await self._done_writing()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
 | 
	
		
			
				|  |  |      """Object for managing unary-unary RPC calls.
 |