|  | @@ -13,6 +13,8 @@
 | 
	
		
			
				|  |  |  # limitations under the License.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import argparse
 | 
	
		
			
				|  |  | +import collections
 | 
	
		
			
				|  |  | +import datetime
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  |  import signal
 | 
	
		
			
				|  |  |  import threading
 | 
	
	
		
			
				|  | @@ -42,8 +44,22 @@ _SUPPORTED_METHODS = (
 | 
	
		
			
				|  |  |      "EmptyCall",
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +_METHOD_CAMEL_TO_CAPS_SNAKE = {
 | 
	
		
			
				|  |  | +    "UnaryCall": "UNARY_CALL",
 | 
	
		
			
				|  |  | +    "EmptyCall": "EMPTY_CALL",
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +_METHOD_STR_TO_ENUM = {
 | 
	
		
			
				|  |  | +    "UnaryCall": messages_pb2.ClientConfigureRequest.UNARY_CALL,
 | 
	
		
			
				|  |  | +    "EmptyCall": messages_pb2.ClientConfigureRequest.EMPTY_CALL,
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +_METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class _StatsWatcher:
 | 
	
		
			
				|  |  |      _start: int
 | 
	
	
		
			
				|  | @@ -98,9 +114,12 @@ _stop_event = threading.Event()
 | 
	
		
			
				|  |  |  _global_rpc_id: int = 0
 | 
	
		
			
				|  |  |  _watchers: Set[_StatsWatcher] = set()
 | 
	
		
			
				|  |  |  _global_server = None
 | 
	
		
			
				|  |  | +_global_rpcs_started: Mapping[str, int] = collections.defaultdict(int)
 | 
	
		
			
				|  |  | +_global_rpcs_succeeded: Mapping[str, int] = collections.defaultdict(int)
 | 
	
		
			
				|  |  | +_global_rpcs_failed: Mapping[str, int] = collections.defaultdict(int)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def _handle_sigint(sig, frame):
 | 
	
		
			
				|  |  | +def _handle_sigint(sig, frame) -> None:
 | 
	
		
			
				|  |  |      _stop_event.set()
 | 
	
		
			
				|  |  |      _global_server.stop(None)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -127,7 +146,25 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
 | 
	
		
			
				|  |  |          response = watcher.await_rpc_stats_response(request.timeout_sec)
 | 
	
		
			
				|  |  |          with _global_lock:
 | 
	
		
			
				|  |  |              _watchers.remove(watcher)
 | 
	
		
			
				|  |  | -        logger.info("Returning stats response: {}".format(response))
 | 
	
		
			
				|  |  | +        logger.info("Returning stats response: %s", response)
 | 
	
		
			
				|  |  | +        return response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def GetClientAccumulatedStats(
 | 
	
		
			
				|  |  | +            self, request: messages_pb2.LoadBalancerAccumulatedStatsRequest,
 | 
	
		
			
				|  |  | +            context: grpc.ServicerContext
 | 
	
		
			
				|  |  | +    ) -> messages_pb2.LoadBalancerAccumulatedStatsResponse:
 | 
	
		
			
				|  |  | +        logger.info("Received cumulative stats request.")
 | 
	
		
			
				|  |  | +        response = messages_pb2.LoadBalancerAccumulatedStatsResponse()
 | 
	
		
			
				|  |  | +        with _global_lock:
 | 
	
		
			
				|  |  | +            for method in _SUPPORTED_METHODS:
 | 
	
		
			
				|  |  | +                caps_method = _METHOD_CAMEL_TO_CAPS_SNAKE[method]
 | 
	
		
			
				|  |  | +                response.num_rpcs_started_by_method[
 | 
	
		
			
				|  |  | +                    caps_method] = _global_rpcs_started[method]
 | 
	
		
			
				|  |  | +                response.num_rpcs_succeeded_by_method[
 | 
	
		
			
				|  |  | +                    caps_method] = _global_rpcs_succeeded[method]
 | 
	
		
			
				|  |  | +                response.num_rpcs_failed_by_method[
 | 
	
		
			
				|  |  | +                    caps_method] = _global_rpcs_failed[method]
 | 
	
		
			
				|  |  | +        logger.info("Returning cumulative stats response.")
 | 
	
		
			
				|  |  |          return response
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -154,6 +191,8 @@ def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str,
 | 
	
		
			
				|  |  |      exception = future.exception()
 | 
	
		
			
				|  |  |      hostname = ""
 | 
	
		
			
				|  |  |      if exception is not None:
 | 
	
		
			
				|  |  | +        with _global_lock:
 | 
	
		
			
				|  |  | +            _global_rpcs_failed[method] += 1
 | 
	
		
			
				|  |  |          if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
 | 
	
		
			
				|  |  |              logger.error(f"RPC {rpc_id} timed out")
 | 
	
		
			
				|  |  |          else:
 | 
	
	
		
			
				|  | @@ -167,6 +206,12 @@ def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str,
 | 
	
		
			
				|  |  |                  break
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              hostname = response.hostname
 | 
	
		
			
				|  |  | +        if future.code() == grpc.StatusCode.OK:
 | 
	
		
			
				|  |  | +            with _global_lock:
 | 
	
		
			
				|  |  | +                _global_rpcs_succeeded[method] += 1
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            with _global_lock:
 | 
	
		
			
				|  |  | +                _global_rpcs_failed[method] += 1
 | 
	
		
			
				|  |  |          if print_response:
 | 
	
		
			
				|  |  |              if future.code() == grpc.StatusCode.OK:
 | 
	
		
			
				|  |  |                  logger.info("Successful response.")
 | 
	
	
		
			
				|  | @@ -195,24 +240,55 @@ def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
 | 
	
		
			
				|  |  |          future.cancel()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]],
 | 
	
		
			
				|  |  | -                        qps: int, server: str, rpc_timeout_sec: int,
 | 
	
		
			
				|  |  | -                        print_response: bool):
 | 
	
		
			
				|  |  | +class _ChannelConfiguration:
 | 
	
		
			
				|  |  | +    """Configuration for a single client channel.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    Instances of this class are meant to be dealt with as PODs. That is,
 | 
	
		
			
				|  |  | +    data member should be accessed directly. This class is not thread-safe.
 | 
	
		
			
				|  |  | +    When accessing any of its members, the lock member should be held.
 | 
	
		
			
				|  |  | +    """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, method: str, metadata: Sequence[Tuple[str, str]],
 | 
	
		
			
				|  |  | +                 qps: int, server: str, rpc_timeout_sec: int,
 | 
	
		
			
				|  |  | +                 print_response: bool):
 | 
	
		
			
				|  |  | +        # condition is signalled when a change is made to the config.
 | 
	
		
			
				|  |  | +        self.condition = threading.Condition()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        self.method = method
 | 
	
		
			
				|  |  | +        self.metadata = metadata
 | 
	
		
			
				|  |  | +        self.qps = qps
 | 
	
		
			
				|  |  | +        self.server = server
 | 
	
		
			
				|  |  | +        self.rpc_timeout_sec = rpc_timeout_sec
 | 
	
		
			
				|  |  | +        self.print_response = print_response
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +def _run_single_channel(config: _ChannelConfiguration) -> None:
 | 
	
		
			
				|  |  |      global _global_rpc_id  # pylint: disable=global-statement
 | 
	
		
			
				|  |  | -    duration_per_query = 1.0 / float(qps)
 | 
	
		
			
				|  |  | +    with config.condition:
 | 
	
		
			
				|  |  | +        server = config.server
 | 
	
		
			
				|  |  |      with grpc.insecure_channel(server) as channel:
 | 
	
		
			
				|  |  |          stub = test_pb2_grpc.TestServiceStub(channel)
 | 
	
		
			
				|  |  |          futures: Dict[int, Tuple[grpc.Future, str]] = {}
 | 
	
		
			
				|  |  |          while not _stop_event.is_set():
 | 
	
		
			
				|  |  | +            with config.condition:
 | 
	
		
			
				|  |  | +                if config.qps == 0:
 | 
	
		
			
				|  |  | +                    config.condition.wait(
 | 
	
		
			
				|  |  | +                        timeout=_CONFIG_CHANGE_TIMEOUT.total_seconds())
 | 
	
		
			
				|  |  | +                    continue
 | 
	
		
			
				|  |  | +                else:
 | 
	
		
			
				|  |  | +                    duration_per_query = 1.0 / float(config.qps)
 | 
	
		
			
				|  |  |              request_id = None
 | 
	
		
			
				|  |  |              with _global_lock:
 | 
	
		
			
				|  |  |                  request_id = _global_rpc_id
 | 
	
		
			
				|  |  |                  _global_rpc_id += 1
 | 
	
		
			
				|  |  | +                _global_rpcs_started[config.method] += 1
 | 
	
		
			
				|  |  |              start = time.time()
 | 
	
		
			
				|  |  |              end = start + duration_per_query
 | 
	
		
			
				|  |  | -            _start_rpc(method, metadata, request_id, stub,
 | 
	
		
			
				|  |  | -                       float(rpc_timeout_sec), futures)
 | 
	
		
			
				|  |  | -            _remove_completed_rpcs(futures, print_response)
 | 
	
		
			
				|  |  | +            with config.condition:
 | 
	
		
			
				|  |  | +                _start_rpc(config.method, config.metadata, request_id, stub,
 | 
	
		
			
				|  |  | +                           float(config.rpc_timeout_sec), futures)
 | 
	
		
			
				|  |  | +            with config.condition:
 | 
	
		
			
				|  |  | +                _remove_completed_rpcs(futures, config.print_response)
 | 
	
		
			
				|  |  |              logger.debug(f"Currently {len(futures)} in-flight RPCs")
 | 
	
		
			
				|  |  |              now = time.time()
 | 
	
		
			
				|  |  |              while now < end:
 | 
	
	
		
			
				|  | @@ -221,30 +297,54 @@ def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]],
 | 
	
		
			
				|  |  |          _cancel_all_rpcs(futures)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +class _XdsUpdateClientConfigureServicer(
 | 
	
		
			
				|  |  | +        test_pb2_grpc.XdsUpdateClientConfigureServiceServicer):
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __init__(self, per_method_configs: Mapping[str, _ChannelConfiguration],
 | 
	
		
			
				|  |  | +                 qps: int):
 | 
	
		
			
				|  |  | +        super(_XdsUpdateClientConfigureServicer).__init__()
 | 
	
		
			
				|  |  | +        self._per_method_configs = per_method_configs
 | 
	
		
			
				|  |  | +        self._qps = qps
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def Configure(self, request: messages_pb2.ClientConfigureRequest,
 | 
	
		
			
				|  |  | +                  context: grpc.ServicerContext
 | 
	
		
			
				|  |  | +                 ) -> messages_pb2.ClientConfigureResponse:
 | 
	
		
			
				|  |  | +        logger.info("Received Configure RPC: %s", request)
 | 
	
		
			
				|  |  | +        method_strs = (_METHOD_ENUM_TO_STR[t] for t in request.types)
 | 
	
		
			
				|  |  | +        for method in _SUPPORTED_METHODS:
 | 
	
		
			
				|  |  | +            method_enum = _METHOD_STR_TO_ENUM[method]
 | 
	
		
			
				|  |  | +            if method in method_strs:
 | 
	
		
			
				|  |  | +                qps = self._qps
 | 
	
		
			
				|  |  | +                metadata = ((md.key, md.value)
 | 
	
		
			
				|  |  | +                            for md in request.metadata
 | 
	
		
			
				|  |  | +                            if md.type == method_enum)
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                qps = 0
 | 
	
		
			
				|  |  | +                metadata = ()
 | 
	
		
			
				|  |  | +            channel_config = self._per_method_configs[method]
 | 
	
		
			
				|  |  | +            with channel_config.condition:
 | 
	
		
			
				|  |  | +                channel_config.qps = qps
 | 
	
		
			
				|  |  | +                channel_config.metadata = list(metadata)
 | 
	
		
			
				|  |  | +                channel_config.condition.notify_all()
 | 
	
		
			
				|  |  | +        return messages_pb2.ClientConfigureResponse()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class _MethodHandle:
 | 
	
		
			
				|  |  |      """An object grouping together threads driving RPCs for a method."""
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      _channel_threads: List[threading.Thread]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def __init__(self, method: str, metadata: Sequence[Tuple[str, str]],
 | 
	
		
			
				|  |  | -                 num_channels: int, qps: int, server: str, rpc_timeout_sec: int,
 | 
	
		
			
				|  |  | -                 print_response: bool):
 | 
	
		
			
				|  |  | +    def __init__(self, num_channels: int,
 | 
	
		
			
				|  |  | +                 channel_config: _ChannelConfiguration):
 | 
	
		
			
				|  |  |          """Creates and starts a group of threads running the indicated method."""
 | 
	
		
			
				|  |  |          self._channel_threads = []
 | 
	
		
			
				|  |  |          for i in range(num_channels):
 | 
	
		
			
				|  |  |              thread = threading.Thread(target=_run_single_channel,
 | 
	
		
			
				|  |  | -                                      args=(
 | 
	
		
			
				|  |  | -                                          method,
 | 
	
		
			
				|  |  | -                                          metadata,
 | 
	
		
			
				|  |  | -                                          qps,
 | 
	
		
			
				|  |  | -                                          server,
 | 
	
		
			
				|  |  | -                                          rpc_timeout_sec,
 | 
	
		
			
				|  |  | -                                          print_response,
 | 
	
		
			
				|  |  | -                                      ))
 | 
	
		
			
				|  |  | +                                      args=(channel_config,))
 | 
	
		
			
				|  |  |              thread.start()
 | 
	
		
			
				|  |  |              self._channel_threads.append(thread)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def stop(self):
 | 
	
		
			
				|  |  | +    def stop(self) -> None:
 | 
	
		
			
				|  |  |          """Joins all threads referenced by the handle."""
 | 
	
		
			
				|  |  |          for channel_thread in self._channel_threads:
 | 
	
		
			
				|  |  |              channel_thread.join()
 | 
	
	
		
			
				|  | @@ -255,15 +355,24 @@ def _run(args: argparse.Namespace, methods: Sequence[str],
 | 
	
		
			
				|  |  |      logger.info("Starting python xDS Interop Client.")
 | 
	
		
			
				|  |  |      global _global_server  # pylint: disable=global-statement
 | 
	
		
			
				|  |  |      method_handles = []
 | 
	
		
			
				|  |  | -    for method in methods:
 | 
	
		
			
				|  |  | -        method_handles.append(
 | 
	
		
			
				|  |  | -            _MethodHandle(method, per_method_metadata.get(method, []),
 | 
	
		
			
				|  |  | -                          args.num_channels, args.qps, args.server,
 | 
	
		
			
				|  |  | -                          args.rpc_timeout_sec, args.print_response))
 | 
	
		
			
				|  |  | +    channel_configs = {}
 | 
	
		
			
				|  |  | +    for method in _SUPPORTED_METHODS:
 | 
	
		
			
				|  |  | +        if method in methods:
 | 
	
		
			
				|  |  | +            qps = args.qps
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            qps = 0
 | 
	
		
			
				|  |  | +        channel_config = _ChannelConfiguration(
 | 
	
		
			
				|  |  | +            method, per_method_metadata.get(method, []), qps, args.server,
 | 
	
		
			
				|  |  | +            args.rpc_timeout_sec, args.print_response)
 | 
	
		
			
				|  |  | +        channel_configs[method] = channel_config
 | 
	
		
			
				|  |  | +        method_handles.append(_MethodHandle(args.num_channels, channel_config))
 | 
	
		
			
				|  |  |      _global_server = grpc.server(futures.ThreadPoolExecutor())
 | 
	
		
			
				|  |  |      _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
 | 
	
		
			
				|  |  |      test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
 | 
	
		
			
				|  |  |          _LoadBalancerStatsServicer(), _global_server)
 | 
	
		
			
				|  |  | +    test_pb2_grpc.add_XdsUpdateClientConfigureServiceServicer_to_server(
 | 
	
		
			
				|  |  | +        _XdsUpdateClientConfigureServicer(channel_configs, args.qps),
 | 
	
		
			
				|  |  | +        _global_server)
 | 
	
		
			
				|  |  |      _global_server.start()
 | 
	
		
			
				|  |  |      _global_server.wait_for_termination()
 | 
	
		
			
				|  |  |      for method_handle in method_handles:
 |