| 
					
				 | 
			
			
				@@ -13,14 +13,24 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # limitations under the License. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 """Tests for Simple Stubs.""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import contextlib 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import datetime 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import inspect 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import unittest 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import sys 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import grpc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import test_common 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# TODO: Figure out how to get this test to run only for Python 3. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from typing import Callable, Optional 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+_CACHE_EPOCHS = 8 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+_CACHE_TRIALS = 6 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 _UNARY_UNARY = "/test/UnaryUnary" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -37,26 +47,93 @@ class _GenericHandler(grpc.GenericRpcHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise NotImplementedError() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    start = datetime.datetime.now() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    to_time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return datetime.datetime.now() - start 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+@contextlib.contextmanager 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def _server(credentials: Optional[grpc.ServerCredentials]): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server = test_common.test_server() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        target = '[::]:0' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if credentials is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            port = server.add_insecure_port(target) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            port = server.add_secure_port(target, credentials) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server.add_generic_rpc_handlers((_GenericHandler(),)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server.start() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        yield server, port 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    finally: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        server.stop(None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 @unittest.skipIf(sys.version_info[0] < 3, "Unsupported on Python 2.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class SimpleStubsTest(unittest.TestCase): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    @classmethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def setUpClass(cls): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        super(SimpleStubsTest, cls).setUpClass() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        cls._server = test_common.test_server() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        cls._port = cls._server.add_insecure_port('[::]:0') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        cls._server.add_generic_rpc_handlers((_GenericHandler(),)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        cls._server.start() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    @classmethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def tearDownClass(cls): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        cls._server.stop(None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        super(SimpleStubsTest, cls).tearDownClass() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def test_unary_unary(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        target = f'localhost:{self._port}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        request = b'0000' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        response = grpc.unary_unary(request, target, _UNARY_UNARY) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.assertEqual(request, response) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def assert_cached(self, to_check: Callable[[str], None]) -> None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """Asserts that a function caches intermediate data/state. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        To be specific, given a function whose caching behavior is 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        deterministic in the value of a supplied string, this function asserts 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        that, on average, subsequent invocations of the function for a specific 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        string are faster than first invocations with that same string. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Args: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+          to_check: A function returning nothing, that caches values based on 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            an arbitrary supplied Text object. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        initial_runs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        cached_runs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for epoch in range(_CACHE_EPOCHS): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            runs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            text = str(epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for trial in range(_CACHE_TRIALS): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                runs.append(_time_invocation(lambda: to_check(text))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            initial_runs.append(runs[0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            cached_runs.extend(runs[1:]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        average_cold = sum((run for run in initial_runs), datetime.timedelta()) / len(initial_runs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        average_warm = sum((run for run in cached_runs), datetime.timedelta()) / len(cached_runs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.assertLess(average_warm, average_cold) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_unary_unary_insecure(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with _server(None) as (_, port): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            target = f'localhost:{port}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            request = b'0000' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = grpc.unary_unary(request, target, _UNARY_UNARY) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertEqual(request, response) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_unary_unary_secure(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with _server(grpc.local_server_credentials()) as (_, port): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            target = f'localhost:{port}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            request = b'0000' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = grpc.unary_unary(request, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                        target, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                        _UNARY_UNARY, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                        channel_credentials=grpc.local_channel_credentials()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assertEqual(request, response) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def test_channels_cached(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with _server(grpc.local_server_credentials()) as (_, port): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            target = f'localhost:{port}' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            request = b'0000' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            test_name = inspect.stack()[0][3] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            args = (request, target, _UNARY_UNARY) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            kwargs = {"channel_credentials": grpc.local_channel_credentials()} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            def _invoke(seed: Text): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                run_kwargs = dict(kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                run_kwargs["options"] = ((test_name + seed, ""),) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                grpc.unary_unary(*args, **run_kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self.assert_cached(_invoke) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO: Test request_serializer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO: Test request_deserializer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO: Test channel_credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO: Test call_credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO: Test compression 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO: Test wait_for_ready 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO: Test metadata 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 if __name__ == "__main__": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     logging.basicConfig() 
			 |