worker_server.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2016 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import multiprocessing
  15. import random
  16. import threading
  17. import time
  18. from concurrent import futures
  19. import grpc
  20. from src.proto.grpc.testing import control_pb2
  21. from src.proto.grpc.testing import benchmark_service_pb2_grpc
  22. from src.proto.grpc.testing import worker_service_pb2_grpc
  23. from src.proto.grpc.testing import stats_pb2
  24. from tests.qps import benchmark_client
  25. from tests.qps import benchmark_server
  26. from tests.qps import client_runner
  27. from tests.qps import histogram
  28. from tests.unit import resources
  29. from tests.unit import test_common
  30. class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer):
  31. """Python Worker Server implementation."""
  32. def __init__(self, server_port=None):
  33. self._quit_event = threading.Event()
  34. self._server_port = server_port
  35. def RunServer(self, request_iterator, context):
  36. config = next(request_iterator).setup #pylint: disable=stop-iteration-return
  37. server, port = self._create_server(config)
  38. cores = multiprocessing.cpu_count()
  39. server.start()
  40. start_time = time.time()
  41. yield self._get_server_status(start_time, start_time, port, cores)
  42. for request in request_iterator:
  43. end_time = time.time()
  44. status = self._get_server_status(start_time, end_time, port, cores)
  45. if request.mark.reset:
  46. start_time = end_time
  47. yield status
  48. server.stop(None)
  49. def _get_server_status(self, start_time, end_time, port, cores):
  50. end_time = time.time()
  51. elapsed_time = end_time - start_time
  52. stats = stats_pb2.ServerStats(time_elapsed=elapsed_time,
  53. time_user=elapsed_time,
  54. time_system=elapsed_time)
  55. return control_pb2.ServerStatus(stats=stats, port=port, cores=cores)
  56. def _create_server(self, config):
  57. if config.async_server_threads == 0:
  58. # This is the default concurrent.futures thread pool size, but
  59. # None doesn't seem to work
  60. server_threads = multiprocessing.cpu_count() * 5
  61. else:
  62. server_threads = config.async_server_threads
  63. server = test_common.test_server(max_workers=server_threads)
  64. if config.server_type == control_pb2.ASYNC_SERVER:
  65. servicer = benchmark_server.BenchmarkServer()
  66. benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
  67. servicer, server)
  68. elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
  69. resp_size = config.payload_config.bytebuf_params.resp_size
  70. servicer = benchmark_server.GenericBenchmarkServer(resp_size)
  71. method_implementations = {
  72. 'StreamingCall':
  73. grpc.stream_stream_rpc_method_handler(servicer.StreamingCall
  74. ),
  75. 'UnaryCall':
  76. grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
  77. }
  78. handler = grpc.method_handlers_generic_handler(
  79. 'grpc.testing.BenchmarkService', method_implementations)
  80. server.add_generic_rpc_handlers((handler,))
  81. else:
  82. raise Exception('Unsupported server type {}'.format(
  83. config.server_type))
  84. if self._server_port is not None and config.port == 0:
  85. server_port = self._server_port
  86. else:
  87. server_port = config.port
  88. if config.HasField('security_params'): # Use SSL
  89. server_creds = grpc.ssl_server_credentials(
  90. ((resources.private_key(), resources.certificate_chain()),))
  91. port = server.add_secure_port('[::]:{}'.format(server_port),
  92. server_creds)
  93. else:
  94. port = server.add_insecure_port('[::]:{}'.format(server_port))
  95. return (server, port)
  96. def RunClient(self, request_iterator, context):
  97. config = next(request_iterator).setup #pylint: disable=stop-iteration-return
  98. client_runners = []
  99. qps_data = histogram.Histogram(config.histogram_params.resolution,
  100. config.histogram_params.max_possible)
  101. start_time = time.time()
  102. # Create a client for each channel
  103. for i in range(config.client_channels):
  104. server = config.server_targets[i % len(config.server_targets)]
  105. runner = self._create_client_runner(server, config, qps_data)
  106. client_runners.append(runner)
  107. runner.start()
  108. end_time = time.time()
  109. yield self._get_client_status(start_time, end_time, qps_data)
  110. # Respond to stat requests
  111. for request in request_iterator:
  112. end_time = time.time()
  113. status = self._get_client_status(start_time, end_time, qps_data)
  114. if request.mark.reset:
  115. qps_data.reset()
  116. start_time = time.time()
  117. yield status
  118. # Cleanup the clients
  119. for runner in client_runners:
  120. runner.stop()
  121. def _get_client_status(self, start_time, end_time, qps_data):
  122. latencies = qps_data.get_data()
  123. end_time = time.time()
  124. elapsed_time = end_time - start_time
  125. stats = stats_pb2.ClientStats(latencies=latencies,
  126. time_elapsed=elapsed_time,
  127. time_user=elapsed_time,
  128. time_system=elapsed_time)
  129. return control_pb2.ClientStatus(stats=stats)
  130. def _create_client_runner(self, server, config, qps_data):
  131. no_ping_pong = False
  132. if config.client_type == control_pb2.SYNC_CLIENT:
  133. if config.rpc_type == control_pb2.UNARY:
  134. client = benchmark_client.UnarySyncBenchmarkClient(
  135. server, config, qps_data)
  136. elif config.rpc_type == control_pb2.STREAMING:
  137. client = benchmark_client.StreamingSyncBenchmarkClient(
  138. server, config, qps_data)
  139. elif config.rpc_type == control_pb2.STREAMING_FROM_SERVER:
  140. no_ping_pong = True
  141. client = benchmark_client.ServerStreamingSyncBenchmarkClient(
  142. server, config, qps_data)
  143. elif config.client_type == control_pb2.ASYNC_CLIENT:
  144. if config.rpc_type == control_pb2.UNARY:
  145. client = benchmark_client.UnaryAsyncBenchmarkClient(
  146. server, config, qps_data)
  147. else:
  148. raise Exception('Async streaming client not supported')
  149. else:
  150. raise Exception('Unsupported client type {}'.format(
  151. config.client_type))
  152. # In multi-channel tests, we split the load across all channels
  153. load_factor = float(config.client_channels)
  154. if config.load_params.WhichOneof('load') == 'closed_loop':
  155. runner = client_runner.ClosedLoopClientRunner(
  156. client, config.outstanding_rpcs_per_channel, no_ping_pong)
  157. else: # Open loop Poisson
  158. alpha = config.load_params.poisson.offered_load / load_factor
  159. def poisson():
  160. while True:
  161. yield random.expovariate(alpha)
  162. runner = client_runner.OpenLoopClientRunner(client, poisson())
  163. return runner
  164. def CoreCount(self, request, context):
  165. return control_pb2.CoreResponse(cores=multiprocessing.cpu_count())
  166. def QuitWorker(self, request, context):
  167. self._quit_event.set()
  168. return control_pb2.Void()
  169. def wait_for_quit(self):
  170. self._quit_event.wait()