|  | @@ -14,9 +14,12 @@
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import asyncio
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  | +import os
 | 
	
		
			
				|  |  |  import multiprocessing
 | 
	
		
			
				|  |  | +import sys
 | 
	
		
			
				|  |  |  import time
 | 
	
		
			
				|  |  |  from typing import Tuple
 | 
	
		
			
				|  |  | +import collections
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import grpc
 | 
	
		
			
				|  |  |  from grpc.experimental import aio
 | 
	
	
		
			
				|  | @@ -117,6 +120,26 @@ def _create_client(server: str, config: control_pb2.ClientConfig,
 | 
	
		
			
				|  |  |      return client_type(server, config, qps_data)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +WORKER_ENTRY_FILE = os.path.split(os.path.abspath(__file__))[0] + 'worker.py'
 | 
	
		
			
				|  |  | +SubWorker = collections.namedtuple('SubWorker', ['process', 'port', 'channel', 'stub'])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +async def _create_sub_worker(port: int) -> SubWorker:
 | 
	
		
			
				|  |  | +    process = asyncio.create_subprocess_exec(
 | 
	
		
			
				|  |  | +        sys.executable,
 | 
	
		
			
				|  |  | +        WORKER_ENTRY_FILE,
 | 
	
		
			
				|  |  | +        '--driver_port', port
 | 
	
		
			
				|  |  | +    )
 | 
	
		
			
				|  |  | +    channel = aio.insecure_channel(f'localhost:{port}')
 | 
	
		
			
				|  |  | +    stub = worker_service_pb2_grpc.WorkerServiceStub(channel)
 | 
	
		
			
				|  |  | +    return SubWorker(
 | 
	
		
			
				|  |  | +        process=process,
 | 
	
		
			
				|  |  | +        port=port,
 | 
	
		
			
				|  |  | +        channel=channel,
 | 
	
		
			
				|  |  | +        stub=stub,
 | 
	
		
			
				|  |  | +    )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
 | 
	
		
			
				|  |  |      """Python Worker Server implementation."""
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -143,10 +166,7 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
 | 
	
		
			
				|  |  |              yield status
 | 
	
		
			
				|  |  |          await server.stop(None)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def RunClient(self, request_iterator, context):
 | 
	
		
			
				|  |  | -        config = (await context.read()).setup
 | 
	
		
			
				|  |  | -        _LOGGER.info('Received ClientConfig: %s', config)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +    async def _run_single_client(self, config, request_iterator, context):
 | 
	
		
			
				|  |  |          running_tasks = []
 | 
	
		
			
				|  |  |          qps_data = histogram.Histogram(config.histogram_params.resolution,
 | 
	
		
			
				|  |  |                                         config.histogram_params.max_possible)
 | 
	
	
		
			
				|  | @@ -160,7 +180,7 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
 | 
	
		
			
				|  |  |              running_tasks.append(self._loop.create_task(client.run()))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          end_time = time.time()
 | 
	
		
			
				|  |  | -        yield _get_client_status(start_time, end_time, qps_data)
 | 
	
		
			
				|  |  | +        await context.write(_get_client_status(start_time, end_time, qps_data))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Respond to stat requests
 | 
	
		
			
				|  |  |          async for request in request_iterator:
 | 
	
	
		
			
				|  | @@ -169,16 +189,66 @@ class WorkerServicer(worker_service_pb2_grpc.WorkerServiceServicer):
 | 
	
		
			
				|  |  |              if request.mark.reset:
 | 
	
		
			
				|  |  |                  qps_data.reset()
 | 
	
		
			
				|  |  |                  start_time = time.time()
 | 
	
		
			
				|  |  | -            yield status
 | 
	
		
			
				|  |  | +            await context.write(status)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # Cleanup the clients
 | 
	
		
			
				|  |  |          for task in running_tasks:
 | 
	
		
			
				|  |  |              task.cancel()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def CoreCount(self, request, context):
 | 
	
		
			
				|  |  | +    async def RunClient(self, request_iterator, context):
 | 
	
		
			
				|  |  | +        config_request = await context.read()
 | 
	
		
			
				|  |  | +        config = config_request.setup
 | 
	
		
			
				|  |  | +        _LOGGER.info('Received ClientConfig: %s', config)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if config.async_server_threads <= 0:
 | 
	
		
			
				|  |  | +            raise ValueError('async_server_threads can\'t be [%d]' % config.async_server_threads)
 | 
	
		
			
				|  |  | +        elif config.async_server_threads == 1:
 | 
	
		
			
				|  |  | +            await self._run_single_client(config, request_iterator, context)
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            sub_workers = []
 | 
	
		
			
				|  |  | +            for i in range(config.async_server_threads):
 | 
	
		
			
				|  |  | +                port = 40000+i
 | 
	
		
			
				|  |  | +                _LOGGER.info('Creating sub worker at port [%d]...', port)
 | 
	
		
			
				|  |  | +                sub_workers.append(await _create_sub_worker(port))
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            calls = [worker.stub.RunClient() for worker in sub_workers]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            for call in calls:
 | 
	
		
			
				|  |  | +                await call.write(config_request)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            start_time = time.time()
 | 
	
		
			
				|  |  | +            result = histogram.Histogram(config.histogram_params.resolution,
 | 
	
		
			
				|  |  | +                                         config.histogram_params.max_possible)
 | 
	
		
			
				|  |  | +            end_time = time.time()
 | 
	
		
			
				|  |  | +            yield _get_client_status(start_time, end_time, result)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            async for request in request_iterator:
 | 
	
		
			
				|  |  | +                end_time = time.time()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                for call in calls:
 | 
	
		
			
				|  |  | +                    await call.write(request)
 | 
	
		
			
				|  |  | +                    sub_status = await call.read()
 | 
	
		
			
				|  |  | +                    result.merge(sub_status.latencies)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                status = _get_client_status(start_time, end_time, result)
 | 
	
		
			
				|  |  | +                if request.mark.reset:
 | 
	
		
			
				|  |  | +                    result.reset()
 | 
	
		
			
				|  |  | +                    start_time = time.time()
 | 
	
		
			
				|  |  | +                yield status
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            for call in calls:
 | 
	
		
			
				|  |  | +                await call.QuitWorker()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            for worker in sub_workers:
 | 
	
		
			
				|  |  | +                await worker.channel.close()
 | 
	
		
			
				|  |  | +                _LOGGER.info('Waiting for sub worker [%s] to quit...', worker)
 | 
	
		
			
				|  |  | +                await worker.process.wait()
 | 
	
		
			
				|  |  | +                _LOGGER.info('Sub worker [%s] quit', worker)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    async def CoreCount(self, unused_request, unused_context):
 | 
	
		
			
				|  |  |          return control_pb2.CoreResponse(cores=_NUM_CORES)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    async def QuitWorker(self, request, context):
 | 
	
		
			
				|  |  | +    async def QuitWorker(self, unused_request, unused_context):
 | 
	
		
			
				|  |  |          _LOGGER.info('QuitWorker command received.')
 | 
	
		
			
				|  |  |          self._quit_event.set()
 | 
	
		
			
				|  |  |          return control_pb2.Void()
 |