Ver código fonte

Adopt reviewer's advice

Richard Belleville 6 anos atrás
pai
commit
f7249fcd3a

+ 9 - 15
src/python/grpcio/grpc/_channel.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 """Invocation-side implementation of gRPC Python."""
 
+import copy
 import functools
 import logging
 import sys
@@ -266,27 +267,20 @@ def _rpc_state_string(class_name, rpc_state):
 class _RpcError(grpc.RpcError, grpc.Call, grpc.Future):
     """An RPC error not tied to the execution of a particular RPC.
 
-    The state passed to _RpcError must be guaranteed not to be accessed by any
-    other threads.
-
-    The RPC represented by the state object must not be in-progress.
+    The RPC represented by the state object must not be in-progress or
+    cancelled.
 
     Attributes:
       _state: An instance of _RPCState.
     """
 
     def __init__(self, state):
-        if state.cancelled:
-            raise ValueError(
-                "Cannot instantiate an _RpcError for a cancelled RPC.")
-        if state.code is grpc.StatusCode.OK:
-            raise ValueError(
-                "Cannot instantiate an _RpcError for a successfully completed RPC."
-            )
-        if state.code is None:
-            raise ValueError(
-                "Cannot instantiate an _RpcError for an incomplete RPC.")
-        self._state = state
+        with state.condition:
+            self._state = _RPCState((), copy.deepcopy(state.initial_metadata),
+                                    copy.deepcopy(state.trailing_metadata),
+                                    state.code, copy.deepcopy(state.details))
+            self._state.response = copy.copy(state.response)
+            self._state.debug_error_string = copy.copy(state.debug_error_string)
 
     def initial_metadata(self):
         return self._state.initial_metadata

+ 44 - 10
src/python/grpcio_tests/tests/unit/_interceptor_test.py

@@ -40,6 +40,10 @@ _STREAM_UNARY = '/test/StreamUnary'
 _STREAM_STREAM = '/test/StreamStream'
 
 
+class _ApplicationErrorStandin(Exception):
+    pass
+
+
 class _Callback(object):
 
     def __init__(self):
@@ -73,12 +77,12 @@ class _Handler(object):
                 'testvalue',
             ),))
         if request == _EXCEPTION_REQUEST:
-            raise RuntimeError()
+            raise _ApplicationErrorStandin()
         return request
 
     def handle_unary_stream(self, request, servicer_context):
         if request == _EXCEPTION_REQUEST:
-            raise RuntimeError()
+            raise _ApplicationErrorStandin()
         for _ in range(test_constants.STREAM_LENGTH):
             self._control.control()
             yield request
@@ -104,7 +108,7 @@ class _Handler(object):
                 'testvalue',
             ),))
         if _EXCEPTION_REQUEST in response_elements:
-            raise RuntimeError()
+            raise _ApplicationErrorStandin()
         return b''.join(response_elements)
 
     def handle_stream_stream(self, request_iterator, servicer_context):
@@ -116,7 +120,7 @@ class _Handler(object):
             ),))
         for request in request_iterator:
             if request == _EXCEPTION_REQUEST:
-                raise RuntimeError()
+                raise _ApplicationErrorStandin()
             self._control.control()
             yield request
         self._control.control()
@@ -245,10 +249,12 @@ class _LoggingInterceptor(
         result = continuation(client_call_details, request)
         assert isinstance(
             result,
-            grpc.Call), '{} is not an instance of grpc.Call'.format(result)
+            grpc.Call), '{} ({}) is not an instance of grpc.Call'.format(
+                result, type(result))
         assert isinstance(
             result,
-            grpc.Future), '{} is not an instance of grpc.Future'.format(result)
+            grpc.Future), '{} ({}) is not an instance of grpc.Future'.format(
+                result, type(result))
         return result
 
     def intercept_unary_stream(self, continuation, client_call_details,
@@ -476,11 +482,18 @@ class InterceptorTest(unittest.TestCase):
                                              'c2', self._record))
 
         multi_callable = _unary_unary_multi_callable(channel)
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(grpc.RpcError) as exception_context:
             multi_callable(
                 request,
                 metadata=(('test',
                            'InterceptedUnaryRequestBlockingUnaryResponse'),))
+        exception = exception_context.exception
+        self.assertFalse(exception.cancelled())
+        self.assertFalse(exception.running())
+        self.assertTrue(exception.done())
+        with self.assertRaises(grpc.RpcError):
+            exception.result()
+        self.assertIsInstance(exception.exception(), grpc.RpcError)
 
     def testInterceptedUnaryRequestBlockingUnaryResponseWithCall(self):
         request = b'\x07\x08'
@@ -561,8 +574,15 @@ class InterceptorTest(unittest.TestCase):
         response_iterator = multi_callable(
             request,
             metadata=(('test', 'InterceptedUnaryRequestStreamResponse'),))
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(grpc.RpcError) as exception_context:
             tuple(response_iterator)
+        exception = exception_context.exception
+        self.assertFalse(exception.cancelled())
+        self.assertFalse(exception.running())
+        self.assertTrue(exception.done())
+        with self.assertRaises(grpc.RpcError):
+            exception.result()
+        self.assertIsInstance(exception.exception(), grpc.RpcError)
 
     def testInterceptedStreamRequestBlockingUnaryResponse(self):
         requests = tuple(
@@ -650,8 +670,15 @@ class InterceptorTest(unittest.TestCase):
         response_future = multi_callable.future(
             request_iterator,
             metadata=(('test', 'InterceptedStreamRequestFutureUnaryResponse'),))
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(grpc.RpcError) as exception_context:
             response_future.result()
+        exception = exception_context.exception
+        self.assertFalse(exception.cancelled())
+        self.assertFalse(exception.running())
+        self.assertTrue(exception.done())
+        with self.assertRaises(grpc.RpcError):
+            exception.result()
+        self.assertIsInstance(exception.exception(), grpc.RpcError)
 
     def testInterceptedStreamRequestStreamResponse(self):
         requests = tuple(
@@ -692,8 +719,15 @@ class InterceptorTest(unittest.TestCase):
         response_iterator = multi_callable(
             request_iterator,
             metadata=(('test', 'InterceptedStreamRequestStreamResponse'),))
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(grpc.RpcError) as exception_context:
             tuple(response_iterator)
+        exception = exception_context.exception
+        self.assertFalse(exception.cancelled())
+        self.assertFalse(exception.running())
+        self.assertTrue(exception.done())
+        with self.assertRaises(grpc.RpcError):
+            exception.result()
+        self.assertIsInstance(exception.exception(), grpc.RpcError)
 
 
 if __name__ == '__main__':