Bladeren bron

Fix a bug that prevents metadata modification in interceptors

Lidi Zheng 6 jaren geleden
bovenliggende
commit
435cf89108

+ 5 - 0
src/python/grpcio/grpc/_cython/_cygrpc/metadata.pyx.pxi

@@ -41,6 +41,11 @@ cdef void _store_c_metadata(
       for index, (key, value) in enumerate(metadata):
         encoded_key = _encode(key)
         encoded_value = value if encoded_key[-4:] == b'-bin' else _encode(value)
+        if type(encoded_value) != bytes:
+          raise TypeError('Binary metadata key="%s" expected bytes, got %s' % (
+            key,
+            type(encoded_value)
+          ))
         c_metadata[0][index].key = _slice_from_bytes(encoded_key)
         c_metadata[0][index].value = _slice_from_bytes(encoded_value)
 

+ 2 - 1
src/python/grpcio/grpc/experimental/aio/_interceptor.py

@@ -150,7 +150,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
             else:
                 return UnaryUnaryCall(
                     request, _timeout_to_deadline(client_call_details.timeout),
-                    metadata, client_call_details.credentials, self._channel,
+                    client_call_details.metadata,
+                    client_call_details.credentials, self._channel,
                     client_call_details.method, request_serializer,
                     response_deserializer)
 

+ 7 - 0
src/python/grpcio_tests/tests_aio/unit/BUILD.bazel

@@ -43,6 +43,12 @@ py_library(
     srcs_version = "PY3",
 )
 
+py_library(
+    name = "_common",
+    srcs = ["_common.py"],
+    srcs_version = "PY3",
+)
+
 [
     py_test(
         name = test_file_name[:-3],
@@ -55,6 +61,7 @@ py_library(
         main = test_file_name,
         python_version = "PY3",
         deps = [
+            ":_common",
             ":_constants",
             ":_test_base",
             ":_test_server",

+ 24 - 0
src/python/grpcio_tests/tests_aio/unit/_common.py

@@ -0,0 +1,24 @@
+# Copyright 2020 The gRPC Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def seen_metadata(expected, actual):
+    metadata_dict = dict(actual)
+    if type(expected[0]) != tuple:
+        return metadata_dict.get(expected[0]) == expected[1]
+    else:
+        for metadatum in expected:
+            if metadata_dict.get(metadatum[0]) != metadatum[1]:
+                return False
+        return True

+ 53 - 8
src/python/grpcio_tests/tests_aio/unit/interceptor_test.py

@@ -18,11 +18,18 @@ import unittest
 import grpc
 
 from grpc.experimental import aio
-from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
+from tests_aio.unit._test_server import start_test_server, _INITIAL_METADATA_KEY, _TRAILING_METADATA_KEY
+from tests_aio.unit import _constants
+from tests_aio.unit import _common
 from tests_aio.unit._test_base import AioTestBase
-from src.proto.grpc.testing import messages_pb2
+from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
+
 
 _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
+_INITIAL_METADATA_TO_INJECT = (
+    (_INITIAL_METADATA_KEY, 'extra info'),
+    (_TRAILING_METADATA_KEY, b'\x13\x37'),
+)
 
 
 class TestUnaryUnaryClientInterceptor(AioTestBase):
@@ -124,7 +131,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
                                             client_call_details, request):
                 new_client_call_details = aio.ClientCallDetails(
                     method=client_call_details.method,
-                    timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
+                    timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
                     metadata=client_call_details.metadata,
                     credentials=client_call_details.credentials)
                 return await continuation(new_client_call_details, request)
@@ -165,7 +172,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
 
                 new_client_call_details = aio.ClientCallDetails(
                     method=client_call_details.method,
-                    timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
+                    timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2,
                     metadata=client_call_details.metadata,
                     credentials=client_call_details.credentials)
 
@@ -342,8 +349,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
                 response_deserializer=messages_pb2.SimpleResponse.FromString)
 
-            call = multicallable(messages_pb2.SimpleRequest(),
-                                 timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
+            call = multicallable(
+                messages_pb2.SimpleRequest(),
+                timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
 
             with self.assertRaises(aio.AioRpcError) as exception_context:
                 await call
@@ -375,8 +383,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
                 request_serializer=messages_pb2.SimpleRequest.SerializeToString,
                 response_deserializer=messages_pb2.SimpleResponse.FromString)
 
-            call = multicallable(messages_pb2.SimpleRequest(),
-                                 timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
+            call = multicallable(
+                messages_pb2.SimpleRequest(),
+                timeout=_constants.UNARY_CALL_WITH_SLEEP_VALUE / 2)
 
             with self.assertRaises(aio.AioRpcError) as exception_context:
                 await call
@@ -532,6 +541,42 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
             self.assertEqual(await call.initial_metadata(), tuple())
             self.assertEqual(await call.trailing_metadata(), None)
 
+    async def test_initial_metadata_modification(self):
+
+        class Interceptor(aio.UnaryUnaryClientInterceptor):
+
+            async def intercept_unary_unary(self, continuation,
+                                            client_call_details, request):
+                if client_call_details.metadata is not None:
+                    new_metadata = client_call_details.metadata + _INITIAL_METADATA_TO_INJECT
+                else:
+                    new_metadata = _INITIAL_METADATA_TO_INJECT
+                new_details = aio.ClientCallDetails(
+                    method=client_call_details.method,
+                    timeout=client_call_details.timeout,
+                    metadata=new_metadata,
+                    credentials=client_call_details.credentials,
+                )
+                return await continuation(new_details, request)
+
+        async with aio.insecure_channel(self._server_target,
+                                        interceptors=[Interceptor()
+                                                     ]) as channel:
+            stub = test_pb2_grpc.TestServiceStub(channel)
+            call = stub.UnaryCall(messages_pb2.SimpleRequest())
+
+            # Expected to see the echoed initial metadata
+            self.assertTrue(
+                _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[0], await
+                                      call.initial_metadata()))
+
+            # Expected to see the echoed trailing metadata
+            self.assertTrue(
+                _common.seen_metadata(_INITIAL_METADATA_TO_INJECT[1], await
+                                      call.trailing_metadata()))
+
+            self.assertEqual(await call.code(), grpc.StatusCode.OK)
+
 
 if __name__ == '__main__':
     logging.basicConfig()

+ 5 - 12
src/python/grpcio_tests/tests_aio/unit/metadata_test.py

@@ -23,6 +23,7 @@ import grpc
 from grpc.experimental import aio
 
 from tests_aio.unit._test_base import AioTestBase
+from tests_aio.unit import _common
 
 _TEST_CLIENT_TO_SERVER = '/test/TestClientToServer'
 _TEST_SERVER_TO_CLIENT = '/test/TestServerToClient'
@@ -69,21 +70,13 @@ _INVALID_METADATA_TEST_CASES = (
 )
 
 
-def _seen_metadata(expected, actual):
-    metadata_dict = dict(actual)
-    for metadatum in expected:
-        if metadata_dict.get(metadatum[0]) != metadatum[1]:
-            return False
-    return True
-
-
 class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
 
     @staticmethod
     async def _test_client_to_server(request, context):
         assert _REQUEST == request
-        assert _seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
-                              context.invocation_metadata())
+        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
+                             context.invocation_metadata())
         return _RESPONSE
 
     @staticmethod
@@ -120,8 +113,8 @@ class _TestGenericHandlerItself(grpc.GenericRpcHandler):
         return _RESPONSE
 
     def service(self, handler_details):
-        assert _seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
-                              handler_details.invocation_metadata)
+        assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
+                             handler_details.invocation_metadata)
         return grpc.unary_unary_rpc_method_handler(self._method)