소스 검색

Implement abort mechanism for server side

Lidi Zheng 5 년 전
부모
커밋
cddd0a0419

+ 20 - 0
src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi

@@ -173,3 +173,23 @@ async def _receive_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
     cdef tuple ops = (op,)
     await execute_batch(grpc_call_wrapper, ops, loop)
     return op.initial_metadata()
+
+async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
+                                         grpc_status_code code,
+                                         str details,
+                                         tuple trailing_metadata,
+                                         bint metadata_sent,
+                                         object loop):
+    assert code != StatusCode.ok, 'Expecting non-ok status code.'
+    cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
+        trailing_metadata,
+        code,
+        details,
+        _EMPTY_FLAGS,
+    )
+    cdef tuple ops
+    if metadata_sent:
+        ops = (op,)
+    else:
+        ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAG))
+    await execute_batch(grpc_call_wrapper, ops, loop)

+ 3 - 0
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi

@@ -21,6 +21,9 @@ cdef class RPCState(GrpcCallWrapper):
     cdef grpc_call_details details
     cdef grpc_metadata_array request_metadata
     cdef AioServer server
+    cdef object abort_exception
+    cdef bint metadata_sent
+    cdef bint status_sent
 
     cdef bytes method(self)
 

+ 131 - 13
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi

@@ -14,6 +14,7 @@
 
 
 import inspect
+import traceback
 
 
 # TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
@@ -34,6 +35,9 @@ cdef class RPCState:
         self.server = server
         grpc_metadata_array_init(&self.request_metadata)
         grpc_call_details_init(&self.details)
+        self.abort_exception = None
+        self.metadata_sent = False
+        self.status_sent = False
 
     cdef bytes method(self):
       return _slice_bytes(self.details.method)
@@ -46,10 +50,54 @@ cdef class RPCState:
             grpc_call_unref(self.call)
 
 
+# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve
+# current code structure to make it happen.
+class AbortError(Exception): pass
+
+
+def _raise_if_aborted(RPCState rpc_state):
+    """Raise AbortError if RPC is aborted.
+
+    Server method handlers may suppress the abort exception. We need to halt
+    the RPC execution in that case. This function needs to be called after
+    running application code.
+    """
+    if rpc_state.abort_exception is not None:
+        raise rpc_state.abort_exception
+
+
+async def _perform_abort(RPCState rpc_state,
+                         grpc_status_code code,
+                         str details, 
+                         tuple trailing_metadata,
+                         object loop):
+    """Perform the abort logic.
+
+    Sends final status to the client, and then set the RPC into corresponding
+    state.
+    """
+    if rpc_state.abort_exception is not None:
+        raise RuntimeError('Abort already called!')
+    else:
+        # Keeps track of the exception object. After abort happen, the RPC
+        # should stop execution. However, if users decided to suppress it, it
+        # could lead to undefined behavior.
+        rpc_state.abort_exception = AbortError('Locally aborted.')
+
+    rpc_state.status_sent = True
+    await _send_error_status_from_server(
+        rpc_state,
+        code,
+        details,
+        trailing_metadata,
+        rpc_state.metadata_sent,
+        loop
+    )
+
+
 cdef class _ServicerContext:
     cdef RPCState _rpc_state
     cdef object _loop
-    cdef bint _metadata_sent
     cdef object _request_deserializer
     cdef object _response_serializer
 
@@ -62,27 +110,46 @@ cdef class _ServicerContext:
         self._request_deserializer = request_deserializer
         self._response_serializer = response_serializer
         self._loop = loop
-        self._metadata_sent = False
 
     async def read(self):
+        if self._rpc_state.status_sent:
+            raise RuntimeError('RPC already finished.')
         cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop)
         return deserialize(self._request_deserializer,
                            raw_message)
 
     async def write(self, object message):
+        if self._rpc_state.status_sent:
+            raise RuntimeError('RPC already finished.')
         await _send_message(self._rpc_state,
                             serialize(self._response_serializer, message),
-                            self._metadata_sent,
+                            self._rpc_state.metadata_sent,
                             self._loop)
-        if not self._metadata_sent:
-            self._metadata_sent = True
+        if not self._rpc_state.metadata_sent:
+            self._rpc_state.metadata_sent = True
 
     async def send_initial_metadata(self, tuple metadata):
-        if self._metadata_sent:
+        if self._rpc_state.status_sent:
+            raise RuntimeError('RPC already finished.')
+        elif self._rpc_state.metadata_sent:
             raise RuntimeError('Send initial metadata failed: already sent')
         else:
             _send_initial_metadata(self._rpc_state, self._loop)
