|  | @@ -30,6 +30,9 @@ _TEST_SERVER_TO_CLIENT = '/test/TestServerToClient'
 | 
	
		
			
				|  |  |  _TEST_TRAILING_METADATA = '/test/TestTrailingMetadata'
 | 
	
		
			
				|  |  |  _TEST_ECHO_INITIAL_METADATA = '/test/TestEchoInitialMetadata'
 | 
	
		
			
				|  |  |  _TEST_GENERIC_HANDLER = '/test/TestGenericHandler'
 | 
	
		
			
				|  |  | +_TEST_UNARY_STREAM = '/test/TestUnaryStream'
 | 
	
		
			
				|  |  | +_TEST_STREAM_UNARY = '/test/TestStreamUnary'
 | 
	
		
			
				|  |  | +_TEST_STREAM_STREAM = '/test/TestStreamStream'
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  _REQUEST = b'\x00\x00\x00'
 | 
	
		
			
				|  |  |  _RESPONSE = b'\x01\x01\x01'
 | 
	
	
		
			
				|  | @@ -72,6 +75,25 @@ _INVALID_METADATA_TEST_CASES = (
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def __init__(self):
 | 
	
		
			
				|  |  | +        self._routing_table = {
 | 
	
		
			
				|  |  | +            _TEST_CLIENT_TO_SERVER:
 | 
	
		
			
				|  |  | +                grpc.unary_unary_rpc_method_handler(self._test_client_to_server
 | 
	
		
			
				|  |  | +                                                   ),
 | 
	
		
			
				|  |  | +            _TEST_SERVER_TO_CLIENT:
 | 
	
		
			
				|  |  | +                grpc.unary_unary_rpc_method_handler(self._test_server_to_client
 | 
	
		
			
				|  |  | +                                                   ),
 | 
	
		
			
				|  |  | +            _TEST_TRAILING_METADATA:
 | 
	
		
			
				|  |  | +                grpc.unary_unary_rpc_method_handler(self._test_trailing_metadata
 | 
	
		
			
				|  |  | +                                                   ),
 | 
	
		
			
				|  |  | +            _TEST_UNARY_STREAM:
 | 
	
		
			
				|  |  | +                grpc.unary_stream_rpc_method_handler(self._test_unary_stream),
 | 
	
		
			
				|  |  | +            _TEST_STREAM_UNARY:
 | 
	
		
			
				|  |  | +                grpc.stream_unary_rpc_method_handler(self._test_stream_unary),
 | 
	
		
			
				|  |  | +            _TEST_STREAM_STREAM:
 | 
	
		
			
				|  |  | +                grpc.stream_stream_rpc_method_handler(self._test_stream_stream),
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  |      async def _test_client_to_server(request, context):
 | 
	
		
			
				|  |  |          assert _REQUEST == request
 | 
	
	
		
			
				|  | @@ -92,17 +114,44 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler):
 | 
	
		
			
				|  |  |          context.set_trailing_metadata(_TRAILING_METADATA)
 | 
	
		
			
				|  |  |          return _RESPONSE
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def service(self, handler_details):
 | 
	
		
			
				|  |  | -        if handler_details.method == _TEST_CLIENT_TO_SERVER:
 | 
	
		
			
				|  |  | -            return grpc.unary_unary_rpc_method_handler(
 | 
	
		
			
				|  |  | -                self._test_client_to_server)
 | 
	
		
			
				|  |  | -        if handler_details.method == _TEST_SERVER_TO_CLIENT:
 | 
	
		
			
				|  |  | -            return grpc.unary_unary_rpc_method_handler(
 | 
	
		
			
				|  |  | -                self._test_server_to_client)
 | 
	
		
			
				|  |  | -        if handler_details.method == _TEST_TRAILING_METADATA:
 | 
	
		
			
				|  |  | -            return grpc.unary_unary_rpc_method_handler(
 | 
	
		
			
				|  |  | -                self._test_trailing_metadata)
 | 
	
		
			
				|  |  | -        return None
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    async def _test_unary_stream(request, context):
 | 
	
		
			
				|  |  | +        assert _REQUEST == request
 | 
	
		
			
				|  |  | +        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
 | 
	
		
			
				|  |  | +                                     context.invocation_metadata())
 | 
	
		
			
				|  |  | +        await context.send_initial_metadata(
 | 
	
		
			
				|  |  | +            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
 | 
	
		
			
				|  |  | +        yield _RESPONSE
 | 
	
		
			
				|  |  | +        context.set_trailing_metadata(_TRAILING_METADATA)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    async def _test_stream_unary(request_iterator, context):
 | 
	
		
			
				|  |  | +        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
 | 
	
		
			
				|  |  | +                                     context.invocation_metadata())
 | 
	
		
			
				|  |  | +        await context.send_initial_metadata(
 | 
	
		
			
				|  |  | +            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        async for request in request_iterator:
 | 
	
		
			
				|  |  | +            assert _REQUEST == request
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        context.set_trailing_metadata(_TRAILING_METADATA)
 | 
	
		
			
				|  |  | +        return _RESPONSE
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @staticmethod
 | 
	
		
			
				|  |  | +    async def _test_stream_stream(request_iterator, context):
 | 
	
		
			
				|  |  | +        assert _common.seen_metadata(_INITIAL_METADATA_FROM_CLIENT_TO_SERVER,
 | 
	
		
			
				|  |  | +                                     context.invocation_metadata())
 | 
	
		
			
				|  |  | +        await context.send_initial_metadata(
 | 
	
		
			
				|  |  | +            _INITIAL_METADATA_FROM_SERVER_TO_CLIENT)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        async for request in request_iterator:
 | 
	
		
			
				|  |  | +            assert _REQUEST == request
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        yield _RESPONSE
 | 
	
		
			
				|  |  | +        context.set_trailing_metadata(_TRAILING_METADATA)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def service(self, handler_call_details):
 | 
	
		
			
				|  |  | +        return self._routing_table.get(handler_call_details.method)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class _TestGenericHandlerItself(grpc.GenericRpcHandler):
 | 
	
	
		
			
				|  | @@ -112,9 +161,9 @@ class _TestGenericHandlerItself(grpc.GenericRpcHandler):
 | 
	
		
			
				|  |  |          assert _REQUEST == request
 | 
	
		
			
				|  |  |          return _RESPONSE
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def service(self, handler_details):
 | 
	
		
			
				|  |  | +    def service(self, handler_call_details):
 | 
	
		
			
				|  |  |          assert _common.seen_metadata(_INITIAL_METADATA_FOR_GENERIC_HANDLER,
 | 
	
		
			
				|  |  | -                                     handler_details.invocation_metadata)
 | 
	
		
			
				|  |  | +                                     handler_call_details.invocation_metadata)
 | 
	
		
			
				|  |  |          return grpc.unary_unary_rpc_method_handler(self._method)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -164,9 +213,10 @@ class TestMetadata(AioTestBase):
 | 
	
		
			
				|  |  |      async def test_invalid_metadata(self):
 | 
	
		
			
				|  |  |          multicallable = self._client.unary_unary(_TEST_CLIENT_TO_SERVER)
 | 
	
		
			
				|  |  |          for exception_type, metadata in _INVALID_METADATA_TEST_CASES:
 | 
	
		
			
				|  |  | -            call = multicallable(_REQUEST, metadata=metadata)
 | 
	
		
			
				|  |  | -            with self.assertRaises(exception_type):
 | 
	
		
			
				|  |  | -                await call
 | 
	
		
			
				|  |  | +            with self.subTest(metadata=metadata):
 | 
	
		
			
				|  |  | +                call = multicallable(_REQUEST, metadata=metadata)
 | 
	
		
			
				|  |  | +                with self.assertRaises(exception_type):
 | 
	
		
			
				|  |  | +                    await call
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      async def test_generic_handler(self):
 | 
	
		
			
				|  |  |          multicallable = self._client.unary_unary(_TEST_GENERIC_HANDLER)
 | 
	
	
		
			
				|  | @@ -175,6 +225,49 @@ class TestMetadata(AioTestBase):
 | 
	
		
			
				|  |  |          self.assertEqual(_RESPONSE, await call)
 | 
	
		
			
				|  |  |          self.assertEqual(grpc.StatusCode.OK, await call.code())
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    async def test_unary_stream(self):
 | 
	
		
			
				|  |  | +        multicallable = self._client.unary_stream(_TEST_UNARY_STREAM)
 | 
	
		
			
				|  |  | +        call = multicallable(_REQUEST,
 | 
	
		
			
				|  |  | +                             metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertTrue(
 | 
	
		
			
				|  |  | +            _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
 | 
	
		
			
				|  |  | +                                  call.initial_metadata()))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertSequenceEqual([_RESPONSE],
 | 
	
		
			
				|  |  | +                                 [request async for request in call])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
 | 
	
		
			
				|  |  | +        self.assertEqual(grpc.StatusCode.OK, await call.code())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def test_stream_unary(self):
 | 
	
		
			
				|  |  | +        multicallable = self._client.stream_unary(_TEST_STREAM_UNARY)
 | 
	
		
			
				|  |  | +        call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
 | 
	
		
			
				|  |  | +        await call.write(_REQUEST)
 | 
	
		
			
				|  |  | +        await call.done_writing()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertTrue(
 | 
	
		
			
				|  |  | +            _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
 | 
	
		
			
				|  |  | +                                  call.initial_metadata()))
 | 
	
		
			
				|  |  | +        self.assertEqual(_RESPONSE, await call)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
 | 
	
		
			
				|  |  | +        self.assertEqual(grpc.StatusCode.OK, await call.code())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def test_stream_stream(self):
 | 
	
		
			
				|  |  | +        multicallable = self._client.stream_stream(_TEST_STREAM_STREAM)
 | 
	
		
			
				|  |  | +        call = multicallable(metadata=_INITIAL_METADATA_FROM_CLIENT_TO_SERVER)
 | 
	
		
			
				|  |  | +        await call.write(_REQUEST)
 | 
	
		
			
				|  |  | +        await call.done_writing()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertTrue(
 | 
	
		
			
				|  |  | +            _common.seen_metadata(_INITIAL_METADATA_FROM_SERVER_TO_CLIENT, await
 | 
	
		
			
				|  |  | +                                  call.initial_metadata()))
 | 
	
		
			
				|  |  | +        self.assertSequenceEqual([_RESPONSE],
 | 
	
		
			
				|  |  | +                                 [request async for request in call])
 | 
	
		
			
				|  |  | +        self.assertEqual(_TRAILING_METADATA, await call.trailing_metadata())
 | 
	
		
			
				|  |  | +        self.assertEqual(grpc.StatusCode.OK, await call.code())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  if __name__ == '__main__':
 | 
	
		
			
				|  |  |      logging.basicConfig(level=logging.DEBUG)
 |