Эх сурвалжийг харах

Indicate private type aliases

Sergii Tkachenko 4 жил өмнө
parent
commit
9d5a7fad4d

+ 10 - 10
tools/run_tests/xds_test_driver/bin/run_channelz.py

@@ -36,9 +36,9 @@ flags.adopt_module_key_flags(xds_flags)
 flags.adopt_module_key_flags(xds_k8s_flags)
 
 # Type aliases
-Socket = grpc_channelz.Socket
-XdsTestServer = server_app.XdsTestServer
-XdsTestClient = client_app.XdsTestClient
+_Socket = grpc_channelz.Socket
+_XdsTestServer = server_app.XdsTestServer
+_XdsTestClient = client_app.XdsTestClient
 
 
 def debug_cert(cert):
@@ -75,7 +75,7 @@ def main(argv):
     server_name = xds_flags.SERVER_NAME.value
     server_port = xds_flags.SERVER_PORT.value
     server_pod_ip = get_deployment_pod_ips(server_k8s_ns, server_name)[0]
-    test_server: XdsTestServer = XdsTestServer(
+    test_server: _XdsTestServer = _XdsTestServer(
         ip=server_pod_ip,
         rpc_port=server_port,
         xds_host=xds_flags.SERVER_XDS_HOST.value,
@@ -88,7 +88,7 @@ def main(argv):
     client_port = xds_flags.CLIENT_PORT.value
     client_pod_ip = get_deployment_pod_ips(client_k8s_ns, client_name)[0]
 
-    test_client: XdsTestClient = XdsTestClient(
+    test_client: _XdsTestClient = _XdsTestClient(
         ip=client_pod_ip,
         server_target=test_server.xds_uri,
         rpc_port=client_port,
@@ -96,12 +96,12 @@ def main(argv):
 
     with test_client, test_server:
         test_client.wait_for_active_server_channel()
-        client_socket: Socket = test_client.get_client_socket_with_test_server()
-        server_socket: Socket = test_server.get_server_socket_matching_client(
-            client_socket)
+        client_sock: _Socket = test_client.get_client_socket_with_test_server()
+        server_sock: _Socket = test_server.get_server_socket_matching_client(
+            client_sock)
 
-        server_tls = server_socket.security.tls
-        client_tls = client_socket.security.tls
+        server_tls = server_sock.security.tls
+        client_tls = client_sock.security.tls
 
         print(f'\nServer certs:\n{debug_sock_tls(server_tls)}')
         print(f'\nClient certs:\n{debug_sock_tls(client_tls)}')

+ 15 - 15
tools/run_tests/xds_test_driver/framework/infrastructure/traffic_director.py

@@ -20,24 +20,24 @@ logger = logging.getLogger(__name__)
 
 # Type aliases
 # Compute
-ComputeV1 = gcp.compute.ComputeV1
-HealthCheckProtocol = ComputeV1.HealthCheckProtocol
-BackendServiceProtocol = ComputeV1.BackendServiceProtocol
-GcpResource = ComputeV1.GcpResource
-ZonalGcpResource = ComputeV1.ZonalGcpResource
+_ComputeV1 = gcp.compute.ComputeV1
+HealthCheckProtocol = _ComputeV1.HealthCheckProtocol
+BackendServiceProtocol = _ComputeV1.BackendServiceProtocol
+GcpResource = _ComputeV1.GcpResource
+ZonalGcpResource = _ComputeV1.ZonalGcpResource
 
 # Network Security
-NetworkSecurityV1Alpha1 = gcp.network_security.NetworkSecurityV1Alpha1
-ServerTlsPolicy = NetworkSecurityV1Alpha1.ServerTlsPolicy
-ClientTlsPolicy = NetworkSecurityV1Alpha1.ClientTlsPolicy
+_NetworkSecurityV1Alpha1 = gcp.network_security.NetworkSecurityV1Alpha1
+ServerTlsPolicy = _NetworkSecurityV1Alpha1.ServerTlsPolicy
+ClientTlsPolicy = _NetworkSecurityV1Alpha1.ClientTlsPolicy
 
 # Network Services
-NetworkServicesV1Alpha1 = gcp.network_services.NetworkServicesV1Alpha1
-EndpointConfigSelector = NetworkServicesV1Alpha1.EndpointConfigSelector
+_NetworkServicesV1Alpha1 = gcp.network_services.NetworkServicesV1Alpha1
+EndpointConfigSelector = _NetworkServicesV1Alpha1.EndpointConfigSelector
 
 
 class TrafficDirectorManager:
-    compute: ComputeV1
+    compute: _ComputeV1
     BACKEND_SERVICE_NAME = "backend-service"
     HEALTH_CHECK_NAME = "health-check"
     URL_MAP_NAME = "url-map"
@@ -54,7 +54,7 @@ class TrafficDirectorManager:
             network: str = 'default',
     ):
         # API
-        self.compute = ComputeV1(gcp_api_manager, project)
+        self.compute = _ComputeV1(gcp_api_manager, project)
 
         # Settings
         self.project: str = project
@@ -272,7 +272,7 @@ class TrafficDirectorManager:
 
 
 class TrafficDirectorSecureManager(TrafficDirectorManager):
-    netsec: Optional[NetworkSecurityV1Alpha1]
+    netsec: Optional[_NetworkSecurityV1Alpha1]
     SERVER_TLS_POLICY_NAME = "server-tls-policy"
     CLIENT_TLS_POLICY_NAME = "client-tls-policy"
     ENDPOINT_CONFIG_SELECTOR_NAME = "endpoint-config-selector"
@@ -292,8 +292,8 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
                          network=network)
 
         # API
-        self.netsec = NetworkSecurityV1Alpha1(gcp_api_manager, project)
-        self.netsvc = NetworkServicesV1Alpha1(gcp_api_manager, project)
+        self.netsec = _NetworkSecurityV1Alpha1(gcp_api_manager, project)
+        self.netsvc = _NetworkServicesV1Alpha1(gcp_api_manager, project)
 
         # Managed resources
         self.server_tls_policy: Optional[ServerTlsPolicy] = None

+ 21 - 21
tools/run_tests/xds_test_driver/framework/rpc/grpc_channelz.py

@@ -27,26 +27,26 @@ logger = logging.getLogger(__name__)
 # Channel
 Channel = channelz_pb2.Channel
 ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
-GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
-GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
+_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
+_GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
 # Subchannel
 Subchannel = channelz_pb2.Subchannel
-GetSubchannelRequest = channelz_pb2.GetSubchannelRequest
-GetSubchannelResponse = channelz_pb2.GetSubchannelResponse
+_GetSubchannelRequest = channelz_pb2.GetSubchannelRequest
+_GetSubchannelResponse = channelz_pb2.GetSubchannelResponse
 # Server
 Server = channelz_pb2.Server
-GetServersRequest = channelz_pb2.GetServersRequest
-GetServersResponse = channelz_pb2.GetServersResponse
+_GetServersRequest = channelz_pb2.GetServersRequest
+_GetServersResponse = channelz_pb2.GetServersResponse
 # Sockets
 Socket = channelz_pb2.Socket
 SocketRef = channelz_pb2.SocketRef
-GetSocketRequest = channelz_pb2.GetSocketRequest
-GetSocketResponse = channelz_pb2.GetSocketResponse
+_GetSocketRequest = channelz_pb2.GetSocketRequest
+_GetSocketResponse = channelz_pb2.GetSocketResponse
 Address = channelz_pb2.Address
 Security = channelz_pb2.Security
 # Server Sockets
-GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest
-GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse
+_GetServerSocketsRequest = channelz_pb2.GetServerSocketsRequest
+_GetServerSocketsResponse = channelz_pb2.GetServerSocketsResponse
 
 
 class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
@@ -113,14 +113,14 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
         This does not include subchannels nor non-top level channels.
         """
         start: int = -1
-        response: Optional[GetTopChannelsResponse] = None
+        response: Optional[_GetTopChannelsResponse] = None
         while start < 0 or not response.end:
             # From proto: To request subsequent pages, the client generates this
             # value by adding 1 to the highest seen result ID.
             start += 1
             response = self.call_unary_when_channel_ready(
                 rpc='GetTopChannels',
-                req=GetTopChannelsRequest(start_channel_id=start))
+                req=_GetTopChannelsRequest(start_channel_id=start))
             for channel in response.channel:
                 start = max(start, channel.ref.channel_id)
                 yield channel
@@ -128,13 +128,13 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
     def list_servers(self) -> Iterator[Server]:
         """Iterate over all pages of all servers that exist in the process."""
         start: int = -1
-        response: Optional[GetServersResponse] = None
+        response: Optional[_GetServersResponse] = None
         while start < 0 or not response.end:
             # From proto: To request subsequent pages, the client generates this
             # value by adding 1 to the highest seen result ID.
             start += 1
             response = self.call_unary_when_channel_ready(
-                rpc='GetServers', req=GetServersRequest(start_server_id=start))
+                rpc='GetServers', req=_GetServersRequest(start_server_id=start))
             for server in response.server:
                 start = max(start, server.ref.server_id)
                 yield server
@@ -142,15 +142,15 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
     def list_server_sockets(self, server_id) -> Iterator[Socket]:
         """Iterate over all server sockets that exist in server process."""
         start: int = -1
-        response: Optional[GetServerSocketsResponse] = None
+        response: Optional[_GetServerSocketsResponse] = None
         while start < 0 or not response.end:
             # From proto: To request subsequent pages, the client generates this
             # value by adding 1 to the highest seen result ID.
             start += 1
             response = self.call_unary_when_channel_ready(
                 rpc='GetServerSockets',
-                req=GetServerSocketsRequest(server_id=server_id,
-                                            start_socket_id=start))
+                req=_GetServerSocketsRequest(server_id=server_id,
+                                             start_socket_id=start))
             socket_ref: SocketRef
             for socket_ref in response.socket_ref:
                 start = max(start, socket_ref.socket_id)
@@ -159,13 +159,13 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
 
     def get_subchannel(self, subchannel_id) -> Subchannel:
         """Return a single Subchannel, otherwise raises RpcError."""
-        response: GetSubchannelResponse = self.call_unary_when_channel_ready(
+        response: _GetSubchannelResponse = self.call_unary_when_channel_ready(
             rpc='GetSubchannel',
-            req=GetSubchannelRequest(subchannel_id=subchannel_id))
+            req=_GetSubchannelRequest(subchannel_id=subchannel_id))
         return response.subchannel
 
     def get_socket(self, socket_id) -> Socket:
         """Return a single Socket, otherwise raises RpcError."""
-        response: GetSocketResponse = self.call_unary_when_channel_ready(
-            rpc='GetSocket', req=GetSocketRequest(socket_id=socket_id))
+        response: _GetSocketResponse = self.call_unary_when_channel_ready(
+            rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id))
         return response.socket

+ 5 - 5
tools/run_tests/xds_test_driver/framework/rpc/grpc_testing.py

@@ -20,8 +20,8 @@ from src.proto.grpc.testing import test_pb2_grpc
 from src.proto.grpc.testing import messages_pb2
 
 # Type aliases
-LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest
-LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse
+_LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest
+_LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse
 
 
 class LoadBalancerStatsServiceClient(framework.rpc.GrpcClientHelper):
@@ -36,12 +36,12 @@ class LoadBalancerStatsServiceClient(framework.rpc.GrpcClientHelper):
             *,
             num_rpcs: int,
             timeout_sec: Optional[int] = STATS_PARTIAL_RESULTS_TIMEOUT_SEC,
-    ) -> LoadBalancerStatsResponse:
+    ) -> _LoadBalancerStatsResponse:
         if timeout_sec is None:
             timeout_sec = self.STATS_PARTIAL_RESULTS_TIMEOUT_SEC
 
         return self.call_unary_when_channel_ready(
             rpc='GetClientStats',
             wait_for_ready_sec=timeout_sec,
-            req=LoadBalancerStatsRequest(num_rpcs=num_rpcs,
-                                         timeout_sec=timeout_sec))
+            req=_LoadBalancerStatsRequest(num_rpcs=num_rpcs,
+                                          timeout_sec=timeout_sec))

+ 12 - 11
tools/run_tests/xds_test_driver/framework/test_app/client_app.py

@@ -26,9 +26,9 @@ from framework.test_app import base_runner
 logger = logging.getLogger(__name__)
 
 # Type aliases
-ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
-ChannelConnectivityState = grpc_channelz.ChannelConnectivityState
-LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
+_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
+_ChannelConnectivityState = grpc_channelz.ChannelConnectivityState
+_LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
 
 
 class XdsTestClient(framework.rpc.GrpcApp):
@@ -48,20 +48,21 @@ class XdsTestClient(framework.rpc.GrpcApp):
 
     @property
     @functools.lru_cache(None)
-    def load_balancer_stats(self) -> LoadBalancerStatsServiceClient:
-        return LoadBalancerStatsServiceClient(self._make_channel(self.rpc_port))
+    def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient:
+        return _LoadBalancerStatsServiceClient(self._make_channel(
+            self.rpc_port))
 
     @property
     @functools.lru_cache(None)
-    def channelz(self) -> ChannelzServiceClient:
-        return ChannelzServiceClient(self._make_channel(self.maintenance_port))
+    def channelz(self) -> _ChannelzServiceClient:
+        return _ChannelzServiceClient(self._make_channel(self.maintenance_port))
 
     def get_load_balancer_stats(
             self,
             *,
             num_rpcs: int,
             timeout_sec: Optional[int] = None,
-    ) -> grpc_testing.LoadBalancerStatsResponse:
+    ) -> grpc_testing._LoadBalancerStatsResponse:
         """
         Shortcut to LoadBalancerStatsServiceClient.get_client_stats()
         """
@@ -85,10 +86,10 @@ class XdsTestClient(framework.rpc.GrpcApp):
 
     def get_active_server_channel(self) -> Optional[grpc_channelz.Channel]:
         for channel in self.get_server_channels():
-            state: ChannelConnectivityState = channel.data.state
+            state: _ChannelConnectivityState = channel.data.state
             logger.debug('Server channel: %s, state: %s', channel.ref.name,
-                         ChannelConnectivityState.State.Name(state.state))
-            if state.state is ChannelConnectivityState.READY:
+                         _ChannelConnectivityState.State.Name(state.state))
+            if state.state is _ChannelConnectivityState.READY:
                 return channel
         raise self.NotFound('Client has no active channel with the server')
 

+ 3 - 3
tools/run_tests/xds_test_driver/framework/test_app/server_app.py

@@ -23,7 +23,7 @@ from framework.test_app import base_runner
 logger = logging.getLogger(__name__)
 
 # Type aliases
-ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
+_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
 
 
 class XdsTestServer(framework.rpc.GrpcApp):
@@ -48,8 +48,8 @@ class XdsTestServer(framework.rpc.GrpcApp):
 
     @property
     @functools.lru_cache(None)
-    def channelz(self) -> ChannelzServiceClient:
-        return ChannelzServiceClient(self._make_channel(self.maintenance_port))
+    def channelz(self) -> _ChannelzServiceClient:
+        return _ChannelzServiceClient(self._make_channel(self.maintenance_port))
 
     def set_xds_address(self, xds_host, xds_port: Optional[int] = None):
         self.xds_host, self.xds_port = xds_host, xds_port

+ 4 - 4
tools/run_tests/xds_test_driver/tests/baseline_test.py

@@ -22,8 +22,8 @@ logger = logging.getLogger(__name__)
 flags.adopt_module_key_flags(xds_k8s_testcase)
 
 # Type aliases
-XdsTestServer = xds_k8s_testcase.XdsTestServer
-XdsTestClient = xds_k8s_testcase.XdsTestClient
+_XdsTestServer = xds_k8s_testcase.XdsTestServer
+_XdsTestClient = xds_k8s_testcase.XdsTestClient
 
 
 class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase):
@@ -31,10 +31,10 @@ class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase):
     def test_ping_pong(self):
         self.setupTrafficDirectorGrpc()
 
-        test_server: XdsTestServer = self.startTestServer()
+        test_server: _XdsTestServer = self.startTestServer()
         self.setupServerBackends()
 
-        test_client: XdsTestClient = self.startTestClient(test_server)
+        test_client: _XdsTestClient = self.startTestClient(test_server)
         self.assertSuccessfulRpcs(test_client)
 
 

+ 12 - 12
tools/run_tests/xds_test_driver/tests/security_test.py

@@ -23,9 +23,9 @@ flags.adopt_module_key_flags(xds_k8s_testcase)
 SKIP_REASON = 'Work in progress'
 
 # Type aliases
-XdsTestServer = xds_k8s_testcase.XdsTestServer
-XdsTestClient = xds_k8s_testcase.XdsTestClient
-SecurityMode = xds_k8s_testcase.SecurityXdsKubernetesTestCase.SecurityMode
+_XdsTestServer = xds_k8s_testcase.XdsTestServer
+_XdsTestClient = xds_k8s_testcase.XdsTestClient
+_SecurityMode = xds_k8s_testcase.SecurityXdsKubernetesTestCase.SecurityMode
 
 
 class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
@@ -37,11 +37,11 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
                                    client_tls=True,
                                    client_mtls=True)
 
-        test_server: XdsTestServer = self.startSecureTestServer()
+        test_server: _XdsTestServer = self.startSecureTestServer()
         self.setupServerBackends()
-        test_client: XdsTestClient = self.startSecureTestClient(test_server)
+        test_client: _XdsTestClient = self.startSecureTestClient(test_server)
 
-        self.assertTestAppSecurity(SecurityMode.MTLS, test_client, test_server)
+        self.assertTestAppSecurity(_SecurityMode.MTLS, test_client, test_server)
         self.assertSuccessfulRpcs(test_client)
 
     def test_tls(self):
@@ -51,11 +51,11 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
                                    client_tls=True,
                                    client_mtls=False)
 
-        test_server: XdsTestServer = self.startSecureTestServer()
+        test_server: _XdsTestServer = self.startSecureTestServer()
         self.setupServerBackends()
-        test_client: XdsTestClient = self.startSecureTestClient(test_server)
+        test_client: _XdsTestClient = self.startSecureTestClient(test_server)
 
-        self.assertTestAppSecurity(SecurityMode.TLS, test_client, test_server)
+        self.assertTestAppSecurity(_SecurityMode.TLS, test_client, test_server)
         self.assertSuccessfulRpcs(test_client)
 
     def test_plaintext_fallback(self):
@@ -65,11 +65,11 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
                                    client_tls=False,
                                    client_mtls=False)
 
-        test_server: XdsTestServer = self.startSecureTestServer()
+        test_server: _XdsTestServer = self.startSecureTestServer()
         self.setupServerBackends()
-        test_client: XdsTestClient = self.startSecureTestClient(test_server)
+        test_client: _XdsTestClient = self.startSecureTestClient(test_server)
 
-        self.assertTestAppSecurity(SecurityMode.PLAINTEXT, test_client,
+        self.assertTestAppSecurity(_SecurityMode.PLAINTEXT, test_client,
                                    test_server)
         self.assertSuccessfulRpcs(test_client)