Browse Source

fix metadata

Zhanghui Mao 5 years ago
parent
commit
6d556914d0

+ 2 - 0
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi

@@ -28,8 +28,10 @@ cdef class RPCState(GrpcCallWrapper):
     cdef object abort_exception
     cdef bint metadata_sent
     cdef bint status_sent
+    cdef tuple trailing_metadata
 
     cdef bytes method(self)
+    cdef tuple invocation_metadata(self)
 
 
 cdef enum AioServerStatus:

+ 13 - 6
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi

@@ -40,9 +40,13 @@ cdef class RPCState:
         self.abort_exception = None
         self.metadata_sent = False
         self.status_sent = False
+        self.trailing_metadata = tuple()
 
     cdef bytes method(self):
-      return _slice_bytes(self.details.method)
+        return _slice_bytes(self.details.method)
+
+    cdef tuple invocation_metadata(self):
+        return _metadata(&self.request_metadata)
 
     def __dealloc__(self):
         """Cleans the Core objects."""
@@ -146,8 +150,11 @@ cdef class _ServicerContext:
 
             raise self._rpc_state.abort_exception
 
+    def set_trailing_metadata(self, tuple metadata):
+        self._rpc_state.trailing_metadata = metadata
+
     def invocation_metadata(self):
-        return _metadata(&self._rpc_state.request_metadata)
+        return self._rpc_state.invocation_metadata()
 
 
 cdef _find_method_handler(str method, list generic_handlers):
@@ -192,10 +199,10 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
 
     # Assembles the batch operations
     cdef Operation send_status_op = SendStatusFromServerOperation(
-        tuple(),
-            StatusCode.ok,
-            b'',
-            _EMPTY_FLAGS,
+        rpc_state.trailing_metadata,
+        StatusCode.ok,
+        b'',
+        _EMPTY_FLAGS,
     )
     cdef tuple finish_ops
     if not rpc_state.metadata_sent:

+ 6 - 5
src/python/grpcio_tests/tests_aio/unit/_test_server.py

@@ -34,15 +34,16 @@ async def _maybe_echo_metadata(servicer_context):
         initial_metadatum = (_INITIAL_METADATA_KEY,
                              invocation_metadata[_INITIAL_METADATA_KEY])
         await servicer_context.send_initial_metadata((initial_metadatum,))
-    # if _TRAILING_METADATA_KEY in invocation_metadata:
-    #     trailing_metadatum = (_TRAILING_METADATA_KEY,
-    #                           invocation_metadata[_TRAILING_METADATA_KEY])
-    #     servicer_context.set_trailing_metadata((trailing_metadatum,))
+    if _TRAILING_METADATA_KEY in invocation_metadata:
+        trailing_metadatum = (_TRAILING_METADATA_KEY,
+                              invocation_metadata[_TRAILING_METADATA_KEY])
+        servicer_context.set_trailing_metadata((trailing_metadatum,))
 
 
 class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
 
-    async def UnaryCall(self, unused_request, unused_context):
+    async def UnaryCall(self, unused_request, context):
+        await _maybe_echo_metadata(context)
         return messages_pb2.SimpleResponse()
 
     async def StreamingOutputCall(

+ 4 - 0
src/python/grpcio_tests/tests_aio/unit/channel_test.py

@@ -112,8 +112,12 @@ class TestChannel(AioTestBase):
             call = hi(messages_pb2.SimpleRequest(),
                       metadata=_INVOCATION_METADATA)
             initial_metadata = await call.initial_metadata()
+            trailing_metadata = await call.trailing_metadata()
 
             self.assertIsInstance(initial_metadata, tuple)
+            self.assertEqual(_INVOCATION_METADATA[0], initial_metadata[0])
+            self.assertIsInstance(trailing_metadata, tuple)
+            self.assertEqual(_INVOCATION_METADATA[1], trailing_metadata[0])
 
     async def test_unary_stream(self):
         channel = aio.insecure_channel(self._server_target)