-            self._metadata_sent = True
+            self._rpc_state.metadata_sent = True
+
+    async def abort(self,
+              object code,
+              str details='',
+              tuple trailing_metadata=_EMPTY_METADATA):
+        await _perform_abort(
+            self._rpc_state,
+            code.value[0],
+            details,
+            trailing_metadata,
+            self._loop
+        )
+
+        raise self._rpc_state.abort_exception
 
 
 cdef _find_method_handler(str method, list generic_handlers):
@@ -120,6 +187,9 @@ async def _handle_unary_unary_rpc(object method_handler,
         ),
     )
 
+    # Raises exception if aborted
+    _raise_if_aborted(rpc_state)
+
     # Serializes the response message
     cdef bytes response_raw = serialize(
         method_handler.response_serializer,
@@ -138,6 +208,7 @@ async def _handle_unary_unary_rpc(object method_handler,
         SendMessageOperation(response_raw, _EMPTY_FLAGS),
     )
     await execute_batch(rpc_state, send_ops, loop)
+    rpc_state.status_sent = True
 
 
 async def _handle_unary_stream_rpc(object method_handler,
@@ -167,6 +238,9 @@ async def _handle_unary_stream_rpc(object method_handler,
             request_message,
             servicer_context,
         )
+
+        # Raises exception if aborted
+        _raise_if_aborted(rpc_state)
     else:
         # The handler uses async generator API
         async_response_generator = method_handler.unary_stream(
@@ -176,6 +250,9 @@ async def _handle_unary_stream_rpc(object method_handler,
 
         # Consumes messages from the generator
         async for response_message in async_response_generator:
+            # Raises exception if aborted
+            _raise_if_aborted(rpc_state)
+
             if rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
                 # The async generator might yield much much later after the
                 # server is destroied. If we proceed, Core will crash badly.
@@ -194,6 +271,34 @@ async def _handle_unary_stream_rpc(object method_handler,
 
     cdef tuple ops = (op,)
     await execute_batch(rpc_state, ops, loop)
+    rpc_state.status_sent = True
+
+
+async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
+    try:
+        try:
+            await rpc_coro
+        except AbortError as e:
+            # Caught AbortError check if it is the same one
+            assert rpc_state.abort_exception is e, 'Abort error has been replaced!'
+            return
+        else:
+            # Check if the abort exception got suppressed
+            if rpc_state.abort_exception is not None:
+                _LOGGER.error(
+                    'Abort error unexpectedly suppressed: %s',
+                    traceback.format_exception(rpc_state.abort_exception)
+                )
+    except Exception as e:
+        _LOGGER.exception(e)
+        if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED:
+            await _perform_abort(
+                rpc_state,
+                StatusCode.unknown,
+                '%s: %s' % (type(e), e),
+                _EMPTY_METADATA,
+                loop
+            )
 
 
 async def _handle_cancellation_from_core(object rpc_task,
@@ -213,7 +318,11 @@ async def _schedule_rpc_coro(object rpc_coro,
                              RPCState rpc_state,
                              object loop):
     # Schedules the RPC coroutine.
-    cdef object rpc_task = loop.create_task(rpc_coro)
+    cdef object rpc_task = loop.create_task(_handle_exceptions(
+        rpc_state,
+        rpc_coro,
+        loop,
+    ))
     await _handle_cancellation_from_core(rpc_task, rpc_state, loop)
 
 
@@ -224,14 +333,23 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
         generic_handlers,
     )
     if method_handler is None:
-        # TODO(lidiz) return unimplemented error to client side
-        raise NotImplementedError()
+        await _perform_abort(
+            rpc_state,
+            StatusCode.unimplemented,
+            b'Method not found!',
+            _EMPTY_METADATA,
+            loop
+        )
+        return
 
     # TODO(lidiz) extend to all 4 types of RPC
     if not method_handler.request_streaming and method_handler.response_streaming:
-        await _handle_unary_stream_rpc(method_handler,
-                                       rpc_state,
-                                       loop)
+        try:
+            await _handle_unary_stream_rpc(method_handler,
+                                        rpc_state,
+                                        loop)
+        except Exception as e:
+            raise
     elif not method_handler.request_streaming and not method_handler.response_streaming:
         await _handle_unary_unary_rpc(method_handler,
                                       rpc_state,

+ 154 - 0
src/python/grpcio_tests/tests_aio/unit/abort_test.py

@@ -0,0 +1,154 @@
+# Copyright 2019 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import logging
+import unittest
+import time
+import gc
+
+import grpc
+from grpc.experimental import aio
+from tests_aio.unit._test_base import AioTestBase
+from tests.unit.framework.common import test_constants
+
+_UNARY_UNARY_ABORT = '/test/UnaryUnaryAbort'
+_SUPPRESS_ABORT = '/test/SuppressAbort'
+_REPLACE_ABORT = '/test/ReplaceAbort'
+_ABORT_AFTER_REPLY = '/test/AbortAfterReply'
+
+_REQUEST = b'\x00\x00\x00'
+_RESPONSE = b'\x01\x01\x01'
+_NUM_STREAM_RESPONSES = 5
+
+_ABORT_CODE = grpc.StatusCode.RESOURCE_EXHAUSTED
+_ABORT_DETAILS = 'Dummy error details'
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+    @staticmethod
+    async def _unary_unary_abort(unused_request, context):
+        await context.abort(_ABORT_CODE, _ABORT_DETAILS)
+        raise RuntimeError('This line should not be executed')
+
+    @staticmethod
+    async def _suppress_abort(unused_request, context):
+        try:
+            await context.abort(_ABORT_CODE, _ABORT_DETAILS)
+        except Exception as e:
+            pass
+        return _RESPONSE
+
+    @staticmethod
+    async def _replace_abort(unused_request, context):
+        try:
+            await context.abort(_ABORT_CODE, _ABORT_DETAILS)
+        except Exception as e:
+            await context.abort(grpc.StatusCode.INVALID_ARGUMENT,
+                                'Override abort!')
+
+    @staticmethod
+    async def _abort_after_reply(unused_request, context):
+        yield _RESPONSE
+        await context.abort(_ABORT_CODE, _ABORT_DETAILS)
+        raise RuntimeError('This line should not be executed')
+
+    def service(self, handler_details):
+        if handler_details.method == _UNARY_UNARY_ABORT:
+            return grpc.unary_unary_rpc_method_handler(self._unary_unary_abort)
+        if handler_details.method == _SUPPRESS_ABORT:
+            return grpc.unary_unary_rpc_method_handler(self._suppress_abort)
+        if handler_details.method == _REPLACE_ABORT:
+            return grpc.unary_unary_rpc_method_handler(self._replace_abort)
+        if handler_details.method == _ABORT_AFTER_REPLY:
+            return grpc.unary_stream_rpc_method_handler(self._abort_after_reply)
+
+
+async def _start_test_server():
+    server = aio.server()
+    port = server.add_insecure_port('[::]:0')
+    server.add_generic_rpc_handlers((_GenericHandler(),))
+    await server.start()
+    return 'localhost:%d' % port, server
+
+
+class TestServer(AioTestBase):
+
+    async def setUp(self):
+        address, self._server = await _start_test_server()
+        self._channel = aio.insecure_channel(address)
+
+    async def tearDown(self):
+        await self._channel.close()
+        await self._server.stop(None)
+
+    async def test_unary_unary_abort(self):
+        method = self._channel.unary_unary(_UNARY_UNARY_ABORT)
+        call = method(_REQUEST)
+
+        self.assertEqual(_ABORT_CODE, await call.code())
+        self.assertEqual(_ABORT_DETAILS, await call.details())
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            await call
+
+        rpc_error = exception_context.exception
+        rpc_error.code()
+        self.assertEqual(_ABORT_CODE, rpc_error.code())
+        self.assertEqual(_ABORT_DETAILS, rpc_error.details())
+
+    async def test_suppress_abort(self):
+        method = self._channel.unary_unary(_SUPPRESS_ABORT)
+        call = method(_REQUEST)
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            await call
+
+        rpc_error = exception_context.exception
+        rpc_error.code()
+        self.assertEqual(_ABORT_CODE, rpc_error.code())
+        self.assertEqual(_ABORT_DETAILS, rpc_error.details())
+
+    async def test_replace_abort(self):
+        method = self._channel.unary_unary(_REPLACE_ABORT)
+        call = method(_REQUEST)
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            await call
+
+        rpc_error = exception_context.exception
+        rpc_error.code()
+        self.assertEqual(_ABORT_CODE, rpc_error.code())
+        self.assertEqual(_ABORT_DETAILS, rpc_error.details())
+
+    async def test_abort_after_reply(self):
+        method = self._channel.unary_stream(_ABORT_AFTER_REPLY)
+        call = method(_REQUEST)
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            await call.read()
+
+        rpc_error = exception_context.exception
+        rpc_error.code()
+        self.assertEqual(_ABORT_CODE, rpc_error.code())
+        self.assertEqual(_ABORT_DETAILS, rpc_error.details())
+
+        self.assertEqual(_ABORT_CODE, await call.code())
+        self.assertEqual(_ABORT_DETAILS, await call.details())
+
+
+if __name__ == '__main__':
+    logging.basicConfig(level=logging.DEBUG)
+    unittest.main(verbosity=2)