| 
					
				 | 
			
			
				@@ -0,0 +1,168 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Copyright 2020 The gRPC Authors 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Licensed under the Apache License, Version 2.0 (the "License"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# you may not use this file except in compliance with the License. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# You may obtain a copy of the License at 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     http://www.apache.org/licenses/LICENSE-2.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Unless required by applicable law or agreed to in writing, software 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# distributed under the License is distributed on an "AS IS" BASIS, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# See the License for the specific language governing permissions and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# limitations under the License. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import unittest 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from typing import Callable, Awaitable, Any 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import grpc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from grpc.experimental import aio 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from tests_aio.unit._test_server import start_test_server 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from tests_aio.unit._test_base import AioTestBase 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from src.proto.grpc.testing import messages_pb2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class _LoggingInterceptor(aio.ServerInterceptor): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, tag: str, record: list) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.tag = tag 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.record = record 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    async def intercept_service( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                grpc.RpcMethodHandler]], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            handler_call_details: grpc.HandlerCallDetails 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) -> grpc.RpcMethodHandler: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.record.append(self.tag + ':intercept_service') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return await continuation(handler_call_details) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class _GenericInterceptor(aio.ServerInterceptor): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, fn: Callable[[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Callable[[grpc.HandlerCallDetails], Awaitable[grpc. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                          RpcMethodHandler]], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            grpc.HandlerCallDetails 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ], Any]) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self._fn = fn 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    async def intercept_service( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self, continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                grpc.RpcMethodHandler]], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            handler_call_details: grpc.HandlerCallDetails 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) -> grpc.RpcMethodHandler: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return await self._fn(continuation, handler_call_details) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def _filter_server_interceptor(condition: Callable, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                               interceptor: aio.ServerInterceptor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                              ) -> aio.ServerInterceptor: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    async def intercept_service( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            continuation: Callable[[grpc.HandlerCallDetails], Awaitable[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                grpc.RpcMethodHandler]], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            handler_call_details: grpc.HandlerCallDetails 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) -> grpc.RpcMethodHandler: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if condition(handler_call_details): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return await interceptor.intercept_service(continuation, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                                       handler_call_details) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return await continuation(handler_call_details) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return _GenericInterceptor(intercept_service) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class TestServerInterceptor(AioTestBase): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    async def test_invalid_interceptor(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        class InvalidInterceptor: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            """Just an invalid Interceptor""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with self.assertRaises(ValueError): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            server_target, _ = await start_test_server( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                interceptors=(InvalidInterceptor(),)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    async def test_executed_right_order(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        record = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server_target, _ = await start_test_server(interceptors=( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            _LoggingInterceptor('log1', record), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            _LoggingInterceptor('log2', record), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        async with aio.insecure_channel(server_target) as channel: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            multicallable = channel.unary_unary( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                '/grpc.testing.TestService/UnaryCall', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                request_serializer=messages_pb2.SimpleRequest.SerializeToString, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                response_deserializer=messages_pb2.SimpleResponse.FromString) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            call = multicallable(messages_pb2.SimpleRequest()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = await call 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Check that all interceptors were executed, and were executed 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # in the right order. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertSequenceEqual([ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'log1:intercept_service', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'log2:intercept_service', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], record) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertIsInstance(response, messages_pb2.SimpleResponse) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    async def test_response_ok(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        record = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server_target, _ = await start_test_server( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            interceptors=(_LoggingInterceptor('log1', record),)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        async with aio.insecure_channel(server_target) as channel: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            multicallable = channel.unary_unary( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                '/grpc.testing.TestService/UnaryCall', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                request_serializer=messages_pb2.SimpleRequest.SerializeToString, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                response_deserializer=messages_pb2.SimpleResponse.FromString) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            call = multicallable(messages_pb2.SimpleRequest()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = await call 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            code = await call.code() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertSequenceEqual(['log1:intercept_service'], record) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertIsInstance(response, messages_pb2.SimpleResponse) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertEqual(code, grpc.StatusCode.OK) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    async def test_apply_different_interceptors_by_metadata(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        record = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        conditional_interceptor = _filter_server_interceptor( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            lambda x: ('secret', '42') in x.invocation_metadata, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            _LoggingInterceptor('log3', record)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server_target, _ = await start_test_server(interceptors=( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            _LoggingInterceptor('log1', record), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            conditional_interceptor, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            _LoggingInterceptor('log2', record), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        )) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        async with aio.insecure_channel(server_target) as channel: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            multicallable = channel.unary_unary( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                '/grpc.testing.TestService/UnaryCall', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                request_serializer=messages_pb2.SimpleRequest.SerializeToString, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                response_deserializer=messages_pb2.SimpleResponse.FromString) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            metadata = (('key', 'value'),) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            call = multicallable(messages_pb2.SimpleRequest(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                 metadata=metadata) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            await call 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertSequenceEqual([ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'log1:intercept_service', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'log2:intercept_service', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], record) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            record.clear() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            metadata = (('key', 'value'), ('secret', '42')) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            call = multicallable(messages_pb2.SimpleRequest(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                 metadata=metadata) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            await call 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertSequenceEqual([ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'log1:intercept_service', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'log3:intercept_service', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'log2:intercept_service', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], record) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    logging.basicConfig() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    unittest.main(verbosity=2) 
			 |