|  | @@ -19,7 +19,7 @@ import threading
 | 
	
		
			
				|  |  |  import time
 | 
	
		
			
				|  |  |  import sys
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from typing import DefaultDict, Dict, List, Mapping, Set, Sequence
 | 
	
		
			
				|  |  | +from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple
 | 
	
		
			
				|  |  |  import collections
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from concurrent import futures
 | 
	
	
		
			
				|  | @@ -39,6 +39,7 @@ logger.addHandler(console_handler)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  _SUPPORTED_METHODS = ("UnaryCall", "EmptyCall",)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class _StatsWatcher:
 | 
	
		
			
				|  |  |      _start: int
 | 
	
	
		
			
				|  | @@ -118,14 +119,16 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
 | 
	
		
			
				|  |  |          return response
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def _start_rpc(method: str, request_id: int, stub: test_pb2_grpc.TestServiceStub,
 | 
	
		
			
				|  |  | +def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]], request_id: int, stub: test_pb2_grpc.TestServiceStub,
 | 
	
		
			
				|  |  |                 timeout: float, futures: Mapping[int, grpc.Future]) -> None:
 | 
	
		
			
				|  |  | -    logger.info(f"Sending request to backend: {request_id}")
 | 
	
		
			
				|  |  | +    logger.info(f"Sending {method} request to backend: {request_id}")
 | 
	
		
			
				|  |  |      if method == "UnaryCall":
 | 
	
		
			
				|  |  |          future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
 | 
	
		
			
				|  |  | +                                       metadata=metadata,
 | 
	
		
			
				|  |  |                                         timeout=timeout)
 | 
	
		
			
				|  |  |      elif method == "EmptyCall":
 | 
	
		
			
				|  |  |          future = stub.EmptyCall.future(empty_pb2.Empty(),
 | 
	
		
			
				|  |  | +                                       metadata=metadata,
 | 
	
		
			
				|  |  |                                         timeout=timeout)
 | 
	
		
			
				|  |  |      else:
 | 
	
		
			
				|  |  |          raise ValueError(f"Unrecognized method '{method}'.")
 | 
	
	
		
			
				|  | @@ -173,7 +176,7 @@ def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None:
 | 
	
		
			
				|  |  |          future.cancel()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def _run_single_channel(method: str, qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
 | 
	
		
			
				|  |  | +def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]], qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
 | 
	
		
			
				|  |  |      global _global_rpc_id  # pylint: disable=global-statement
 | 
	
		
			
				|  |  |      duration_per_query = 1.0 / float(qps)
 | 
	
		
			
				|  |  |      with grpc.insecure_channel(server) as channel:
 | 
	
	
		
			
				|  | @@ -186,7 +189,7 @@ def _run_single_channel(method: str, qps: int, server: str, rpc_timeout_sec: int
 | 
	
		
			
				|  |  |                  _global_rpc_id += 1
 | 
	
		
			
				|  |  |              start = time.time()
 | 
	
		
			
				|  |  |              end = start + duration_per_query
 | 
	
		
			
				|  |  | -            _start_rpc(method, request_id, stub, float(rpc_timeout_sec), futures)
 | 
	
		
			
				|  |  | +            _start_rpc(method, metadata, request_id, stub, float(rpc_timeout_sec), futures)
 | 
	
		
			
				|  |  |              _remove_completed_rpcs(futures, print_response)
 | 
	
		
			
				|  |  |              logger.debug(f"Currently {len(futures)} in-flight RPCs")
 | 
	
		
			
				|  |  |              now = time.time()
 | 
	
	
		
			
				|  | @@ -200,11 +203,11 @@ class _MethodHandle:
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      _channel_threads: List[threading.Thread]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def __init__(self, method: str, num_channels: int, qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
 | 
	
		
			
				|  |  | +    def __init__(self, method: str, metadata: Sequence[Tuple[str, str]], num_channels: int, qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
 | 
	
		
			
				|  |  |          """Creates and starts a group of threads running the indicated method."""
 | 
	
		
			
				|  |  |          self._channel_threads = []
 | 
	
		
			
				|  |  |          for i in range(num_channels):
 | 
	
		
			
				|  |  | -            thread = threading.Thread(target=_run_single_channel, args=(method, qps, server, rpc_timeout_sec, print_response,))
 | 
	
		
			
				|  |  | +            thread = threading.Thread(target=_run_single_channel, args=(method, metadata, qps, server, rpc_timeout_sec, print_response,))
 | 
	
		
			
				|  |  |              thread.start()
 | 
	
		
			
				|  |  |              self._channel_threads.append(thread)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -214,12 +217,12 @@ class _MethodHandle:
 | 
	
		
			
				|  |  |              channel_thread.join()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -def _run(args: argparse.Namespace, methods: Sequence[str]) -> None:
 | 
	
		
			
				|  |  | +def _run(args: argparse.Namespace, methods: Sequence[str], per_method_metadata: PerMethodMetadataType) -> None:
 | 
	
		
			
				|  |  |      logger.info("Starting python xDS Interop Client.")
 | 
	
		
			
				|  |  |      global _global_server  # pylint: disable=global-statement
 | 
	
		
			
				|  |  |      method_handles = []
 | 
	
		
			
				|  |  |      for method in methods:
 | 
	
		
			
				|  |  | -        method_handles.append(_MethodHandle(method, args.num_channels, args.qps, args.server, args.rpc_timeout_sec, args.print_response))
 | 
	
		
			
				|  |  | +        method_handles.append(_MethodHandle(method, per_method_metadata.get(method, []), args.num_channels, args.qps, args.server, args.rpc_timeout_sec, args.print_response))
 | 
	
		
			
				|  |  |      _global_server = grpc.server(futures.ThreadPoolExecutor())
 | 
	
		
			
				|  |  |      _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
 | 
	
		
			
				|  |  |      test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
 | 
	
	
		
			
				|  | @@ -274,6 +277,13 @@ if __name__ == "__main__":
 | 
	
		
			
				|  |  |                          default="UnaryCall",
 | 
	
		
			
				|  |  |                          type=str,
 | 
	
		
			
				|  |  |                          help=rpc_help)
 | 
	
		
			
				|  |  | +    metadata_help = ("A comma-delimited list of 3-tuples of the form " +
 | 
	
		
			
				|  |  | +                     "METHOD:KEY:VALUE, e.g. " +
 | 
	
		
			
				|  |  | +                     "EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3")
 | 
	
		
			
				|  |  | +    parser.add_argument("--metadata",
 | 
	
		
			
				|  |  | +                        default="",
 | 
	
		
			
				|  |  | +                        type=str,
 | 
	
		
			
				|  |  | +                        help=metadata_help)
 | 
	
		
			
				|  |  |      args = parser.parse_args()
 | 
	
		
			
				|  |  |      signal.signal(signal.SIGINT, _handle_sigint)
 | 
	
		
			
				|  |  |      if args.verbose:
 | 
	
	
		
			
				|  | @@ -282,7 +292,16 @@ if __name__ == "__main__":
 | 
	
		
			
				|  |  |          file_handler = logging.FileHandler(args.log_file, mode='a')
 | 
	
		
			
				|  |  |          file_handler.setFormatter(formatter)
 | 
	
		
			
				|  |  |          logger.addHandler(file_handler)
 | 
	
		
			
				|  |  | -    methods =  args.rpc.split(",")
 | 
	
		
			
				|  |  | +    methods = args.rpc.split(",")
 | 
	
		
			
				|  |  |      if set(methods) - set(_SUPPORTED_METHODS):
 | 
	
		
			
				|  |  |          raise ValueError("--rpc supported methods: {}".format(", ".join(_SUPPORTED_METHODS)))
 | 
	
		
			
				|  |  | -    _run(args, methods)
 | 
	
		
			
				|  |  | +    per_method_metadata = collections.defaultdict(list)
 | 
	
		
			
				|  |  | +    metadata = args.metadata.split(",") if args.metadata else []
 | 
	
		
			
				|  |  | +    for metadatum in metadata:
 | 
	
		
			
				|  |  | +        elems = metadatum.split(":")
 | 
	
		
			
				|  |  | +        if len(elems) != 3:
 | 
	
		
			
				|  |  | +            raise ValueError(f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'")
 | 
	
		
			
				|  |  | +        if elems[0] not in _SUPPORTED_METHODS:
 | 
	
		
			
				|  |  | +            raise ValueError(f"Unrecognized method '{elems[0]}'")
 | 
	
		
			
				|  |  | +        per_method_metadata[elems[0]].append((elems[1], elems[2]))
 | 
	
		
			
				|  |  | +    _run(args, methods, per_method_metadata)
 |