|  | @@ -13,6 +13,7 @@
 | 
	
		
			
				|  |  |  # limitations under the License.
 | 
	
		
			
				|  |  |  """Test the functionality of server interceptors."""
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import asyncio
 | 
	
		
			
				|  |  |  import functools
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  |  import unittest
 | 
	
	
		
			
				|  | @@ -79,6 +80,43 @@ def _filter_server_interceptor(condition: Callable,
 | 
	
		
			
				|  |  |      return _GenericInterceptor(intercept_service)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class _CacheInterceptor(aio.ServerInterceptor):
 | 
	
		
			
				|  |  | +    """An interceptor that caches response based on request message."""
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, cache_store=None):
 | 
	
		
			
				|  |  | +        self.cache_store = cache_store or {}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def intercept_service(
 | 
	
		
			
				|  |  | +            self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
 | 
	
		
			
				|  |  | +                grpc.RpcMethodHandler]],
 | 
	
		
			
				|  |  | +            handler_call_details: grpc.HandlerCallDetails
 | 
	
		
			
				|  |  | +    ) -> grpc.RpcMethodHandler:
 | 
	
		
			
				|  |  | +        # Get the actual handler
 | 
	
		
			
				|  |  | +        handler = await continuation(handler_call_details)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Only intercept unary call RPCs
 | 
	
		
			
				|  |  | +        if handler and (handler.request_streaming or
 | 
	
		
			
				|  |  | +                        handler.response_streaming):
 | 
	
		
			
				|  |  | +            return handler
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        def wrapper(behavior: Callable[
 | 
	
		
			
				|  |  | +            [messages_pb2.SimpleRequest, aio.
 | 
	
		
			
				|  |  | +             ServicerContext], messages_pb2.SimpleResponse]):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            @functools.wraps(behavior)
 | 
	
		
			
				|  |  | +            async def wrapper(request: messages_pb2.SimpleRequest,
 | 
	
		
			
				|  |  | +                              context: aio.ServicerContext
 | 
	
		
			
				|  |  | +                             ) -> messages_pb2.SimpleResponse:
 | 
	
		
			
				|  |  | +                if request.response_size not in self.cache_store:
 | 
	
		
			
				|  |  | +                    self.cache_store[request.response_size] = await behavior(
 | 
	
		
			
				|  |  | +                        request, context)
 | 
	
		
			
				|  |  | +                return self.cache_store[request.response_size]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            return wrapper
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return wrap_server_method_handler(wrapper, handler)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  async def _create_server_stub_pair(
 | 
	
		
			
				|  |  |          *interceptors: aio.ServerInterceptor
 | 
	
		
			
				|  |  |  ) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]:
 | 
	
	
		
			
				|  | @@ -182,55 +220,29 @@ class TestServerInterceptor(AioTestBase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      async def test_response_caching(self):
 | 
	
		
			
				|  |  |          # Prepares a preset value to help testing
 | 
	
		
			
				|  |  | -        cache_store = {
 | 
	
		
			
				|  |  | +        interceptor = _CacheInterceptor({
 | 
	
		
			
				|  |  |              42:
 | 
	
		
			
				|  |  |                  messages_pb2.SimpleResponse(payload=messages_pb2.Payload(
 | 
	
		
			
				|  |  |                      body=b'\x42'))
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        async def intercept_and_cache(
 | 
	
		
			
				|  |  | -                continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
 | 
	
		
			
				|  |  | -                    grpc.RpcMethodHandler]],
 | 
	
		
			
				|  |  | -                handler_call_details: grpc.HandlerCallDetails
 | 
	
		
			
				|  |  | -        ) -> grpc.RpcMethodHandler:
 | 
	
		
			
				|  |  | -            # Get the actual handler
 | 
	
		
			
				|  |  | -            handler = await continuation(handler_call_details)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            def wrapper(behavior: Callable[
 | 
	
		
			
				|  |  | -                [messages_pb2.SimpleRequest, aio.
 | 
	
		
			
				|  |  | -                 ServerInterceptor], messages_pb2.SimpleResponse]):
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -                @functools.wraps(behavior)
 | 
	
		
			
				|  |  | -                async def wrapper(request: messages_pb2.SimpleRequest,
 | 
	
		
			
				|  |  | -                                  context: aio.ServicerContext
 | 
	
		
			
				|  |  | -                                 ) -> messages_pb2.SimpleResponse:
 | 
	
		
			
				|  |  | -                    if request.response_size not in cache_store:
 | 
	
		
			
				|  |  | -                        cache_store[request.response_size] = await behavior(
 | 
	
		
			
				|  |  | -                            request, context)
 | 
	
		
			
				|  |  | -                    return cache_store[request.response_size]
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -                return wrapper
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            return wrap_server_method_handler(wrapper, handler)
 | 
	
		
			
				|  |  | +        })
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Constructs a server with the cache interceptor
 | 
	
		
			
				|  |  | -        server, stub = await _create_server_stub_pair(
 | 
	
		
			
				|  |  | -            _GenericInterceptor(intercept_and_cache))
 | 
	
		
			
				|  |  | +        server, stub = await _create_server_stub_pair(interceptor)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Tests if the cache store is used
 | 
	
		
			
				|  |  |          response = await stub.UnaryCall(
 | 
	
		
			
				|  |  |              messages_pb2.SimpleRequest(response_size=42))
 | 
	
		
			
				|  |  | -        self.assertEqual(1, len(cache_store[42].payload.body))
 | 
	
		
			
				|  |  | -        self.assertEqual(cache_store[42], response)
 | 
	
		
			
				|  |  | +        self.assertEqual(1, len(interceptor.cache_store[42].payload.body))
 | 
	
		
			
				|  |  | +        self.assertEqual(interceptor.cache_store[42], response)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Tests response can be cached
 | 
	
		
			
				|  |  |          response = await stub.UnaryCall(
 | 
	
		
			
				|  |  |              messages_pb2.SimpleRequest(response_size=1337))
 | 
	
		
			
				|  |  | -        self.assertEqual(1337, len(cache_store[1337].payload.body))
 | 
	
		
			
				|  |  | -        self.assertEqual(cache_store[1337], response)
 | 
	
		
			
				|  |  | +        self.assertEqual(1337, len(interceptor.cache_store[1337].payload.body))
 | 
	
		
			
				|  |  | +        self.assertEqual(interceptor.cache_store[1337], response)
 | 
	
		
			
				|  |  |          response = await stub.UnaryCall(
 | 
	
		
			
				|  |  |              messages_pb2.SimpleRequest(response_size=1337))
 | 
	
		
			
				|  |  | -        self.assertEqual(cache_store[1337], response)
 | 
	
		
			
				|  |  | +        self.assertEqual(interceptor.cache_store[1337], response)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      async def test_interceptor_unary_stream(self):
 | 
	
		
			
				|  |  |          record = []
 |