浏览代码

Add simple eviction mechanism

Richard Belleville 5 年之前
父节点
当前提交
741b407ac4
共有 2 个文件被更改,包括 130 次插入27 次删除
  1. 105 25
      src/python/grpcio/grpc/_simple_stubs.py
  2. 25 2
      src/python/grpcio_tests/tests/unit/_simple_stubs_test.py

+ 105 - 25
src/python/grpcio/grpc/_simple_stubs.py

@@ -1,50 +1,126 @@
 # TODO: Flowerbox.
 # TODO: Flowerbox.
 
 
+import collections
+import datetime
+import os
+import logging
 import threading
 import threading
 
 
 import grpc
 import grpc
 from typing import Any, Callable, Optional, Sequence, Text, Tuple, Union
 from typing import Any, Callable, Optional, Sequence, Text, Tuple, Union
 
 
-_CHANNEL_CACHE = None
-_CHANNEL_CACHE_LOCK = threading.RLock()
 
 
-# TODO: Evict channels.
+_LOGGER = logging.getLogger(__name__)
+
+_EVICTION_PERIOD_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"
+if _EVICTION_PERIOD_KEY in os.environ:
+    _EVICTION_PERIOD = datetime.timedelta(seconds=float(os.environ[_EVICTION_PERIOD_KEY]))
+else:
+    _EVICTION_PERIOD = datetime.timedelta(minutes=10)
 
 
-# Eviction policy based on both channel count and time since use. Perhaps
-# OrderedDict instead?
 
 
 def _create_channel(target: Text,
 def _create_channel(target: Text,
                     options: Sequence[Tuple[Text, Text]],
                     options: Sequence[Tuple[Text, Text]],
                     channel_credentials: Optional[grpc.ChannelCredentials],
                     channel_credentials: Optional[grpc.ChannelCredentials],
                     compression: Optional[grpc.Compression]) -> grpc.Channel:
                     compression: Optional[grpc.Compression]) -> grpc.Channel:
     if channel_credentials is None:
     if channel_credentials is None:
+        _LOGGER.info(f"Creating insecure channel with options '{options}' " +
+                       f"and compression '{compression}'")
         return grpc.insecure_channel(target,
         return grpc.insecure_channel(target,
                                      options=options,
                                      options=options,
                                      compression=compression)
                                      compression=compression)
     else:
     else:
+        _LOGGER.info(f"Creating secure channel with credentials '{channel_credentials}', " +
+                       f"options '{options}' and compression '{compression}'")
         return grpc.secure_channel(target,
         return grpc.secure_channel(target,
                                    credentials=channel_credentials,
                                    credentials=channel_credentials,
                                    options=options,
                                    options=options,
                                    compression=compression)
                                    compression=compression)
 
 
+class ChannelCache:
+    _singleton = None
+    _lock = threading.RLock()
+    _condition = threading.Condition(lock=_lock)
+    _eviction_ready = threading.Event()
+
+
+    def __init__(self):
+        self._mapping = collections.OrderedDict()
+        self._eviction_thread = threading.Thread(target=ChannelCache._perform_evictions, daemon=True)
+        self._eviction_thread.start()
+
+
+    @staticmethod
+    def get():
+        with ChannelCache._lock:
+            if ChannelCache._singleton is None:
+                ChannelCache._singleton = ChannelCache()
+        ChannelCache._eviction_ready.wait()
+        return ChannelCache._singleton
+
+    # TODO: Type annotate key.
+    def _evict_locked(self, key):
+        channel, _ = self._mapping.pop(key)
+        _LOGGER.info(f"Evicting channel {channel} with configuration {key}.")
+        channel.close()
+        del channel
+
+
+    # TODO: Refactor. Way too deeply nested.
+    @staticmethod
+    def _perform_evictions():
+        while True:
+            with ChannelCache._lock:
+                ChannelCache._eviction_ready.set()
+                if not ChannelCache._singleton._mapping:
+                    ChannelCache._condition.wait()
+                else:
+                    key, (channel, eviction_time) = next(iter(ChannelCache._singleton._mapping.items()))
+                    now = datetime.datetime.now()
+                    if eviction_time <= now:
+                        ChannelCache._singleton._evict_locked(key)
+                        continue
+                    else:
+                        time_to_eviction = (eviction_time - now).total_seconds()
+                        ChannelCache._condition.wait(timeout=time_to_eviction)
 
 
-def _get_cached_channel(target: Text,
-                        options: Sequence[Tuple[Text, Text]],
-                        channel_credentials: Optional[grpc.ChannelCredentials],
-                        compression: Optional[grpc.Compression]) -> grpc.Channel:
-    global _CHANNEL_CACHE
-    global _CHANNEL_CACHE_LOCK
-    key = (target, options, channel_credentials, compression)
-    with _CHANNEL_CACHE_LOCK:
-        if _CHANNEL_CACHE is None:
-            _CHANNEL_CACHE = {}
-        channel = _CHANNEL_CACHE.get(key, None)
-        if channel is not None:
-            return channel
-        else:
-            channel = _create_channel(target, options, channel_credentials, compression)
-            _CHANNEL_CACHE[key] = channel
-            return channel
+
+
+    def get_channel(self,
+                    target: Text,
+                    options: Sequence[Tuple[Text, Text]],
+                    channel_credentials: Optional[grpc.ChannelCredentials],
+                    compression: Optional[grpc.Compression]) -> grpc.Channel:
+        key = (target, options, channel_credentials, compression)
+        with self._lock:
+            # TODO: Can the get and the pop be turned into a single operation?
+            channel_data = self._mapping.get(key, None)
+            if channel_data is not None:
+                channel = channel_data[0]
+                # # NOTE: This isn't actually necessary. The eviction thread will
+                # # always wake up because the new head of the list will, by
+                # # definition, have a later eviction time than the old head of
+                # # the list. If however, we allow for channels with heterogeneous
+                # # eviction periods, this *will* become necessary. We can imagine
+                # # this would be the case for timeouts. That is, if a timeout
+                # # longer than the eviction period is specified, we do not want
+                # # to cancel the RPC prematurely.
+                # if channel is next(iter(self._mapping.values()))[0]:
+                #     self._condition.notify()
+                # Move to the end of the map.
+                self._mapping.pop(key)
+                self._mapping[key] = (channel, datetime.datetime.now() + _EVICTION_PERIOD)
+                return channel
+            else:
+                channel = _create_channel(target, options, channel_credentials, compression)
+                self._mapping[key] = (channel, datetime.datetime.now() + _EVICTION_PERIOD)
+                if len(self._mapping) == 1:
+                    self._condition.notify()
+                return channel
+
+    def _test_only_channel_count(self) -> int:
+        with self._lock:
+            return len(self._mapping)
 
 
 
 
 def unary_unary(request: Any,
 def unary_unary(request: Any,
@@ -58,17 +134,21 @@ def unary_unary(request: Any,
                 call_credentials: Optional[grpc.CallCredentials] = None,
                 call_credentials: Optional[grpc.CallCredentials] = None,
                 compression: Optional[grpc.Compression] = None,
                 compression: Optional[grpc.Compression] = None,
                 wait_for_ready: Optional[bool] = None,
                 wait_for_ready: Optional[bool] = None,
+                timeout: Optional[float] = None,
                 metadata: Optional[Sequence[Tuple[Text, Union[Text, bytes]]]] = None) -> Any:
                 metadata: Optional[Sequence[Tuple[Text, Union[Text, bytes]]]] = None) -> Any:
     """Invokes a unary RPC without an explicitly specified channel.
     """Invokes a unary RPC without an explicitly specified channel.
 
 
-    This is backed by an LRU cache of channels evicted by a background thread
+    This is backed by a cache of channels evicted by a background thread
     on a periodic basis.
     on a periodic basis.
 
 
     TODO: Document the parameters and return value.
     TODO: Document the parameters and return value.
     """
     """
-    channel = _get_cached_channel(target, options, channel_credentials, compression)
+
+    # TODO: Warn if the timeout is greater than the channel eviction time.
+    channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression)
     multicallable = channel.unary_unary(method, request_serializer, request_deserializer)
     multicallable = channel.unary_unary(method, request_serializer, request_deserializer)
     return multicallable(request,
     return multicallable(request,
                          metadata=metadata,
                          metadata=metadata,
                          wait_for_ready=wait_for_ready,
                          wait_for_ready=wait_for_ready,
-                         credentials=call_credentials)
+                         credentials=call_credentials,
+                         timeout=timeout)

