Quellcode durchsuchen

WIP. Implement CB interop test

Richard Belleville vor 4 Jahren
Ursprung
Commit
c4d3fc749f
1 geänderte Dateien mit 102 neuen und 22 gelöschten Zeilen
  1. 102 22
      src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py

+ 102 - 22
src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py

@@ -13,6 +13,8 @@
 # limitations under the License.
 
 import argparse
+import collections
+import datetime
 import logging
 import signal
 import threading
@@ -42,8 +44,16 @@ _SUPPORTED_METHODS = (
     "EmptyCall",
 )
 
+_METHOD_STR_TO_ENUM = {
+    "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL,
+    "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL,
+}
+
+_METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
+
 PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
 
+_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
 
 class _StatsWatcher:
     _start: int
@@ -98,6 +108,9 @@ _stop_event = threading.Event()
 _global_rpc_id: int = 0
 _watchers: Set[_StatsWatcher] = set()
 _global_server = None
+_global_rpcs_started: Mapping[str, int] = collections.defaultdict(int)
+_global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int)
+_global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int)
 
 
 def _handle_sigint(sig, frame):
@@ -129,6 +142,17 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
         logger.info("Returning stats response: {}".format(response))
         return response
 
+    def GetClientAccumulatedStats(self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest, context: grpc.ServicerContext):
+        logger.info("Received cumulative stats request.")
+        response = messages_pb2.LoadBalancerAccumulatedStatsResponse()
+        with _global_lock:
+            for method in _SUPPORTED_METHODS:
+                response.num_rpcs_started_by_method[method] = _global_rpcs_started.get[method]
+                response.num_rpcs_succeeded_by_method[method] = _global_rpcs_succeeded.get[method]
+                response.num_rpcs_failed_by_method[method] = _global_rpcs_succeeded.get[method]
+        logger.info("Returning cumulative stats request.")
+        return response
+
 
 def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]],
                request_id: int, stub: test_pb2_grpc.TestServiceStub,
@@ -153,11 +177,15 @@ def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str,
     exception = future.exception()
     hostname = ""
     if exception is not None:
+        with _global_lock:
+            _global_rpcs_failed[method] += 1
         if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
             logger.error(f"RPC {rpc_id} timed out")
         else:
             logger.error(exception)
     else:
+        with _global_lock:
+            _global_rpcs_succeeded[method] += 1
         response = future.result()
         hostname = None
         for metadatum in future.initial_metadata():
@@ -193,25 +221,52 @@ def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
     for future, _ in futures.values():
         future.cancel()
 
+class _ChannelConfiguration:
+    """Configuration for a single client channel.
+
+    Instances of this class are meant to be dealt with as PODs. That is,
+    data member should be accessed directly. This class is not thread-safe.
+    When accessing any of its members, the lock member should be held.
+    """
+    def __init__(self, method: str, metadata: Sequence[Tuple[str, str]],
+                 qps: int, server: str, rpc_timeout_sec: int,
+                 print_response: bool):
+        # condition is signalled when a change is made to the config.
+        self.condition = threading.Condition()
 
-def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]],
-                        qps: int, server: str, rpc_timeout_sec: int,
-                        print_response: bool):
+        self.method = method
+        self.metadata = metadata
+        self.qps = qps
+        self.server = server
+        self.rpc_timeout_sec = rpc_timeout_sec
+        self.print_response = print_response
+
+def _run_single_channel(config: _ChannelConfiguration):
     global _global_rpc_id  # pylint: disable=global-statement
-    duration_per_query = 1.0 / float(qps)
+    with config.condition:
+        server = config.server
     with grpc.insecure_channel(server) as channel:
         stub = test_pb2_grpc.TestServiceStub(channel)
         futures: Dict[int, Tuple[grpc.Future, str]] = {}
         while not _stop_event.is_set():
+            with config.condition:
+                if config.qps == 0:
+                    config.condition.wait(timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds())
+                    continue
+                else:
+                    duration_per_query = 1.0 / float(config.qps)
             request_id = None
             with _global_lock:
                 request_id = _global_rpc_id
                 _global_rpc_id += 1
