|  | @@ -14,6 +14,7 @@
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  # TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
 | 
	
		
			
				|  |  |  _LOGGER = logging.getLogger(__name__)
 | 
	
		
			
				|  |  | +cdef int _EMPTY_FLAG = 0
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  cdef class _HandlerCallDetails:
 | 
	
	
		
			
				|  | @@ -171,6 +172,9 @@ async def _handle_unary_unary_rpc(object method_handler,
 | 
	
		
			
				|  |  |      await callback_start_batch(rpc_state, send_ops, loop)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
 | 
	
		
			
				|  |  |      # Finds the method handler (application logic)
 | 
	
		
			
				|  |  |      cdef object method_handler = _find_method_handler(
 | 
	
	
		
			
				|  | @@ -180,6 +184,7 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
 | 
	
		
			
				|  |  |      if method_handler is None:
 | 
	
		
			
				|  |  |          # TODO(lidiz) return unimplemented error to client side
 | 
	
		
			
				|  |  |          raise NotImplementedError()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      # TODO(lidiz) extend to all 4 types of RPC
 | 
	
		
			
				|  |  |      if method_handler.request_streaming or method_handler.response_streaming:
 | 
	
		
			
				|  |  |          raise NotImplementedError()
 | 
	
	
		
			
				|  | @@ -223,6 +228,16 @@ async def _server_call_request_call(Server server,
 | 
	
		
			
				|  |  |      return rpc_state
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +async def _handle_cancellation_from_core(object rpc_task,
 | 
	
		
			
				|  |  | +                                          RPCState rpc_state,
 | 
	
		
			
				|  |  | +                                          object loop):
 | 
	
		
			
				|  |  | +    cdef ReceiveCloseOnServerOperation op = ReceiveCloseOnServerOperation(_EMPTY_FLAG)
 | 
	
		
			
				|  |  | +    cdef tuple ops = (op,)
 | 
	
		
			
				|  |  | +    await callback_start_batch(rpc_state, ops, loop)
 | 
	
		
			
				|  |  | +    if op.cancelled() and not rpc_task.done():
 | 
	
		
			
				|  |  | +        rpc_task.cancel()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  cdef _CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = _CallbackFailureHandler(
 | 
	
		
			
				|  |  |      'grpc_completion_queue_shutdown',
 | 
	
		
			
				|  |  |      'Unknown',
 | 
	
	
		
			
				|  | @@ -277,7 +292,7 @@ cdef class AioServer:
 | 
	
		
			
				|  |  |          self.add_generic_rpc_handlers(generic_handlers)
 | 
	
		
			
				|  |  |          self._serving_task = None
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        self._shutdown_lock = asyncio.Lock()
 | 
	
		
			
				|  |  | +        self._shutdown_lock = asyncio.Lock(loop=self._loop)
 | 
	
		
			
				|  |  |          self._shutdown_completed = self._loop.create_future()
 | 
	
		
			
				|  |  |          self._shutdown_callback_wrapper = CallbackWrapper(
 | 
	
		
			
				|  |  |              self._shutdown_completed,
 | 
	
	
		
			
				|  | @@ -320,10 +335,20 @@ cdef class AioServer:
 | 
	
		
			
				|  |  |                  self._cq,
 | 
	
		
			
				|  |  |                  self._loop)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            self._loop.create_task(_handle_rpc(
 | 
	
		
			
				|  |  | -                self._generic_handlers,
 | 
	
		
			
				|  |  | -                rpc_state,
 | 
	
		
			
				|  |  | -                self._loop))
 | 
	
		
			
				|  |  | +            rpc_task = self._loop.create_task(
 | 
	
		
			
				|  |  | +                _handle_rpc(
 | 
	
		
			
				|  |  | +                    self._generic_handlers,
 | 
	
		
			
				|  |  | +                    rpc_state,
 | 
	
		
			
				|  |  | +                    self._loop
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +            self._loop.create_task(
 | 
	
		
			
				|  |  | +                _handle_cancellation_from_core(
 | 
	
		
			
				|  |  | +                    rpc_task,
 | 
	
		
			
				|  |  | +                    rpc_state,
 | 
	
		
			
				|  |  | +                    self._loop
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def _serving_task_crash_handler(self, object task):
 | 
	
		
			
				|  |  |          """Shutdown the server immediately if unexpectedly exited."""
 | 
	
	
		
			
				|  | @@ -389,7 +414,14 @@ cdef class AioServer:
 | 
	
		
			
				|  |  |              await self._shutdown_completed
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              try:
 | 
	
		
			
				|  |  | -                await asyncio.wait_for(asyncio.shield(self._shutdown_completed), grace)
 | 
	
		
			
				|  |  | +                await asyncio.wait_for(
 | 
	
		
			
				|  |  | +                    asyncio.shield(
 | 
	
		
			
				|  |  | +                        self._shutdown_completed,
 | 
	
		
			
				|  |  | +                        loop=self._loop
 | 
	
		
			
				|  |  | +                    ),
 | 
	
		
			
				|  |  | +                    grace,
 | 
	
		
			
				|  |  | +                    loop=self._loop,
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  |              except asyncio.TimeoutError:
 | 
	
		
			
				|  |  |                  # Cancels all ongoing calls by the end of grace period.
 | 
	
		
			
				|  |  |                  grpc_server_cancel_all_calls(self._server.c_server)
 | 
	
	
		
			
				|  | @@ -410,7 +442,14 @@ cdef class AioServer:
 | 
	
		
			
				|  |  |              await self._shutdown_completed
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              try:
 | 
	
		
			
				|  |  | -                await asyncio.wait_for(asyncio.shield(self._shutdown_completed), timeout)
 | 
	
		
			
				|  |  | +                await asyncio.wait_for(
 | 
	
		
			
				|  |  | +                    asyncio.shield(
 | 
	
		
			
				|  |  | +                        self._shutdown_completed,
 | 
	
		
			
				|  |  | +                        loop=self._loop,
 | 
	
		
			
				|  |  | +                    ),
 | 
	
		
			
				|  |  | +                    timeout,
 | 
	
		
			
				|  |  | +                    loop=self._loop,
 | 
	
		
			
				|  |  | +                )
 | 
	
		
			
				|  |  |              except asyncio.TimeoutError:
 | 
	
		
			
				|  |  |                  if self._crash_exception is not None:
 | 
	
		
			
				|  |  |                      raise self._crash_exception
 |