Browse Source

Add thread-safe channel cache. Test that it actually caches

Richard Belleville 5 years ago
parent
commit
5a8a6e3ad3

+ 33 - 5
src/python/grpcio/grpc/_simple_stubs.py

@@ -1,14 +1,22 @@
 # TODO: Flowerbox.
 
+import threading
+
 import grpc
 from typing import Any, Callable, Optional, Sequence, Text, Tuple, Union
 
+_CHANNEL_CACHE = None
+_CHANNEL_CACHE_LOCK = threading.RLock()
 
-def _get_cached_channel(target: Text,
-                        options: Sequence[Tuple[Text, Text]],
-                        channel_credentials: Optional[grpc.ChannelCredentials],
-                        compression: Optional[grpc.Compression]) -> grpc.Channel:
-    # TODO: Actually cache.
+# TODO: Evict channels.
+
+# Eviction policy based on both channel count and time since use. Perhaps
+# OrderedDict instead?
+
+def _create_channel(target: Text,
+                    options: Sequence[Tuple[Text, Text]],
+                    channel_credentials: Optional[grpc.ChannelCredentials],
+                    compression: Optional[grpc.Compression]) -> grpc.Channel:
     if channel_credentials is None:
         return grpc.insecure_channel(target,
                                      options=options,
@@ -19,6 +27,26 @@ def _get_cached_channel(target: Text,
                                    options=options,
                                    compression=compression)
 
+
+def _get_cached_channel(target: Text,
+                        options: Sequence[Tuple[Text, Text]],
+                        channel_credentials: Optional[grpc.ChannelCredentials],
+                        compression: Optional[grpc.Compression]) -> grpc.Channel:
+    global _CHANNEL_CACHE
+    global _CHANNEL_CACHE_LOCK
+    key = (target, options, channel_credentials, compression)
+    with _CHANNEL_CACHE_LOCK:
+        if _CHANNEL_CACHE is None:
+            _CHANNEL_CACHE = {}
+        channel = _CHANNEL_CACHE.get(key, None)
+        if channel is not None:
+            return channel
+        else:
+            channel = _create_channel(target, options, channel_credentials, compression)
+            _CHANNEL_CACHE[key] = channel
+            return channel
+
+
 def unary_unary(request: Any,
                 target: Text,
                 method: Text,

+ 95 - 18
src/python/grpcio_tests/tests/unit/_simple_stubs_test.py

@@ -13,14 +13,24 @@
 # limitations under the License.
 """Tests for Simple Stubs."""
 
+import contextlib
+import datetime
+import inspect
 import unittest
 import sys
+import time
 
 import logging
 
 import grpc
 import test_common
 
+# TODO: Figure out how to get this test to run only for Python 3.
+from typing import Callable, Optional
+
+_CACHE_EPOCHS = 8
+_CACHE_TRIALS = 6
+
 
 _UNARY_UNARY = "/test/UnaryUnary"
 
@@ -37,26 +47,93 @@ class _GenericHandler(grpc.GenericRpcHandler):
             raise NotImplementedError()
 
 
+def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta:
+    start = datetime.datetime.now()
+    to_time()
+    return datetime.datetime.now() - start
+
+
+@contextlib.contextmanager
+def _server(credentials: Optional[grpc.ServerCredentials]):
+    try:
+        server = test_common.test_server()
+        target = '[::]:0'
+        if credentials is None:
+            port = server.add_insecure_port(target)
+        else:
+            port = server.add_secure_port(target, credentials)
+        server.add_generic_rpc_handlers((_GenericHandler(),))
+        server.start()
+        yield server, port
+    finally:
+        server.stop(None)
+
+
 @unittest.skipIf(sys.version_info[0] < 3, "Unsupported on Python 2.")
 class SimpleStubsTest(unittest.TestCase):
-    @classmethod
-    def setUpClass(cls):
-        super(SimpleStubsTest, cls).setUpClass()
-        cls._server = test_common.test_server()
-        cls._port = cls._server.add_insecure_port('[::]:0')
-        cls._server.add_generic_rpc_handlers((_GenericHandler(),))
-        cls._server.start()
-
-    @classmethod
-    def tearDownClass(cls):
-        cls._server.stop(None)
-        super(SimpleStubsTest, cls).tearDownClass()
-
-    def test_unary_unary(self):
-        target = f'localhost:{self._port}'
-        request = b'0000'
-        response = grpc.unary_unary(request, target, _UNARY_UNARY)
-        self.assertEqual(request, response)
+
+    def assert_cached(self, to_check: Callable[[str], None]) -> None:
+        """Asserts that a function caches intermediate data/state.
+
+        To be specific, given a function whose caching behavior is
+        deterministic in the value of a supplied string, this function asserts
+        that, on average, subsequent invocations of the function for a specific
+        string are faster than first invocations with that same string.
+
+        Args:
+          to_check: A function returning nothing, that caches values based on
+            an arbitrary supplied Text object.
+        """
+        initial_runs = []
+        cached_runs = []
+        for epoch in range(_CACHE_EPOCHS):
+            runs = []
+            text = str(epoch)
+            for trial in range(_CACHE_TRIALS):
+                runs.append(_time_invocation(lambda: to_check(text)))
+            initial_runs.append(runs[0])
+            cached_runs.extend(runs[1:])
+        average_cold = sum((run for run in initial_runs), datetime.timedelta()) / len(initial_runs)
+        average_warm = sum((run for run in cached_runs), datetime.timedelta()) / len(cached_runs)
+        self.assertLess(average_warm, average_cold)
+
+    def test_unary_unary_insecure(self):
+        with _server(None) as (_, port):
+            target = f'localhost:{port}'
+            request = b'0000'
+            response = grpc.unary_unary(request, target, _UNARY_UNARY)
+            self.assertEqual(request, response)
+
+    def test_unary_unary_secure(self):
+        with _server(grpc.local_server_credentials()) as (_, port):
+            target = f'localhost:{port}'
+            request = b'0000'
+            response = grpc.unary_unary(request,
+                                        target,
+                                        _UNARY_UNARY,
+                                        channel_credentials=grpc.local_channel_credentials())
+            self.assertEqual(request, response)
+
+    def test_channels_cached(self):
+        with _server(grpc.local_server_credentials()) as (_, port):
+            target = f'localhost:{port}'
+            request = b'0000'
+            test_name = inspect.stack()[0][3]
+            args = (request, target, _UNARY_UNARY)
+            kwargs = {"channel_credentials": grpc.local_channel_credentials()}
+            def _invoke(seed: Text):
+                run_kwargs = dict(kwargs)
+                run_kwargs["options"] = ((test_name + seed, ""),)
+                grpc.unary_unary(*args, **run_kwargs)
+            self.assert_cached(_invoke)
+
+    # TODO: Test request_serializer
+    # TODO: Test request_deserializer
+    # TODO: Test channel_credentials
+    # TODO: Test call_credentials
+    # TODO: Test compression
+    # TODO: Test wait_for_ready
+    # TODO: Test metadata
 
 if __name__ == "__main__":
     logging.basicConfig()