Эх сурвалжийг харах

Add maximum-based channel eviction

Richard Belleville 5 жил өмнө
parent
commit
4ac50ceed6

+ 13 - 12
src/python/grpcio/grpc/_simple_stubs.py

@@ -15,9 +15,16 @@ _LOGGER = logging.getLogger(__name__)
 _EVICTION_PERIOD_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"
 if _EVICTION_PERIOD_KEY in os.environ:
     _EVICTION_PERIOD = datetime.timedelta(seconds=float(os.environ[_EVICTION_PERIOD_KEY]))
+    _LOGGER.info(f"Setting managed channel eviction period to {_EVICTION_PERIOD}")
 else:
     _EVICTION_PERIOD = datetime.timedelta(minutes=10)
 
+_MAXIMUM_CHANNELS_KEY = "GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM"
+if _MAXIMUM_CHANNELS_KEY in os.environ:
+    _MAXIMUM_CHANNELS = int(os.environ[_MAXIMUM_CHANNELS_KEY])
+    _LOGGER.info(f"Setting maximum managed channels to {_MAXIMUM_CHANNELS}")
+else:
+    _MAXIMUM_CHANNELS = 2 ** 8
 
 def _create_channel(target: Text,
                     options: Sequence[Tuple[Text, Text]],
@@ -74,6 +81,10 @@ class ChannelCache:
                 ChannelCache._eviction_ready.set()
                 if not ChannelCache._singleton._mapping:
                     ChannelCache._condition.wait()
+                elif len(ChannelCache._singleton._mapping) > _MAXIMUM_CHANNELS:
+                    key = next(iter(ChannelCache._singleton._mapping.keys()))
+                    ChannelCache._singleton._evict_locked(key)
+                    # And immediately reevaluate.
                 else:
                     key, (channel, eviction_time) = next(iter(ChannelCache._singleton._mapping.items()))
                     now = datetime.datetime.now()
@@ -97,24 +108,13 @@ class ChannelCache:
             channel_data = self._mapping.get(key, None)
             if channel_data is not None:
                 channel = channel_data[0]
-                # # NOTE: This isn't actually necessary. The eviction thread will
-                # # always wake up because the new head of the list will, by
-                # # definition, have a later eviction time than the old head of
-                # # the list. If however, we allow for channels with heterogeneous
-                # # eviction periods, this *will* become necessary. We can imagine
-                # # this would be the case for timeouts. That is, if a timeout
-                # # longer than the eviction period is specified, we do not want
-                # # to cancel the RPC prematurely.
-                # if channel is next(iter(self._mapping.values()))[0]:
-                #     self._condition.notify()
-                # Move to the end of the map.
                 self._mapping.pop(key)
                 self._mapping[key] = (channel, datetime.datetime.now() + _EVICTION_PERIOD)
                 return channel
             else:
                 channel = _create_channel(target, options, channel_credentials, compression)
                 self._mapping[key] = (channel, datetime.datetime.now() + _EVICTION_PERIOD)
-                if len(self._mapping) == 1:
+                if len(self._mapping) == 1 or len(self._mapping) >= _MAXIMUM_CHANNELS:
                     self._condition.notify()
                 return channel
 
@@ -123,6 +123,7 @@ class ChannelCache:
             return len(self._mapping)
 
 
+# TODO: s/Text/str/g
 def unary_unary(request: Any,
                 target: Text,
                 method: Text,

+ 47 - 21
src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py

@@ -13,22 +13,25 @@
 # limitations under the License.
 """Tests for Simple Stubs."""
 
+import os
+
+_MAXIMUM_CHANNELS = 10
+
+os.environ["GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"] = "1"
+os.environ["GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM"] = str(_MAXIMUM_CHANNELS)
+
 import contextlib
 import datetime
 import inspect
-import os
+import logging
 import unittest
 import sys
 import time
+from typing import Callable, Optional
 
-import logging
-
-os.environ["GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"] = "1"
-
-import grpc
 import test_common
+import grpc
 
-from typing import Callable, Optional
 
 _CACHE_EPOCHS = 8
 _CACHE_TRIALS = 6
@@ -71,7 +74,6 @@ def _server(credentials: Optional[grpc.ServerCredentials]):
         server.stop(None)
 
 
-@unittest.skipIf(sys.version_info[0] < 3, "Unsupported on Python 2.")
 class SimpleStubsTest(unittest.TestCase):
 
     def assert_cached(self, to_check: Callable[[str], None]) -> None:
@@ -99,6 +101,22 @@ class SimpleStubsTest(unittest.TestCase):
         average_warm = sum((run for run in cached_runs), datetime.timedelta()) / len(cached_runs)
         self.assertLess(average_warm, average_cold)
 
+
+    def assert_eventually(self,
+                          predicate: Callable[[], bool],
+                          *,
+                          timeout: Optional[datetime.timedelta] = None,
+                          message: Optional[Callable[[], str]] = None) -> None:
+        message = message or (lambda: "Proposition did not evaluate to true")
+        timeout = timeout or datetime.timedelta(seconds=10)
+        end = datetime.datetime.now() + timeout
+        while datetime.datetime.now() < end:
+            if predicate():
+                break
+            time.sleep(0.5)
+        else:
+            self.fail(message() + " after " + str(timeout))
+
     def test_unary_unary_insecure(self):
         with _server(None) as (_, port):
             target = f'localhost:{port}'
@@ -129,7 +147,6 @@ class SimpleStubsTest(unittest.TestCase):
                 grpc.unary_unary(*args, **run_kwargs)
             self.assert_cached(_invoke)
 
-    # TODO: Can this somehow be made more blackbox?
     def test_channels_evicted(self):
         with _server(grpc.local_server_credentials()) as (_, port):
             target = f'localhost:{port}'
@@ -138,15 +155,26 @@ class SimpleStubsTest(unittest.TestCase):
                                         target,
                                         _UNARY_UNARY,
                                         channel_credentials=grpc.local_channel_credentials())
-            channel_count = None
-            deadline = datetime.datetime.now() + datetime.timedelta(seconds=10)
-            while datetime.datetime.now() < deadline:
-                channel_count = grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
-                if channel_count == 0:
-                    break
-                time.sleep(1)
-            else:
-                self.assertFalse("Not all channels were evicted. {channel_count} remain.")
+            self.assert_eventually(
+                lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() == 0,
+                message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain")
+
+    def test_total_channels_enforced(self):
+        with _server(grpc.local_server_credentials()) as (_, port):
+            target = f'localhost:{port}'
+            request = b'0000'
+            for i in range(99):
+                # Ensure we get a new channel each time.
+                options = (("foo", str(i)),)
+                # Send messages at full blast.
+                grpc.unary_unary(request,
+                                 target,
+                                 _UNARY_UNARY,
+                                 options=options,
+                                 channel_credentials=grpc.local_channel_credentials())
+                self.assert_eventually(
+                    lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1,
+                    message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain")
 
 
     # TODO: Test request_serializer
@@ -158,7 +186,5 @@ class SimpleStubsTest(unittest.TestCase):
     # TODO: Test metadata
 
 if __name__ == "__main__":
-    logging.basicConfig()
+    logging.basicConfig(level=logging.INFO)
     unittest.main(verbosity=2)
-
-