Просмотр исходного кода

Add limit concurrent RPC feature to asyncio server
* Reduce the allocation of new function

Lidi Zheng 4 лет назад
Родитель
Сommit
3da3cc2168

+ 8 - 0
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi

@@ -67,6 +67,13 @@ cdef enum AioServerStatus:
     AIO_SERVER_STATUS_STOPPING
 
 
+cdef class _ConcurrentRpcLimiter:
+    cdef int _maximum_concurrent_rpcs
+    cdef int _active_rpcs
+    cdef object _active_rpcs_condition # asyncio.Condition
+    cdef object _loop  # asyncio.EventLoop
+
+
 cdef class AioServer:
     cdef Server _server
     cdef list _generic_handlers
@@ -79,5 +86,6 @@ cdef class AioServer:
     cdef object _crash_exception  # Exception
     cdef tuple _interceptors
     cdef object _thread_pool  # concurrent.futures.ThreadPoolExecutor
+    cdef _ConcurrentRpcLimiter _limiter
 
     cdef thread_pool(self)

+ 44 - 4
src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi

@@ -781,6 +781,40 @@ cdef CallbackFailureHandler SERVER_SHUTDOWN_FAILURE_HANDLER = CallbackFailureHan
     InternalError)
 
 
+cdef class _ConcurrentRpcLimiter:
+
+    def __cinit__(self, int maximum_concurrent_rpcs, object loop):
+        if maximum_concurrent_rpcs <= 0:
+            raise ValueError("maximum_concurrent_rpcs should be a postive integer")
+        self._maximum_concurrent_rpcs = maximum_concurrent_rpcs
+        self._active_rpcs = 0
+        self._active_rpcs_condition = asyncio.Condition()
+        self._loop = loop
+
+    async def check_before_request_call(self):
+        await self._active_rpcs_condition.acquire()
+        try:
+            predicate = lambda: self._active_rpcs < self._maximum_concurrent_rpcs
+            await self._active_rpcs_condition.wait_for(predicate)
+            self._active_rpcs += 1
+        finally:
+            self._active_rpcs_condition.release()
+
+    async def _decrease_active_rpcs_count_with_lock(self):
+        await self._active_rpcs_condition.acquire()
+        try:
+            self._active_rpcs -= 1
+            self._active_rpcs_condition.notify()
+        finally:
+            self._active_rpcs_condition.release()
+
+    def _decrease_active_rpcs_count(self, unused_future):
+        self._loop.create_task(self._decrease_active_rpcs_count_with_lock())
+
+    def decrease_once_finished(self, object rpc_task):
+        rpc_task.add_done_callback(self._decrease_active_rpcs_count)
+
+
 cdef class AioServer:
 
     def __init__(self, loop, thread_pool, generic_handlers, interceptors,
@@ -815,9 +849,9 @@ cdef class AioServer:
             self._interceptors = ()
 
         self._thread_pool = thread_pool
-
-        if maximum_concurrent_rpcs:
-            raise NotImplementedError()
+        if maximum_concurrent_rpcs is not None:
+            self._limiter = _ConcurrentRpcLimiter(maximum_concurrent_rpcs,
+                                                  loop)
 
     def add_generic_rpc_handlers(self, object generic_rpc_handlers):
         self._generic_handlers.extend(generic_rpc_handlers)
@@ -860,6 +894,9 @@ cdef class AioServer:
             if self._status != AIO_SERVER_STATUS_RUNNING:
                 break
 
+            if self._limiter is not None:
+                await self._limiter.check_before_request_call()
+
             # Accepts new request from Core
             rpc_state = await self._request_call()
 
@@ -874,7 +911,7 @@ cdef class AioServer:
                                    self._loop)
 
             # Fires off a task that listens on the cancellation from client.
-            self._loop.create_task(
+            rpc_task = self._loop.create_task(
                 _schedule_rpc_coro(
                     rpc_coro,
                     rpc_state,
@@ -882,6 +919,9 @@ cdef class AioServer:
                 )
             )
 
+            if self._limiter is not None:
+                self._limiter.decrease_once_finished(rpc_task)
+
     def _serving_task_crash_handler(self, object task):
         """Shutdown the server immediately if unexpectedly exited."""
         if task.cancelled():

+ 27 - 1
src/python/grpcio_tests/tests_aio/unit/server_test.py

@@ -47,6 +47,7 @@ _REQUEST = b'\x00\x00\x00'
 _RESPONSE = b'\x01\x01\x01'
 _NUM_STREAM_REQUESTS = 3
 _NUM_STREAM_RESPONSES = 5
+_MAXIMUM_CONCURRENT_RPCS = 5
 
 
 class _GenericHandler(grpc.GenericRpcHandler):
@@ -189,7 +190,8 @@ class _GenericHandler(grpc.GenericRpcHandler):
         context.set_code(grpc.StatusCode.INTERNAL)
 
     def service(self, handler_details):
-        self._called.set_result(None)
+        if not self._called.done():
+            self._called.set_result(None)
         return self._routing_table.get(handler_details.method)
 
     async def wait_for_call(self):
@@ -480,6 +482,30 @@ class TestServer(AioTestBase):
         with self.assertRaises(RuntimeError):
             server.add_secure_port(bind_address, server_credentials)
 
+    async def test_maximum_concurrent_rpcs(self):
+        # Build the server with concurrent rpc argument
+        server = aio.server(maximum_concurrent_rpcs=_MAXIMUM_CONCURRENT_RPCS)
+        port = server.add_insecure_port('localhost:0')
+        bind_address = "localhost:%d" % port
+        server.add_generic_rpc_handlers((_GenericHandler(),))
+        await server.start()
+        # Build the channel
+        channel = aio.insecure_channel(bind_address)
+        # Deplete the concurrent quota with 3 times of max RPCs
+        rpcs = []
+        for _ in range(3 * _MAXIMUM_CONCURRENT_RPCS):
+            rpcs.append(channel.unary_unary(_BLOCK_BRIEFLY)(_REQUEST))
+        task = self.loop.create_task(
+            asyncio.wait(rpcs, return_when=asyncio.FIRST_EXCEPTION))
+        # Each batch took test_constants.SHORT_TIMEOUT /2
+        start_time = time.time()
+        await task
+        elapsed_time = time.time() - start_time
+        self.assertGreater(elapsed_time, test_constants.SHORT_TIMEOUT * 3 / 2)
+        # Clean-up
+        await channel.close()
+        await server.stop(0)
+
 
 if __name__ == '__main__':
     logging.basicConfig(level=logging.DEBUG)