+                _global_rpcs_started[config.method] += 1
             start = time.time()
             end = start + duration_per_query
-            _start_rpc(method, metadata, request_id, stub,
-                       float(rpc_timeout_sec), futures)
-            _remove_completed_rpcs(futures, print_response)
+            with config.condition:
+                _start_rpc(config.method, config.metadata, request_id, stub,
+                           float(config.rpc_timeout_sec), futures)
+            with config.condition:
+                _remove_completed_rpcs(futures, config.print_response)
             logger.debug(f"Currently {len(futures)} in-flight RPCs")
             now = time.time()
             while now < end:
@@ -220,26 +275,44 @@ def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]],
         _cancel_all_rpcs(futures)
 
 
+class _XdsUpdateClientConfigureServicer(test_pb2_grpc.XdsUpdateClientConfigureServiceServicer):
+
+    def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration], qps: int):
+        super(_XdsUpdateClientConfigureServicer).__init__()
+        self._per_method_configs = per_method_configs
+        self._qps = qps
+
+    def Configure(self, request: messages_pb2.ClientConfigureRequest,
+            context: grpc.ServicerContext) -> messages_pb2.ClientConfigureResponse:
+        method_strs = (_METHOD_ENUM_TO_STR[t] for t in request.types)
+        for method in _SUPPORTED_METHODS:
+            method_enum = _METHOD_STR_TO_ENUM[method]
+            if method in method_strs:
+                qps = self._qps
+                metadata = ((md.key, md.value) for md in request.metadata if md.type == method_enum)
+            else:
+                qps = 0
+                metadata = ()
+            channel_config = self._per_method_config[method]
+            with channel_config.condition:
+                channel_config.qps = qps
+                channel_config.metadata = metadata
+                channel_config.condition.notify_all()
+        # TODO: Wait for all channels to respond until responding to RPC?
+        return messages_pb2.ClientConfigureResponse()
+
+
 class _MethodHandle:
     """An object grouping together threads driving RPCs for a method."""
 
     _channel_threads: List[threading.Thread]
 
-    def __init__(self, method: str, metadata: Sequence[Tuple[str, str]],
-                 num_channels: int, qps: int, server: str, rpc_timeout_sec: int,
-                 print_response: bool):
+    def __init__(self, num_channels: int, channel_config: _ChannelConfiguration):
         """Creates and starts a group of threads running the indicated method."""
         self._channel_threads = []
         for i in range(num_channels):
             thread = threading.Thread(target=_run_single_channel,
-                                      args=(
-                                          method,
-                                          metadata,
-                                          qps,
-                                          server,
-                                          rpc_timeout_sec,
-                                          print_response,
-                                      ))
+                                      args=(channel_config,))
             thread.start()
             self._channel_threads.append(thread)
 
@@ -254,15 +327,22 @@ def _run(args: argparse.Namespace, methods: Sequence[str],
     logger.info("Starting python xDS Interop Client.")
     global _global_server  # pylint: disable=global-statement
     method_handles = []
-    for method in methods:
+    channel_configs = {} 
+    for method in _SUPPORTED_METHODS:
+        if method in methods:
+            qps = args.qps
+        else:
+            qps = 0
+        channel_config = _ChannelConfiguration(method, per_method_metadata.get(method, []),
+                            qps, args.server, args.rpc_timeout_sec, args.print_response)
+        channel_configs[method] = channel_config
         method_handles.append(
-            _MethodHandle(method, per_method_metadata.get(method, []),
-                          args.num_channels, args.qps, args.server,
-                          args.rpc_timeout_sec, args.print_response))
+            _MethodHandle(args.num_channels, channel_config))
     _global_server = grpc.server(futures.ThreadPoolExecutor())
     _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
     test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
         _LoadBalancerStatsServicer(), _global_server)
+    test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server(_XdsUpdateClientConfigureServicer(channel_configs, args.qps), _global_server)
     _global_server.start()
     _global_server.wait_for_termination()
     for method_handle in method_handles: