|
@@ -13,14 +13,24 @@
|
|
|
# limitations under the License.
|
|
|
"""Tests for Simple Stubs."""
|
|
|
|
|
|
+import contextlib
|
|
|
+import datetime
|
|
|
+import inspect
|
|
|
import unittest
|
|
|
import sys
|
|
|
+import time
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import grpc
|
|
|
import test_common
|
|
|
|
|
|
+# TODO: Figure out how to get this test to run only for Python 3.
|
|
|
+from typing import Callable, Optional
|
|
|
+
|
|
|
+_CACHE_EPOCHS = 8
|
|
|
+_CACHE_TRIALS = 6
|
|
|
+
|
|
|
|
|
|
_UNARY_UNARY = "/test/UnaryUnary"
|
|
|
|
|
@@ -37,26 +47,93 @@ class _GenericHandler(grpc.GenericRpcHandler):
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
+def _time_invocation(to_time: Callable[[], None]) -> datetime.timedelta:
|
|
|
+ start = datetime.datetime.now()
|
|
|
+ to_time()
|
|
|
+ return datetime.datetime.now() - start
|
|
|
+
|
|
|
+
|
|
|
+@contextlib.contextmanager
|
|
|
+def _server(credentials: Optional[grpc.ServerCredentials]):
|
|
|
+ try:
|
|
|
+ server = test_common.test_server()
|
|
|
+ target = '[::]:0'
|
|
|
+ if credentials is None:
|
|
|
+ port = server.add_insecure_port(target)
|
|
|
+ else:
|
|
|
+ port = server.add_secure_port(target, credentials)
|
|
|
+ server.add_generic_rpc_handlers((_GenericHandler(),))
|
|
|
+ server.start()
|
|
|
+ yield server, port
|
|
|
+ finally:
|
|
|
+ server.stop(None)
|
|
|
+
|
|
|
+
|
|
|
@unittest.skipIf(sys.version_info[0] < 3, "Unsupported on Python 2.")
|
|
|
class SimpleStubsTest(unittest.TestCase):
|
|
|
- @classmethod
|
|
|
- def setUpClass(cls):
|
|
|
- super(SimpleStubsTest, cls).setUpClass()
|
|
|
- cls._server = test_common.test_server()
|
|
|
- cls._port = cls._server.add_insecure_port('[::]:0')
|
|
|
- cls._server.add_generic_rpc_handlers((_GenericHandler(),))
|
|
|
- cls._server.start()
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def tearDownClass(cls):
|
|
|
- cls._server.stop(None)
|
|
|
- super(SimpleStubsTest, cls).tearDownClass()
|
|
|
-
|
|
|
- def test_unary_unary(self):
|
|
|
- target = f'localhost:{self._port}'
|
|
|
- request = b'0000'
|
|
|
- response = grpc.unary_unary(request, target, _UNARY_UNARY)
|
|
|
- self.assertEqual(request, response)
|
|
|
+
|
|
|
+ def assert_cached(self, to_check: Callable[[str], None]) -> None:
|
|
|
+ """Asserts that a function caches intermediate data/state.
|
|
|
+
|
|
|
+ To be specific, given a function whose caching behavior is
|
|
|
+ deterministic in the value of a supplied string, this function asserts
|
|
|
+ that, on average, subsequent invocations of the function for a specific
|
|
|
+ string are faster than first invocations with that same string.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ to_check: A function returning nothing, that caches values based on
|
|
|
+ an arbitrary supplied Text object.
|
|
|
+ """
|
|
|
+ initial_runs = []
|
|
|
+ cached_runs = []
|
|
|
+ for epoch in range(_CACHE_EPOCHS):
|
|
|
+ runs = []
|
|
|
+ text = str(epoch)
|
|
|
+ for trial in range(_CACHE_TRIALS):
|
|
|
+ runs.append(_time_invocation(lambda: to_check(text)))
|
|
|
+ initial_runs.append(runs[0])
|
|
|
+ cached_runs.extend(runs[1:])
|
|
|
+ average_cold = sum((run for run in initial_runs), datetime.timedelta()) / len(initial_runs)
|
|
|
+ average_warm = sum((run for run in cached_runs), datetime.timedelta()) / len(cached_runs)
|
|
|
+ self.assertLess(average_warm, average_cold)
|
|
|
+
|
|
|
+ def test_unary_unary_insecure(self):
|
|
|
+ with _server(None) as (_, port):
|
|
|
+ target = f'localhost:{port}'
|
|
|
+ request = b'0000'
|
|
|
+ response = grpc.unary_unary(request, target, _UNARY_UNARY)
|
|
|
+ self.assertEqual(request, response)
|
|
|
+
|
|
|
+ def test_unary_unary_secure(self):
|
|
|
+ with _server(grpc.local_server_credentials()) as (_, port):
|
|
|
+ target = f'localhost:{port}'
|
|
|
+ request = b'0000'
|
|
|
+ response = grpc.unary_unary(request,
|
|
|
+ target,
|
|
|
+ _UNARY_UNARY,
|
|
|
+ channel_credentials=grpc.local_channel_credentials())
|
|
|
+ self.assertEqual(request, response)
|
|
|
+
|
|
|
+ def test_channels_cached(self):
|
|
|
+ with _server(grpc.local_server_credentials()) as (_, port):
|
|
|
+ target = f'localhost:{port}'
|
|
|
+ request = b'0000'
|
|
|
+ test_name = inspect.stack()[0][3]
|
|
|
+ args = (request, target, _UNARY_UNARY)
|
|
|
+ kwargs = {"channel_credentials": grpc.local_channel_credentials()}
|
|
|
+ def _invoke(seed: Text):
|
|
|
+ run_kwargs = dict(kwargs)
|
|
|
+ run_kwargs["options"] = ((test_name + seed, ""),)
|
|
|
+ grpc.unary_unary(*args, **run_kwargs)
|
|
|
+ self.assert_cached(_invoke)
|
|
|
+
|
|
|
+ # TODO: Test request_serializer
|
|
|
+ # TODO: Test request_deserializer
|
|
|
+ # TODO: Test channel_credentials
|
|
|
+ # TODO: Test call_credentials
|
|
|
+ # TODO: Test compression
|
|
|
+ # TODO: Test wait_for_ready
|
|
|
+ # TODO: Test metadata
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
logging.basicConfig()
|