소스 검색

Use a weakset for storing ongoing calls

Pau Freixes 5 년 전
부모
커밋
c94364f311
2개의 변경된 파일29개의 추가작업 그리고 20개의 파일을 삭제
  1. 5 4
      src/python/grpcio/grpc/experimental/aio/_channel.py
  2. 24 16
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

+ 5 - 4
src/python/grpcio/grpc/experimental/aio/_channel.py

@@ -13,7 +13,8 @@
 # limitations under the License.
 """Invocation-side implementation of gRPC Asyncio Python."""
 import asyncio
-from typing import Any, AsyncIterable, Optional, Sequence, Set, Text
+from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text
+from weakref import WeakSet
 
 import logging
 import grpc
@@ -37,10 +38,10 @@ _LOGGER = logging.getLogger(__name__)
 class _OngoingCalls:
     """Internal class used for have visibility of the ongoing calls."""
 
-    _calls: Set[_base_call.RpcContext]
+    _calls: AbstractSet[_base_call.RpcContext]
 
     def __init__(self):
-        self._calls = set()
+        self._calls = WeakSet()
 
     def _remove_call(self, call: _base_call.RpcContext):
         self._calls.remove(call)
@@ -401,7 +402,7 @@ class Channel:
         # A new set is created acting as a shallow copy because
         # when cancellation happens the calls are automatically
         # removed from the originally set.
-        calls = set(self._ongoing_calls.calls)
+        calls = WeakSet(data=self._ongoing_calls.calls)
         for call in calls:
             call.cancel()
   

+ 24 - 16
src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@@ -16,6 +16,7 @@
 import asyncio
 import logging
 import unittest
+from weakref import WeakSet
 
 import grpc
 from grpc.experimental import aio
@@ -32,36 +33,43 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
 
 class TestOngoingCalls(unittest.TestCase):
 
-    def test_trace_call(self):
-
-        class FakeCall(_base_call.RpcContext):
+    class FakeCall(_base_call.RpcContext):
 
-            def add_done_callback(self, callback):
-                self.callback = callback
+        def add_done_callback(self, callback):
+            self.callback = callback
 
-            def cancel(self):
-                raise NotImplementedError
+        def cancel(self):
+            raise NotImplementedError
 
-            def cancelled(self):
-                raise NotImplementedError
+        def cancelled(self):
+            raise NotImplementedError
 
-            def done(self):
-                raise NotImplementedError
+        def done(self):
+            raise NotImplementedError
 
-            def time_remaining(self):
-                raise NotImplementedError
+        def time_remaining(self):
+            raise NotImplementedError
 
+    def test_trace_call(self):
         ongoing_calls = _OngoingCalls()
         self.assertEqual(ongoing_calls.size(), 0)
 
-        call = FakeCall()
+        call = TestOngoingCalls.FakeCall()
         ongoing_calls.trace_call(call)
         self.assertEqual(ongoing_calls.size(), 1)
-        self.assertEqual(ongoing_calls.calls, set([call]))
+        self.assertEqual(ongoing_calls.calls, WeakSet([call]))
 
         call.callback(call)
         self.assertEqual(ongoing_calls.size(), 0)
-        self.assertEqual(ongoing_calls.calls, set())
+        self.assertEqual(ongoing_calls.calls, WeakSet())
+
+    def test_deleted_call(self):
+        ongoing_calls = _OngoingCalls()
+
+        call = TestOngoingCalls.FakeCall()
+        ongoing_calls.trace_call(call)
+        del(call)
+        self.assertEqual(ongoing_calls.size(), 0)
 
 
 class TestCloseChannel(AioTestBase):