+ 25 - 2
src/python/grpcio_tests/tests/unit/_simple_stubs_test.py

@@ -16,12 +16,15 @@
 import contextlib
 import contextlib
 import datetime
 import datetime
 import inspect
 import inspect
+import os
 import unittest
 import unittest
 import sys
 import sys
 import time
 import time
 
 
 import logging
 import logging
 
 
+os.environ["GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"] = "1"
+
 import grpc
 import grpc
 import test_common
 import test_common
 
 
@@ -82,7 +85,7 @@ class SimpleStubsTest(unittest.TestCase):
 
 
         Args:
         Args:
           to_check: A function returning nothing, that caches values based on
           to_check: A function returning nothing, that caches values based on
-            an arbitrary supplied Text object.
+            an arbitrary supplied string.
         """
         """
         initial_runs = []
         initial_runs = []
         cached_runs = []
         cached_runs = []
@@ -121,12 +124,32 @@ class SimpleStubsTest(unittest.TestCase):
             test_name = inspect.stack()[0][3]
             test_name = inspect.stack()[0][3]
             args = (request, target, _UNARY_UNARY)
             args = (request, target, _UNARY_UNARY)
             kwargs = {"channel_credentials": grpc.local_channel_credentials()}
             kwargs = {"channel_credentials": grpc.local_channel_credentials()}
-            def _invoke(seed: Text):
+            def _invoke(seed: str):
                 run_kwargs = dict(kwargs)
                 run_kwargs = dict(kwargs)
                 run_kwargs["options"] = ((test_name + seed, ""),)
                 run_kwargs["options"] = ((test_name + seed, ""),)
                 grpc.unary_unary(*args, **run_kwargs)
                 grpc.unary_unary(*args, **run_kwargs)
             self.assert_cached(_invoke)
             self.assert_cached(_invoke)
 
 
+    # TODO: Can this somehow be made more blackbox?
+    def test_channels_evicted(self):
+        with _server(grpc.local_server_credentials()) as (_, port):
+            target = f'localhost:{port}'
+            request = b'0000'
+            response = grpc.unary_unary(request,
+                                        target,
+                                        _UNARY_UNARY,
+                                        channel_credentials=grpc.local_channel_credentials())
+            channel_count = None
+            deadline = datetime.datetime.now() + datetime.timedelta(seconds=10)
+            while datetime.datetime.now() < deadline:
+                channel_count = grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
+                if channel_count == 0:
+                    break
+                time.sleep(1)
+            else:
+                self.assertFalse("Not all channels were evicted. {channel_count} remain.")
+
+
     # TODO: Test request_serializer
     # TODO: Test request_serializer
     # TODO: Test request_deserializer
     # TODO: Test request_deserializer
     # TODO: Test channel_credentials
     # TODO: Test channel_credentials