Эх сурвалжийг харах

Add failure handling mechanism to CallbackWrapper

Lidi Zheng 6 жил өмнө
parent
commit
980bcaf076

+ 2 - 2
src/python/grpcio/grpc/_cython/_cygrpc/aio/callbackcontext.pxd.pxi

@@ -16,5 +16,5 @@ cimport cpython
 
 
 cdef struct CallbackContext:
 cdef struct CallbackContext:
     grpc_experimental_completion_queue_functor functor
     grpc_experimental_completion_queue_functor functor
-    cpython.PyObject *waiter
-
+    cpython.PyObject *waiter  # asyncio.Future
+    cpython.PyObject *failure_handler  # cygrpc.CallbackFailureHandler

+ 6 - 4
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi

@@ -27,7 +27,8 @@ cdef class RPCState:
 
 
 cdef class CallbackWrapper:
 cdef class CallbackWrapper:
     cdef CallbackContext context
     cdef CallbackContext context
-    cdef object _reference
+    cdef object _reference_of_future
+    cdef object _reference_of_failure_handler
 
 
     @staticmethod
     @staticmethod
     cdef void functor_run(
     cdef void functor_run(
@@ -48,9 +49,9 @@ cdef enum AioServerStatus:
 cdef class _CallbackCompletionQueue:
 cdef class _CallbackCompletionQueue:
     cdef grpc_completion_queue *_cq
     cdef grpc_completion_queue *_cq
     cdef grpc_completion_queue* c_ptr(self)
     cdef grpc_completion_queue* c_ptr(self)
-    cdef object _shutdown_completed
+    cdef object _shutdown_completed  # asyncio.Future
     cdef CallbackWrapper _wrapper
     cdef CallbackWrapper _wrapper
-    cdef object _loop
+    cdef object _loop  # asyncio.EventLoop
 
 
 
 
 cdef class AioServer:
 cdef class AioServer:
@@ -58,4 +59,5 @@ cdef class AioServer:
     cdef _CallbackCompletionQueue _cq
     cdef _CallbackCompletionQueue _cq
     cdef list _generic_handlers
     cdef list _generic_handlers
     cdef AioServerStatus _status
     cdef AioServerStatus _status
-    cdef object _loop
+    cdef object _loop  # asyncio.EventLoop
+    cdef object _serving_task  # asyncio.Task

+ 97 - 38
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi

@@ -14,6 +14,7 @@
 
 
 _LOGGER = logging.getLogger(__name__)
 _LOGGER = logging.getLogger(__name__)
 
 
+
 cdef class _HandlerCallDetails:
 cdef class _HandlerCallDetails:
     def __cinit__(self, str method, tuple invocation_metadata):
     def __cinit__(self, str method, tuple invocation_metadata):
         self.method = method
         self.method = method
@@ -23,22 +24,60 @@ cdef class _HandlerCallDetails:
 class _ServicerContextPlaceHolder(object): pass
 class _ServicerContextPlaceHolder(object): pass
 
 
 
 
+cdef class CallbackFailureHandler:
+    cdef str _c_core_api
+    cdef object _error_details
+    cdef object _exception_type
+    cdef object _callback  # Callable[[Future], None]
+
+    def __cinit__(self,
+                  str c_core_api="",
+                  object error_details="UNKNOWN",
+                  object exception_type=RuntimeError,
+                  object callback=None):
+        """Handles failure by raising exception or execute a callbcak.
+        
+        The callback accepts a future, returns nothing. The callback is
+        expected to finish the future either "set_result" or "set_exception".
+        """
+        if callback is None:
+            self._c_core_api = c_core_api
+            self._error_details = error_details
+            self._exception_type = exception_type    
+            self._callback = self._raise_exception
+        else:
+            self._callback = callback
+
+    def _raise_exception(self, object future):
+        future.set_exception(self._exception_type(
+            'Failed "%s": %s' % (self._c_core_api, self._error_details)
+        ))
+
+    cdef handle(self, object future):
+        self._callback(future)
+
+
 # TODO(https://github.com/grpc/grpc/issues/20669)
 # TODO(https://github.com/grpc/grpc/issues/20669)
 # Apply this to the client-side
 # Apply this to the client-side
 cdef class CallbackWrapper:
 cdef class CallbackWrapper:
 
 
-    def __cinit__(self, object future):
+    def __cinit__(self, object future, CallbackFailureHandler failure_handler):
         self.context.functor.functor_run = self.functor_run
         self.context.functor.functor_run = self.functor_run
-        self.context.waiter = <cpython.PyObject*>(future)
-        self._reference = future
+        self.context.waiter = <cpython.PyObject*>future
+        self.context.failure_handler = <cpython.PyObject*>failure_handler
+        # NOTE(lidiz) Not using a list here, because this class is critical in
+        # data path. We should make it as efficient as possible.
+        self._reference_of_future = future
+        self._reference_of_failure_handler = failure_handler
 
 
     @staticmethod
     @staticmethod
     cdef void functor_run(
     cdef void functor_run(
             grpc_experimental_completion_queue_functor* functor,
             grpc_experimental_completion_queue_functor* functor,
             int success):
             int success):
         cdef CallbackContext *context = <CallbackContext *>functor
         cdef CallbackContext *context = <CallbackContext *>functor
-        if success == 0:
-            (<object>context.waiter).set_exception(RuntimeError())
+        if succeed == 0:
+            (<CallbackFailureHandler>context.failure_handler).handle(
+                <object>context.waiter)
         else:
         else:
             (<object>context.waiter).set_result(None)
             (<object>context.waiter).set_result(None)
 
 
@@ -85,7 +124,9 @@ async def callback_start_batch(RPCState rpc_state,
     batch_operation_tag.prepare()
     batch_operation_tag.prepare()
 
 
     cdef object future = loop.create_future()
     cdef object future = loop.create_future()
-    cdef CallbackWrapper wrapper = CallbackWrapper(future)
+    cdef CallbackWrapper wrapper = CallbackWrapper(
+        future,
+        CallbackFailureHandler('callback_start_batch', operations))
     # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
     # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
     # when calling "await". This is an over-optimization by Cython.
     # when calling "await". This is an over-optimization by Cython.
     cpython.Py_INCREF(wrapper)
     cpython.Py_INCREF(wrapper)
@@ -162,13 +203,21 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
         )
         )
 
 
 
 
+def _FINISH_FUTURE(future):
+    future.set_result(None)
+
+cdef CallbackFailureHandler IGNORE_FAILURE = CallbackFailureHandler(callback=_FINISH_FUTURE)
+
+
 async def _server_call_request_call(Server server,
 async def _server_call_request_call(Server server,
                                     _CallbackCompletionQueue cq,
                                     _CallbackCompletionQueue cq,
                                     object loop):
                                     object loop):
     cdef grpc_call_error error
     cdef grpc_call_error error
     cdef RPCState rpc_state = RPCState()
     cdef RPCState rpc_state = RPCState()
     cdef object future = loop.create_future()
     cdef object future = loop.create_future()
-    cdef CallbackWrapper wrapper = CallbackWrapper(future)
+    cdef CallbackWrapper wrapper = CallbackWrapper(
+        future,
+        IGNORE_FAILURE)
     # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
     # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
     # when calling "await". This is an over-optimization by Cython.
     # when calling "await". This is an over-optimization by Cython.
     cpython.Py_INCREF(wrapper)
     cpython.Py_INCREF(wrapper)
@@ -186,27 +235,7 @@ async def _server_call_request_call(Server server,
     return rpc_state
     return rpc_state
 
 
 
 
-async def _server_main_loop(object loop,
-                            Server server,
-                            _CallbackCompletionQueue cq,
-                            list generic_handlers):
-    cdef RPCState rpc_state
-
-    while True:
-        rpc_state = await _server_call_request_call(
-            server,
-            cq,
-            loop)
-
-        loop.create_task(_handle_rpc(generic_handlers, rpc_state, loop))
-
-
-async def _server_start(object loop,
-                        Server server,
-                        _CallbackCompletionQueue cq,
-                        list generic_handlers):
-    server.start(backup_queue=False)
-    await _server_main_loop(loop, server, cq, generic_handlers)
+cdef CallbackFailureHandler CQ_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler('grpc_completion_queue_shutdown')
 
 
 
 
 cdef class _CallbackCompletionQueue:
 cdef class _CallbackCompletionQueue:
@@ -214,7 +243,9 @@ cdef class _CallbackCompletionQueue:
     def __cinit__(self, object loop):
     def __cinit__(self, object loop):
         self._loop = loop
         self._loop = loop
         self._shutdown_completed = loop.create_future()
         self._shutdown_completed = loop.create_future()
-        self._wrapper = CallbackWrapper(self._shutdown_completed)
+        self._wrapper = CallbackWrapper(
+            self._shutdown_completed,
+            CQ_SHUTDOWN_FAILURE_HANDLER)
         self._cq = grpc_completion_queue_create_for_callback(
         self._cq = grpc_completion_queue_create_for_callback(
             self._wrapper.c_functor(),
             self._wrapper.c_functor(),
             NULL
             NULL
@@ -229,12 +260,13 @@ cdef class _CallbackCompletionQueue:
         grpc_completion_queue_destroy(self._cq)
         grpc_completion_queue_destroy(self._cq)
 
 
 
 
+cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHandler('grpc_server_shutdown_and_notify')
+
+
 cdef class AioServer:
 cdef class AioServer:
 
 
     def __init__(self, loop, thread_pool, generic_handlers, interceptors,
     def __init__(self, loop, thread_pool, generic_handlers, interceptors,
                  options, maximum_concurrent_rpcs, compression):
                  options, maximum_concurrent_rpcs, compression):
-        self._loop = loop
-
         # C-Core objects won't be deallocated automatically.
         # C-Core objects won't be deallocated automatically.
         self._server = Server(options)
         self._server = Server(options)
         self._cq = _CallbackCompletionQueue(loop)
         self._cq = _CallbackCompletionQueue(loop)
@@ -244,9 +276,11 @@ cdef class AioServer:
             NULL
             NULL
         )
         )
 
 
+        self._loop = loop
         self._status = AIO_SERVER_STATUS_READY
         self._status = AIO_SERVER_STATUS_READY
         self._generic_handlers = []
         self._generic_handlers = []
         self.add_generic_rpc_handlers(generic_handlers)
         self.add_generic_rpc_handlers(generic_handlers)
+        self._serving_task = None
 
 
         if interceptors:
         if interceptors:
             raise NotImplementedError()
             raise NotImplementedError()
@@ -268,6 +302,27 @@ cdef class AioServer:
         return self._server.add_http2_port(address,
         return self._server.add_http2_port(address,
                                           server_credentials._credentials)
                                           server_credentials._credentials)
 
 
+    async def _server_main_loop(self,
+                                object server_started):
+        self._server.start(backup_queue=False)
+        server_started.set_result(True)
+        cdef RPCState rpc_state
+
+        while True:
+            # When shutdown process starts, no more new connections.
+            if self._status != AIO_SERVER_STATUS_RUNNING:
+                break
+
+            rpc_state = await _server_call_request_call(
+                self._server,
+                self._cq,
+                self._loop)
+
+            self._loop.create_task(_handle_rpc(
+                self._generic_handlers,
+                rpc_state,
+                self._loop))
+
     async def start(self):
     async def start(self):
         if self._status == AIO_SERVER_STATUS_RUNNING:
         if self._status == AIO_SERVER_STATUS_RUNNING:
             return
             return
@@ -275,12 +330,11 @@ cdef class AioServer:
             raise RuntimeError('Server not in ready state')
             raise RuntimeError('Server not in ready state')
 
 
         self._status = AIO_SERVER_STATUS_RUNNING
         self._status = AIO_SERVER_STATUS_RUNNING
-        self._loop.create_task(_server_start(
-            self._loop,
-            self._server,
-            self._cq,
-            self._generic_handlers,
-        ))
+        cdef object server_started = self._loop.create_future()
+        self._serving_task = self._loop.create_task(self._server_main_loop(server_started))
+        # Needs to explicitly wait for the server to start up.
+        # Otherwise, the actual start time of the server is un-controllable.
+        await server_started
 
 
     async def shutdown(self, grace):
     async def shutdown(self, grace):
         """Gracefully shutdown the C-Core server.
         """Gracefully shutdown the C-Core server.
@@ -295,7 +349,9 @@ cdef class AioServer:
             # The server either is shutting down, or not started.
             # The server either is shutting down, or not started.
             return
             return
         cdef object shutdown_completed = self._loop.create_future()
         cdef object shutdown_completed = self._loop.create_future()
-        cdef CallbackWrapper wrapper = CallbackWrapper(shutdown_completed)
+        cdef CallbackWrapper wrapper = CallbackWrapper(
+            shutdown_completed,
+            SERVER_SHUTDOWN_FAILURE_HANDLER)
         # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
         # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
         # when calling "await". This is an over-optimization by Cython.
         # when calling "await". This is an over-optimization by Cython.
         cpython.Py_INCREF(wrapper)
         cpython.Py_INCREF(wrapper)
@@ -309,6 +365,9 @@ cdef class AioServer:
         self._server.is_shutting_down = True
         self._server.is_shutting_down = True
         self._status = AIO_SERVER_STATUS_STOPPING
         self._status = AIO_SERVER_STATUS_STOPPING
 
 
+        # Ensures the serving task (coroutine) exits normally
+        await self._serving_task
+
         if grace is None:
         if grace is None:
             # Directly cancels all calls
             # Directly cancels all calls
             grpc_server_cancel_all_calls(self._server.c_server)
             grpc_server_cancel_all_calls(self._server.c_server)