|
@@ -13,22 +13,25 @@
|
|
|
# limitations under the License.
|
|
|
"""Tests for Simple Stubs."""
|
|
|
|
|
|
+import os
|
|
|
+
|
|
|
+_MAXIMUM_CHANNELS = 10
|
|
|
+
|
|
|
+os.environ["GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"] = "1"
|
|
|
+os.environ["GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM"] = str(_MAXIMUM_CHANNELS)
|
|
|
+
|
|
|
import contextlib
|
|
|
import datetime
|
|
|
import inspect
|
|
|
-import os
|
|
|
+import logging
|
|
|
import unittest
|
|
|
import sys
|
|
|
import time
|
|
|
+from typing import Callable, Optional
|
|
|
|
|
|
-import logging
|
|
|
-
|
|
|
-os.environ["GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS"] = "1"
|
|
|
-
|
|
|
-import grpc
|
|
|
import test_common
|
|
|
+import grpc
|
|
|
|
|
|
-from typing import Callable, Optional
|
|
|
|
|
|
_CACHE_EPOCHS = 8
|
|
|
_CACHE_TRIALS = 6
|
|
@@ -71,7 +74,6 @@ def _server(credentials: Optional[grpc.ServerCredentials]):
|
|
|
server.stop(None)
|
|
|
|
|
|
|
|
|
-@unittest.skipIf(sys.version_info[0] < 3, "Unsupported on Python 2.")
|
|
|
class SimpleStubsTest(unittest.TestCase):
|
|
|
|
|
|
def assert_cached(self, to_check: Callable[[str], None]) -> None:
|
|
@@ -99,6 +101,22 @@ class SimpleStubsTest(unittest.TestCase):
|
|
|
average_warm = sum((run for run in cached_runs), datetime.timedelta()) / len(cached_runs)
|
|
|
self.assertLess(average_warm, average_cold)
|
|
|
|
|
|
+
|
|
|
+ def assert_eventually(self,
|
|
|
+ predicate: Callable[[], bool],
|
|
|
+ *,
|
|
|
+ timeout: Optional[datetime.timedelta] = None,
|
|
|
+ message: Optional[Callable[[], str]] = None) -> None:
|
|
|
+ message = message or (lambda: "Proposition did not evaluate to true")
|
|
|
+ timeout = timeout or datetime.timedelta(seconds=10)
|
|
|
+ end = datetime.datetime.now() + timeout
|
|
|
+ while datetime.datetime.now() < end:
|
|
|
+ if predicate():
|
|
|
+ break
|
|
|
+ time.sleep(0.5)
|
|
|
+ else:
|
|
|
+ self.fail(message() + " after " + str(timeout))
|
|
|
+
|
|
|
def test_unary_unary_insecure(self):
|
|
|
with _server(None) as (_, port):
|
|
|
target = f'localhost:{port}'
|
|
@@ -129,7 +147,6 @@ class SimpleStubsTest(unittest.TestCase):
|
|
|
grpc.unary_unary(*args, **run_kwargs)
|
|
|
self.assert_cached(_invoke)
|
|
|
|
|
|
- # TODO: Can this somehow be made more blackbox?
|
|
|
def test_channels_evicted(self):
|
|
|
with _server(grpc.local_server_credentials()) as (_, port):
|
|
|
target = f'localhost:{port}'
|
|
@@ -138,15 +155,26 @@ class SimpleStubsTest(unittest.TestCase):
|
|
|
target,
|
|
|
_UNARY_UNARY,
|
|
|
channel_credentials=grpc.local_channel_credentials())
|
|
|
- channel_count = None
|
|
|
- deadline = datetime.datetime.now() + datetime.timedelta(seconds=10)
|
|
|
- while datetime.datetime.now() < deadline:
|
|
|
- channel_count = grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
|
|
|
- if channel_count == 0:
|
|
|
- break
|
|
|
- time.sleep(1)
|
|
|
- else:
|
|
|
- self.assertFalse("Not all channels were evicted. {channel_count} remain.")
|
|
|
+ self.assert_eventually(
|
|
|
+ lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() == 0,
|
|
|
+ message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain")
|
|
|
+
|
|
|
+ def test_total_channels_enforced(self):
|
|
|
+ with _server(grpc.local_server_credentials()) as (_, port):
|
|
|
+ target = f'localhost:{port}'
|
|
|
+ request = b'0000'
|
|
|
+ for i in range(99):
|
|
|
+ # Ensure we get a new channel each time.
|
|
|
+ options = (("foo", str(i)),)
|
|
|
+ # Send messages at full blast.
|
|
|
+ grpc.unary_unary(request,
|
|
|
+ target,
|
|
|
+ _UNARY_UNARY,
|
|
|
+ options=options,
|
|
|
+ channel_credentials=grpc.local_channel_credentials())
|
|
|
+ self.assert_eventually(
|
|
|
+ lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1,
|
|
|
+ message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain")
|
|
|
|
|
|
|
|
|
# TODO: Test request_serializer
|
|
@@ -158,7 +186,5 @@ class SimpleStubsTest(unittest.TestCase):
|
|
|
# TODO: Test metadata
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- logging.basicConfig()
|
|
|
+ logging.basicConfig(level=logging.INFO)
|
|
|
unittest.main(verbosity=2)
|
|
|
-
|
|
|
-
|