|  | @@ -11,17 +11,23 @@
 | 
	
		
			
				|  |  |  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
	
		
			
				|  |  |  # See the License for the specific language governing permissions and
 | 
	
		
			
				|  |  |  # limitations under the License.
 | 
	
		
			
				|  |  | +"""Test the functionality of server interceptors."""
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import functools
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  |  import unittest
 | 
	
		
			
				|  |  | -from typing import Callable, Awaitable, Any
 | 
	
		
			
				|  |  | +from typing import Any, Awaitable, Callable, Tuple
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import grpc
 | 
	
		
			
				|  |  | +from grpc.experimental import aio, wrap_server_method_handler
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from grpc.experimental import aio
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -from tests_aio.unit._test_server import start_test_server
 | 
	
		
			
				|  |  | +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
 | 
	
		
			
				|  |  |  from tests_aio.unit._test_base import AioTestBase
 | 
	
		
			
				|  |  | -from src.proto.grpc.testing import messages_pb2
 | 
	
		
			
				|  |  | +from tests_aio.unit._test_server import start_test_server
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +_NUM_STREAM_RESPONSES = 5
 | 
	
		
			
				|  |  | +_REQUEST_PAYLOAD_SIZE = 7
 | 
	
		
			
				|  |  | +_RESPONSE_PAYLOAD_SIZE = 42
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class _LoggingInterceptor(aio.ServerInterceptor):
 | 
	
	
		
			
				|  | @@ -73,6 +79,18 @@ def _filter_server_interceptor(condition: Callable,
 | 
	
		
			
				|  |  |      return _GenericInterceptor(intercept_service)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +async def _create_server_stub_pair(
 | 
	
		
			
				|  |  | +        *interceptors: aio.ServerInterceptor
 | 
	
		
			
				|  |  | +) -> Tuple[aio.Server, test_pb2_grpc.TestServiceStub]:
 | 
	
		
			
				|  |  | +    """Creates a server-stub pair with given interceptors.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    Returning the server object to protect it from being garbage collected.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +    server_target, server = await start_test_server(interceptors=interceptors)
 | 
	
		
			
				|  |  | +    channel = aio.insecure_channel(server_target)
 | 
	
		
			
				|  |  | +    return server, test_pb2_grpc.TestServiceStub(channel)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class TestServerInterceptor(AioTestBase):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      async def test_invalid_interceptor(self):
 | 
	
	
		
			
				|  | @@ -162,6 +180,135 @@ class TestServerInterceptor(AioTestBase):
 | 
	
		
			
				|  |  |                  'log2:intercept_service',
 | 
	
		
			
				|  |  |              ], record)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    async def test_response_caching(self):
 | 
	
		
			
				|  |  | +        # Prepares a preset value to help testing
 | 
	
		
			
				|  |  | +        cache_store = {
 | 
	
		
			
				|  |  | +            42:
 | 
	
		
			
				|  |  | +                messages_pb2.SimpleResponse(payload=messages_pb2.Payload(
 | 
	
		
			
				|  |  | +                    body=b'\x42'))
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        async def intercept_and_cache(
 | 
	
		
			
				|  |  | +                continuation: Callable[[grpc.HandlerCallDetails], Awaitable[
 | 
	
		
			
				|  |  | +                    grpc.RpcMethodHandler]],
 | 
	
		
			
				|  |  | +                handler_call_details: grpc.HandlerCallDetails
 | 
	
		
			
				|  |  | +        ) -> grpc.RpcMethodHandler:
 | 
	
		
			
				|  |  | +            # Get the actual handler
 | 
	
		
			
				|  |  | +            handler = await continuation(handler_call_details)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            def wrap_handler(handler: grpc.RpcMethodHandler):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                @functools.wraps(handler)
 | 
	
		
			
				|  |  | +                async def wrapper(request: messages_pb2.SimpleRequest,
 | 
	
		
			
				|  |  | +                                  context: aio.ServicerContext):
 | 
	
		
			
				|  |  | +                    if request.response_size not in cache_store:
 | 
	
		
			
				|  |  | +                        cache_store[request.response_size] = await handler(
 | 
	
		
			
				|  |  | +                            request, context)
 | 
	
		
			
				|  |  | +                    return cache_store[request.response_size]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                return wrapper
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            return wrap_server_method_handler(wrap_handler, handler)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Constructs a server with the cache interceptor
 | 
	
		
			
				|  |  | +        server, stub = await _create_server_stub_pair(
 | 
	
		
			
				|  |  | +            _GenericInterceptor(intercept_and_cache))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Tests if the cache store is used
 | 
	
		
			
				|  |  | +        response = await stub.UnaryCall(
 | 
	
		
			
				|  |  | +            messages_pb2.SimpleRequest(response_size=42))
 | 
	
		
			
				|  |  | +        self.assertEqual(1, len(cache_store[42].payload.body))
 | 
	
		
			
				|  |  | +        self.assertEqual(cache_store[42], response)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Tests response can be cached
 | 
	
		
			
				|  |  | +        response = await stub.UnaryCall(
 | 
	
		
			
				|  |  | +            messages_pb2.SimpleRequest(response_size=1337))
 | 
	
		
			
				|  |  | +        self.assertEqual(1337, len(cache_store[1337].payload.body))
 | 
	
		
			
				|  |  | +        self.assertEqual(cache_store[1337], response)
 | 
	
		
			
				|  |  | +        response = await stub.UnaryCall(
 | 
	
		
			
				|  |  | +            messages_pb2.SimpleRequest(response_size=1337))
 | 
	
		
			
				|  |  | +        self.assertEqual(cache_store[1337], response)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def test_interceptor_unary_stream(self):
 | 
	
		
			
				|  |  | +        record = []
 | 
	
		
			
				|  |  | +        server, stub = await _create_server_stub_pair(
 | 
	
		
			
				|  |  | +            _LoggingInterceptor('log_unary_stream', record))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Prepares the request
 | 
	
		
			
				|  |  | +        request = messages_pb2.StreamingOutputCallRequest()
 | 
	
		
			
				|  |  | +        for _ in range(_NUM_STREAM_RESPONSES):
 | 
	
		
			
				|  |  | +            request.response_parameters.append(
 | 
	
		
			
				|  |  | +                messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Tests if the cache store is used
 | 
	
		
			
				|  |  | +        call = stub.StreamingOutputCall(request)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Ensures the RPC goes fine
 | 
	
		
			
				|  |  | +        async for response in call:
 | 
	
		
			
				|  |  | +            self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
 | 
	
		
			
				|  |  | +        self.assertEqual(await call.code(), grpc.StatusCode.OK)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertSequenceEqual([
 | 
	
		
			
				|  |  | +            'log_unary_stream:intercept_service',
 | 
	
		
			
				|  |  | +        ], record)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def test_interceptor_stream_unary(self):
 | 
	
		
			
				|  |  | +        record = []
 | 
	
		
			
				|  |  | +        server, stub = await _create_server_stub_pair(
 | 
	
		
			
				|  |  | +            _LoggingInterceptor('log_stream_unary', record))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Invokes the actual RPC
 | 
	
		
			
				|  |  | +        call = stub.StreamingInputCall()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Prepares the request
 | 
	
		
			
				|  |  | +        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
 | 
	
		
			
				|  |  | +        request = messages_pb2.StreamingInputCallRequest(payload=payload)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Sends out requests
 | 
	
		
			
				|  |  | +        for _ in range(_NUM_STREAM_RESPONSES):
 | 
	
		
			
				|  |  | +            await call.write(request)
 | 
	
		
			
				|  |  | +        await call.done_writing()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Validates the responses
 | 
	
		
			
				|  |  | +        response = await call
 | 
	
		
			
				|  |  | +        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
 | 
	
		
			
				|  |  | +        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
 | 
	
		
			
				|  |  | +                         response.aggregated_payload_size)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertEqual(await call.code(), grpc.StatusCode.OK)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertSequenceEqual([
 | 
	
		
			
				|  |  | +            'log_stream_unary:intercept_service',
 | 
	
		
			
				|  |  | +        ], record)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def test_interceptor_stream_stream(self):
 | 
	
		
			
				|  |  | +        record = []
 | 
	
		
			
				|  |  | +        server, stub = await _create_server_stub_pair(
 | 
	
		
			
				|  |  | +            _LoggingInterceptor('log_stream_stream', record))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Prepares the request
 | 
	
		
			
				|  |  | +        payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
 | 
	
		
			
				|  |  | +        request = messages_pb2.StreamingInputCallRequest(payload=payload)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        async def gen():
 | 
	
		
			
				|  |  | +            for _ in range(_NUM_STREAM_RESPONSES):
 | 
	
		
			
				|  |  | +                yield request
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Invokes the actual RPC
 | 
	
		
			
				|  |  | +        call = stub.StreamingInputCall(gen())
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        # Validates the responses
 | 
	
		
			
				|  |  | +        response = await call
 | 
	
		
			
				|  |  | +        self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
 | 
	
		
			
				|  |  | +        self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
 | 
	
		
			
				|  |  | +                         response.aggregated_payload_size)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertEqual(await call.code(), grpc.StatusCode.OK)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.assertSequenceEqual([
 | 
	
		
			
				|  |  | +            'log_stream_stream:intercept_service',
 | 
	
		
			
				|  |  | +        ], record)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  if __name__ == '__main__':
 | 
	
		
			
				|  |  |      logging.basicConfig(level=logging.DEBUG)
 |