Sergii Tkachenko před 4 roky
rodič
revize
28a6f740f5
22 změnil soubory, kde provedl 538 přidání a 476 odebrání
  1. 6 4
      tools/run_tests/xds_test_driver/bin/run_channelz.py
  2. 30 26
      tools/run_tests/xds_test_driver/bin/run_td_setup.py
  3. 15 13
      tools/run_tests/xds_test_driver/bin/run_test_client.py
  4. 12 10
      tools/run_tests/xds_test_driver/bin/run_test_server.py
  5. 0 1
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/__init__.py
  6. 33 21
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/api.py
  7. 77 73
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/compute.py
  8. 13 13
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_security.py
  9. 2 1
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_services.py
  10. 65 51
      tools/run_tests/xds_test_driver/framework/infrastructure/k8s.py
  11. 82 67
      tools/run_tests/xds_test_driver/framework/infrastructure/traffic_director.py
  12. 3 3
      tools/run_tests/xds_test_driver/framework/rpc/__init__.py
  13. 5 9
      tools/run_tests/xds_test_driver/framework/rpc/grpc_channelz.py
  14. 2 2
      tools/run_tests/xds_test_driver/framework/rpc/grpc_testing.py
  15. 22 31
      tools/run_tests/xds_test_driver/framework/test_app/base_runner.py
  16. 15 11
      tools/run_tests/xds_test_driver/framework/test_app/client_app.py
  17. 29 22
      tools/run_tests/xds_test_driver/framework/test_app/server_app.py
  18. 26 23
      tools/run_tests/xds_test_driver/framework/xds_flags.py
  19. 15 11
      tools/run_tests/xds_test_driver/framework/xds_k8s_flags.py
  20. 70 76
      tools/run_tests/xds_test_driver/framework/xds_k8s_testcase.py
  21. 1 0
      tools/run_tests/xds_test_driver/tests/baseline_test.py
  22. 15 8
      tools/run_tests/xds_test_driver/tests/security_test.py

+ 6 - 4
tools/run_tests/xds_test_driver/bin/run_channelz.py

@@ -26,10 +26,12 @@ from framework.test_app import client_app
 
 logger = logging.getLogger(__name__)
 # Flags
-_SERVER_RPC_HOST = flags.DEFINE_string(
-    'server_rpc_host', default='127.0.0.1', help='Server RPC host')
-_CLIENT_RPC_HOST = flags.DEFINE_string(
-    'client_rpc_host', default='127.0.0.1', help='Client RPC host')
+_SERVER_RPC_HOST = flags.DEFINE_string('server_rpc_host',
+                                       default='127.0.0.1',
+                                       help='Server RPC host')
+_CLIENT_RPC_HOST = flags.DEFINE_string('client_rpc_host',
+                                       default='127.0.0.1',
+                                       help='Client RPC host')
 flags.adopt_module_key_flags(xds_flags)
 flags.adopt_module_key_flags(xds_k8s_flags)
 

+ 30 - 26
tools/run_tests/xds_test_driver/bin/run_td_setup.py

@@ -22,17 +22,19 @@ from framework.infrastructure import gcp
 from framework.infrastructure import k8s
 from framework.infrastructure import traffic_director
 
-
 logger = logging.getLogger(__name__)
 # Flags
-_CMD = flags.DEFINE_enum(
-    'cmd', default='create',
-    enum_values=['cycle', 'create', 'cleanup',
-                 'backends-add', 'backends-cleanup'],
-    help='Command')
-_SECURITY = flags.DEFINE_enum(
-    'security', default=None, enum_values=['mtls', 'tls', 'plaintext'],
-    help='Configure td with security')
+_CMD = flags.DEFINE_enum('cmd',
+                         default='create',
+                         enum_values=[
+                             'cycle', 'create', 'cleanup', 'backends-add',
+                             'backends-cleanup'
+                         ],
+                         help='Command')
+_SECURITY = flags.DEFINE_enum('security',
+                              default=None,
+                              enum_values=['mtls', 'tls', 'plaintext'],
+                              help='Configure td with security')
 flags.adopt_module_key_flags(xds_flags)
 flags.adopt_module_key_flags(xds_k8s_flags)
 
@@ -57,11 +59,10 @@ def main(argv):
     gcp_api_manager = gcp.api.GcpApiManager()
 
     if security_mode is None:
-        td = traffic_director.TrafficDirectorManager(
-            gcp_api_manager,
-            project=project,
-            resource_prefix=namespace,
-            network=network)
+        td = traffic_director.TrafficDirectorManager(gcp_api_manager,
+                                                     project=project,
+                                                     resource_prefix=namespace,
+                                                     network=network)
     else:
         td = traffic_director.TrafficDirectorSecureManager(
             gcp_api_manager,
@@ -80,26 +81,29 @@ def main(argv):
             elif security_mode == 'mtls':
                 logger.info('Setting up mtls')
                 td.setup_for_grpc(server_xds_host, server_xds_port)
-                td.setup_server_security(server_port,
-                                         tls=True, mtls=True)
-                td.setup_client_security(namespace, server_name,
-                                         tls=True, mtls=True)
+                td.setup_server_security(server_port, tls=True, mtls=True)
+                td.setup_client_security(namespace,
+                                         server_name,
+                                         tls=True,
+                                         mtls=True)
 
             elif security_mode == 'tls':
                 logger.info('Setting up tls')
                 td.setup_for_grpc(server_xds_host, server_xds_port)
-                td.setup_server_security(server_port,
-                                         tls=True, mtls=False)
-                td.setup_client_security(namespace, server_name,
-                                         tls=True, mtls=False)
+                td.setup_server_security(server_port, tls=True, mtls=False)
+                td.setup_client_security(namespace,
+                                         server_name,
+                                         tls=True,
+                                         mtls=False)
 
             elif security_mode == 'plaintext':
                 logger.info('Setting up plaintext')
                 td.setup_for_grpc(server_xds_host, server_xds_port)
-                td.setup_server_security(server_port,
-                                         tls=False, mtls=False)
-                td.setup_client_security(namespace, server_name,
-                                         tls=False, mtls=False)
+                td.setup_server_security(server_port, tls=False, mtls=False)
+                td.setup_client_security(namespace,
+                                         server_name,
+                                         tls=False,
+                                         mtls=False)
 
             logger.info('Works!')
     except Exception:

+ 15 - 13
tools/run_tests/xds_test_driver/bin/run_test_client.py

@@ -23,21 +23,23 @@ from framework.test_app import client_app
 
 logger = logging.getLogger(__name__)
 # Flags
-_CMD = flags.DEFINE_enum(
-    'cmd', default='run', enum_values=['run', 'cleanup'],
-    help='Command')
-_SECURE = flags.DEFINE_bool(
-    "secure", default=False,
-    help="Run client in the secure mode")
+_CMD = flags.DEFINE_enum('cmd',
+                         default='run',
+                         enum_values=['run', 'cleanup'],
+                         help='Command')
+_SECURE = flags.DEFINE_bool("secure",
+                            default=False,
+                            help="Run client in the secure mode")
 _QPS = flags.DEFINE_integer('qps', default=25, help='Queries per second')
-_PRINT_RESPONSE = flags.DEFINE_bool(
-    "print_response", default=False,
-    help="Client prints responses")
-_REUSE_NAMESPACE = flags.DEFINE_bool(
-    "reuse_namespace", default=True,
-    help="Use existing namespace if exists")
+_PRINT_RESPONSE = flags.DEFINE_bool("print_response",
+                                    default=False,
+                                    help="Client prints responses")
+_REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace",
+                                     default=True,
+                                     help="Use existing namespace if exists")
 _CLEANUP_NAMESPACE = flags.DEFINE_bool(
-    "cleanup_namespace", default=False,
+    "cleanup_namespace",
+    default=False,
     help="Delete namespace during resource cleanup")
 flags.adopt_module_key_flags(xds_flags)
 flags.adopt_module_key_flags(xds_k8s_flags)

+ 12 - 10
tools/run_tests/xds_test_driver/bin/run_test_server.py

@@ -35,17 +35,19 @@ from framework.test_app import server_app
 
 logger = logging.getLogger(__name__)
 # Flags
-_CMD = flags.DEFINE_enum(
-    'cmd', default='run', enum_values=['run', 'cleanup'],
-    help='Command')
-_SECURE = flags.DEFINE_bool(
-    "secure", default=False,
-    help="Run server in the secure mode")
-_REUSE_NAMESPACE = flags.DEFINE_bool(
-    "reuse_namespace", default=True,
-    help="Use existing namespace if exists")
+_CMD = flags.DEFINE_enum('cmd',
+                         default='run',
+                         enum_values=['run', 'cleanup'],
+                         help='Command')
+_SECURE = flags.DEFINE_bool("secure",
+                            default=False,
+                            help="Run server in the secure mode")
+_REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace",
+                                     default=True,
+                                     help="Use existing namespace if exists")
 _CLEANUP_NAMESPACE = flags.DEFINE_bool(
-    "cleanup_namespace", default=False,
+    "cleanup_namespace",
+    default=False,
     help="Delete namespace during resource cleanup")
 flags.adopt_module_key_flags(xds_flags)
 flags.adopt_module_key_flags(xds_k8s_flags)

+ 0 - 1
tools/run_tests/xds_test_driver/framework/infrastructure/gcp/__init__.py

@@ -15,4 +15,3 @@ from framework.infrastructure.gcp import api
 from framework.infrastructure.gcp import compute
 from framework.infrastructure.gcp import network_security
 from framework.infrastructure.gcp import network_services
-

+ 33 - 21
tools/run_tests/xds_test_driver/framework/infrastructure/gcp/api.py

@@ -28,14 +28,15 @@ import googleapiclient.errors
 import tenacity
 
 logger = logging.getLogger(__name__)
-V1_DISCOVERY_URI = flags.DEFINE_string(
-    "v1_discovery_uri", default=discovery.V1_DISCOVERY_URI,
-    help="Override v1 Discovery URI")
-V2_DISCOVERY_URI = flags.DEFINE_string(
-    "v2_discovery_uri", default=discovery.V2_DISCOVERY_URI,
-    help="Override v2 Discovery URI")
+V1_DISCOVERY_URI = flags.DEFINE_string("v1_discovery_uri",
+                                       default=discovery.V1_DISCOVERY_URI,
+                                       help="Override v1 Discovery URI")
+V2_DISCOVERY_URI = flags.DEFINE_string("v2_discovery_uri",
+                                       default=discovery.V2_DISCOVERY_URI,
+                                       help="Override v2 Discovery URI")
 COMPUTE_V1_DISCOVERY_FILE = flags.DEFINE_string(
-    "compute_v1_discovery_file", default=None,
+    "compute_v1_discovery_file",
+    default=None,
     help="Load compute v1 from discovery file")
 
 # Type aliases
@@ -43,7 +44,9 @@ Operation = operations_pb2.Operation
 
 
 class GcpApiManager:
-    def __init__(self, *,
+
+    def __init__(self,
+                 *,
                  v1_discovery_uri=None,
                  v2_discovery_uri=None,
                  compute_v1_discovery_file=None,
@@ -73,8 +76,9 @@ class GcpApiManager:
     def networksecurity(self, version):
         api_name = 'networksecurity'
         if version == 'v1alpha1':
-            return self._build_from_discovery_v2(
-                api_name, version, api_key=self.private_api_key)
+            return self._build_from_discovery_v2(api_name,
+                                                 version,
+                                                 api_key=self.private_api_key)
 
         raise NotImplementedError(f'Network Security {version} not supported')
 
@@ -82,22 +86,26 @@ class GcpApiManager:
     def networkservices(self, version):
         api_name = 'networkservices'
         if version == 'v1alpha1':
-            return self._build_from_discovery_v2(
-                api_name, version, api_key=self.private_api_key)
+            return self._build_from_discovery_v2(api_name,
+                                                 version,
+                                                 api_key=self.private_api_key)
 
         raise NotImplementedError(f'Network Services {version} not supported')
 
     def _build_from_discovery_v1(self, api_name, version):
-        api = discovery.build(
-            api_name, version, cache_discovery=False,
-            discoveryServiceUrl=self.v1_discovery_uri)
+        api = discovery.build(api_name,
+                              version,
+                              cache_discovery=False,
+                              discoveryServiceUrl=self.v1_discovery_uri)
         self._exit_stack.enter_context(api)
         return api
 
     def _build_from_discovery_v2(self, api_name, version, *, api_key=None):
         key_arg = f'&key={api_key}' if api_key else ''
         api = discovery.build(
-            api_name, version, cache_discovery=False,
+            api_name,
+            version,
+            cache_discovery=False,
             discoveryServiceUrl=f'{self.v2_discovery_uri}{key_arg}')
         self._exit_stack.enter_context(api)
         return api
@@ -121,6 +129,7 @@ class OperationError(Error):
     https://cloud.google.com/apis/design/design_patterns#long_running_operations
     https://github.com/googleapis/googleapis/blob/master/google/longrunning/operations.proto
     """
+
     def __init__(self, api_name, operation_response, message=None):
         self.api_name = api_name
         operation = json_format.ParseDict(operation_response, Operation())
@@ -175,7 +184,8 @@ class GcpStandardCloudApiResource(GcpProjectApiResource):
                          **kwargs):
         logger.debug("Creating %s", body)
         create_req = collection.create(parent=self.parent(),
-                                       body=body, **kwargs)
+                                       body=body,
+                                       **kwargs)
         self._execute(create_req)
 
     @staticmethod
@@ -191,15 +201,17 @@ class GcpStandardCloudApiResource(GcpProjectApiResource):
         except googleapiclient.errors.HttpError as error:
             # noinspection PyProtectedMember
             reason = error._get_reason()
-            logger.info('Delete failed. Error: %s %s',
-                        error.resp.status, reason)
+            logger.info('Delete failed. Error: %s %s', error.resp.status,
+                        reason)
 
-    def _execute(self, request,
+    def _execute(self,
+                 request,
                  timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
         operation = request.execute(num_retries=self._GCP_API_RETRIES)
         self._wait(operation, timeout_sec)
 
-    def _wait(self, operation,
+    def _wait(self,
+              operation,
               timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
         op_name = operation['name']
         logger.debug('Waiting for %s operation, timeout %s sec: %s',

+ 77 - 73
tools/run_tests/xds_test_driver/framework/infrastructure/gcp/compute.py

@@ -50,7 +50,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
         HTTP2 = enum.auto()
         GRPC = enum.auto()
 
-    def create_health_check_tcp(self, name,
+    def create_health_check_tcp(self,
+                                name,
                                 use_serving_port=False) -> GcpResource:
         health_check_settings = {}
         if use_serving_port:
@@ -66,30 +67,31 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
         self._delete_resource(self.api.healthChecks(), healthCheck=name)
 
     def create_backend_service_traffic_director(
-        self,
-        name: str,
-        health_check: GcpResource,
-        protocol: Optional[BackendServiceProtocol] = None
-    ) -> GcpResource:
+            self,
+            name: str,
+            health_check: GcpResource,
+            protocol: Optional[BackendServiceProtocol] = None) -> GcpResource:
         if not isinstance(protocol, self.BackendServiceProtocol):
             raise TypeError(f'Unexpected Backend Service protocol: {protocol}')
-        return self._insert_resource(self.api.backendServices(), {
-            'name': name,
-            'loadBalancingScheme': 'INTERNAL_SELF_MANAGED',  # Traffic Director
-            'healthChecks': [health_check.url],
-            'protocol': protocol.name,
-        })
+        return self._insert_resource(
+            self.api.backendServices(),
+            {
+                'name': name,
+                'loadBalancingScheme':
+                    'INTERNAL_SELF_MANAGED',  # Traffic Director
+                'healthChecks': [health_check.url],
+                'protocol': protocol.name,
+            })
 
     def get_backend_service_traffic_director(self, name: str) -> GcpResource:
         return self._get_resource(self.api.backendServices(),
                                   backendService=name)
 
     def patch_backend_service(self, backend_service, body, **kwargs):
-        self._patch_resource(
-            collection=self.api.backendServices(),
-            backendService=backend_service.name,
-            body=body,
-            **kwargs)
+        self._patch_resource(collection=self.api.backendServices(),
+                             backendService=backend_service.name,
+                             body=body,
+                             **kwargs)
 
     def backend_service_add_backends(self, backend_service, backends):
         backend_list = [{
@@ -98,16 +100,14 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
             'maxRatePerEndpoint': 5
         } for backend in backends]
 
-        self._patch_resource(
-            collection=self.api.backendServices(),
-            body={'backends': backend_list},
-            backendService=backend_service.name)
+        self._patch_resource(collection=self.api.backendServices(),
+                             body={'backends': backend_list},
+                             backendService=backend_service.name)
 
     def backend_service_remove_all_backends(self, backend_service):
-        self._patch_resource(
-            collection=self.api.backendServices(),
-            body={'backends': []},
-            backendService=backend_service.name)
+        self._patch_resource(collection=self.api.backendServices(),
+                             body={'backends': []},
+                             backendService=backend_service.name)
 
     def delete_backend_service(self, name):
         self._delete_resource(self.api.backendServices(), backendService=name)
@@ -122,18 +122,21 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
     ) -> GcpResource:
         if dst_host_rule_match_backend_service is None:
             dst_host_rule_match_backend_service = dst_default_backend_service
-        return self._insert_resource(self.api.urlMaps(), {
-            'name': name,
-            'defaultService': dst_default_backend_service.url,
-            'hostRules': [{
-                'hosts': src_hosts,
-                'pathMatcher': matcher_name,
-            }],
-            'pathMatchers': [{
-                'name': matcher_name,
-                'defaultService': dst_host_rule_match_backend_service.url,
-            }],
-        })
+        return self._insert_resource(
+            self.api.urlMaps(), {
+                'name':
+                    name,
+                'defaultService':
+                    dst_default_backend_service.url,
+                'hostRules': [{
+                    'hosts': src_hosts,
+                    'pathMatcher': matcher_name,
+                }],
+                'pathMatchers': [{
+                    'name': matcher_name,
+                    'defaultService': dst_host_rule_match_backend_service.url,
+                }],
+            })
 
     def delete_url_map(self, name):
         self._delete_resource(self.api.urlMaps(), urlMap=name)
@@ -174,14 +177,17 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
         target_proxy: GcpResource,
         network_url: str,
     ) -> GcpResource:
-        return self._insert_resource(self.api.globalForwardingRules(), {
-            'name': name,
-            'loadBalancingScheme': 'INTERNAL_SELF_MANAGED',  # Traffic Director
-            'portRange': src_port,
-            'IPAddress': '0.0.0.0',
-            'network': network_url,
-            'target': target_proxy.url,
-        })
+        return self._insert_resource(
+            self.api.globalForwardingRules(),
+            {
+                'name': name,
+                'loadBalancingScheme':
+                    'INTERNAL_SELF_MANAGED',  # Traffic Director
+                'portRange': src_port,
+                'IPAddress': '0.0.0.0',
+                'network': network_url,
+                'target': target_proxy.url,
+            })
 
     def delete_forwarding_rule(self, name):
         self._delete_resource(self.api.globalForwardingRules(),
@@ -192,15 +198,16 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
         return not neg or neg.get('size', 0) == 0
 
     def wait_for_network_endpoint_group(self, name, zone):
+
         @retrying.retry(retry_on_result=self._network_endpoint_group_not_ready,
                         stop_max_delay=60 * 1000,
                         wait_fixed=2 * 1000)
         def _wait_for_network_endpoint_group_ready():
             try:
                 neg = self.get_network_endpoint_group(name, zone)
-                logger.debug('Waiting for endpoints: NEG %s in zone %s, '
-                             'current count %s',
-                             neg['name'], zone, neg.get('size'))
+                logger.debug(
+                    'Waiting for endpoints: NEG %s in zone %s, '
+                    'current count %s', neg['name'], zone, neg.get('size'))
             except googleapiclient.errors.HttpError as error:
                 # noinspection PyProtectedMember
                 reason = error._get_reason()
@@ -211,10 +218,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
 
         network_endpoint_group = _wait_for_network_endpoint_group_ready()
         # @todo(sergiitk): dataclass
-        return self.ZonalGcpResource(
-            network_endpoint_group['name'],
-            network_endpoint_group['selfLink'],
-            zone)
+        return self.ZonalGcpResource(network_endpoint_group['name'],
+                                     network_endpoint_group['selfLink'], zone)
 
     def get_network_endpoint_group(self, name, zone):
         neg = self.api.networkEndpointGroups().get(project=self.project,
@@ -232,10 +237,9 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
     ):
         pending = set(backends)
 
-        @retrying.retry(
-            retry_on_result=lambda result: not result,
-            stop_max_delay=timeout_sec * 1000,
-            wait_fixed=wait_sec * 1000)
+        @retrying.retry(retry_on_result=lambda result: not result,
+                        stop_max_delay=timeout_sec * 1000,
+                        wait_fixed=wait_sec * 1000)
         def _retry_backends_health():
             for backend in pending:
                 result = self.get_backend_service_backend_health(
@@ -250,9 +254,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
                 for instance in result['healthStatus']:
                     logger.debug(
                         'Backend %s in zone %s: instance %s:%s health: %s',
-                        backend.name, backend.zone,
-                        instance['ipAddress'], instance['port'],
-                        instance['healthState'])
+                        backend.name, backend.zone, instance['ipAddress'],
+                        instance['port'], instance['healthState'])
                     if instance['healthState'] != 'HEALTHY':
                         backend_healthy = False
 
@@ -267,8 +270,11 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
 
     def get_backend_service_backend_health(self, backend_service, backend):
         return self.api.backendServices().getHealth(
-            project=self.project, backendService=backend_service.name,
-            body={"group": backend.url}).execute()
+            project=self.project,
+            backendService=backend_service.name,
+            body={
+                "group": backend.url
+            }).execute()
 
     def _get_resource(self, collection: discovery.Resource,
                       **kwargs) -> GcpResource:
@@ -276,11 +282,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
         logger.debug("Loaded %r", resp)
         return self.GcpResource(resp['name'], resp['selfLink'])
 
-    def _insert_resource(
-        self,
-        collection: discovery.Resource,
-        body: Dict[str, Any]
-    ) -> GcpResource:
+    def _insert_resource(self, collection: discovery.Resource,
+                         body: Dict[str, Any]) -> GcpResource:
         logger.debug("Creating %s", body)
         resp = self._execute(collection.insert(project=self.project, body=body))
         return self.GcpResource(body['name'], resp['targetLink'])
@@ -297,14 +300,16 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
         except googleapiclient.errors.HttpError as error:
             # noinspection PyProtectedMember
             reason = error._get_reason()
-            logger.info('Delete failed. Error: %s %s',
-                        error.resp.status, reason)
+            logger.info('Delete failed. Error: %s %s', error.resp.status,
+                        reason)
 
     @staticmethod
     def _operation_status_done(operation):
         return 'status' in operation and operation['status'] == 'DONE'
 
-    def _execute(self, request, *,
+    def _execute(self,
+                 request,
+                 *,
                  test_success_fn=None,
                  timeout_sec=_WAIT_FOR_OPERATION_SEC):
         operation = request.execute(num_retries=self._GCP_API_RETRIES)
@@ -320,10 +325,9 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
 
         logger.debug('Waiting for global operation %s, timeout %s sec',
                      operation['name'], timeout_sec)
-        response = self.wait_for_operation(
-            operation_request=operation_request,
-            test_success_fn=test_success_fn,
-            timeout_sec=timeout_sec)
+        response = self.wait_for_operation(operation_request=operation_request,
+                                           test_success_fn=test_success_fn,
+                                           timeout_sec=timeout_sec)
 
         if 'error' in response:
             logger.debug('Waiting for global operation failed, response: %r',

+ 13 - 13
tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_security.py

@@ -52,22 +52,22 @@ class NetworkSecurityV1Alpha1(gcp.api.GcpStandardCloudApiResource):
         self._api_locations = self.api.projects().locations()
 
     def create_server_tls_policy(self, name, body: dict):
-        return self._create_resource(
-            self._api_locations.serverTlsPolicies(),
-            body, serverTlsPolicyId=name)
+        return self._create_resource(self._api_locations.serverTlsPolicies(),
+                                     body,
+                                     serverTlsPolicyId=name)
 
     def get_server_tls_policy(self, name: str) -> ServerTlsPolicy:
         result = self._get_resource(
             collection=self._api_locations.serverTlsPolicies(),
             full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
 
-        return self.ServerTlsPolicy(
-            name=name,
-            url=result['name'],
-            server_certificate=result.get('serverCertificate', {}),
-            mtls_policy=result.get('mtlsPolicy', {}),
-            create_time=result['createTime'],
-            update_time=result['updateTime'])
+        return self.ServerTlsPolicy(name=name,
+                                    url=result['name'],
+                                    server_certificate=result.get(
+                                        'serverCertificate', {}),
+                                    mtls_policy=result.get('mtlsPolicy', {}),
+                                    create_time=result['createTime'],
+                                    update_time=result['updateTime'])
 
     def delete_server_tls_policy(self, name):
         return self._delete_resource(
@@ -75,9 +75,9 @@ class NetworkSecurityV1Alpha1(gcp.api.GcpStandardCloudApiResource):
             full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
 
     def create_client_tls_policy(self, name, body: dict):
-        return self._create_resource(
-            self._api_locations.clientTlsPolicies(),
-            body, clientTlsPolicyId=name)
+        return self._create_resource(self._api_locations.clientTlsPolicies(),
+                                     body,
+                                     clientTlsPolicyId=name)
 
     def get_client_tls_policy(self, name: str) -> ClientTlsPolicy:
         result = self._get_resource(

+ 2 - 1
tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_services.py

@@ -49,7 +49,8 @@ class NetworkServicesV1Alpha1(gcp.api.GcpStandardCloudApiResource):
     def create_endpoint_config_selector(self, name, body: dict):
         return self._create_resource(
             self._api_locations.endpointConfigSelectors(),
-            body, endpointConfigSelectorId=name)
+            body,
+            endpointConfigSelectorId=name)
 
     def get_endpoint_config_selector(self, name: str) -> EndpointConfigSelector:
         result = self._get_resource(

+ 65 - 51
tools/run_tests/xds_test_driver/framework/infrastructure/k8s.py

@@ -35,6 +35,7 @@ ApiException = client.ApiException
 
 
 def simple_resource_get(func):
+
     def wrap_not_found_return_none(*args, **kwargs):
         try:
             return func(*args, **kwargs)
@@ -43,6 +44,7 @@ def simple_resource_get(func):
                 # Ignore 404
                 return None
             raise
+
     return wrap_not_found_return_none
 
 
@@ -51,6 +53,7 @@ def label_dict_to_selector(labels: dict) -> str:
 
 
 class KubernetesApiManager:
+
     def __init__(self, context):
         self.context = context
         self.client = self._cached_api_client_for_context(context)
@@ -80,7 +83,8 @@ class KubernetesNamespace:
         self.api = api
 
     def apply_manifest(self, manifest):
-        return utils.create_from_dict(self.api.client, manifest,
+        return utils.create_from_dict(self.api.client,
+                                      manifest,
                                       namespace=self.name)
 
     @simple_resource_get
@@ -91,24 +95,22 @@ class KubernetesNamespace:
     def get_service_account(self, name) -> V1Service:
         return self.api.core.read_namespaced_service_account(name, self.name)
 
-    def delete_service(
-        self,
-        name,
-        grace_period_seconds=DELETE_GRACE_PERIOD_SEC
-    ):
+    def delete_service(self,
+                       name,
+                       grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
         self.api.core.delete_namespaced_service(
-            name=name, namespace=self.name,
+            name=name,
+            namespace=self.name,
             body=client.V1DeleteOptions(
                 propagation_policy='Foreground',
                 grace_period_seconds=grace_period_seconds))
 
-    def delete_service_account(
-        self,
-        name,
-        grace_period_seconds=DELETE_GRACE_PERIOD_SEC
-    ):
+    def delete_service_account(self,
+                               name,
+                               grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
         self.api.core.delete_namespaced_service_account(
-            name=name, namespace=self.name,
+            name=name,
+            namespace=self.name,
             body=client.V1DeleteOptions(
                 propagation_policy='Foreground',
                 grace_period_seconds=grace_period_seconds))
@@ -124,8 +126,8 @@ class KubernetesNamespace:
                 propagation_policy='Foreground',
                 grace_period_seconds=grace_period_seconds))
 
-    def wait_for_service_deleted(self, name: str,
-                                 timeout_sec=60, wait_sec=1):
+    def wait_for_service_deleted(self, name: str, timeout_sec=60, wait_sec=1):
+
         @retrying.retry(retry_on_result=lambda r: r is not None,
                         stop_max_delay=timeout_sec * 1000,
                         wait_fixed=wait_sec * 1000)
@@ -135,10 +137,14 @@ class KubernetesNamespace:
                 logger.info('Waiting for service %s to be deleted',
                             service.metadata.name)
             return service
+
         _wait_for_deleted_service_with_retry()
 
-    def wait_for_service_account_deleted(self, name: str,
-                                         timeout_sec=60, wait_sec=1):
+    def wait_for_service_account_deleted(self,
+                                         name: str,
+                                         timeout_sec=60,
+                                         wait_sec=1):
+
         @retrying.retry(retry_on_result=lambda r: r is not None,
                         stop_max_delay=timeout_sec * 1000,
                         wait_fixed=wait_sec * 1000)
@@ -148,10 +154,11 @@ class KubernetesNamespace:
                 logger.info('Waiting for service account %s to be deleted',
                             service_account.metadata.name)
             return service_account
+
         _wait_for_deleted_service_account_with_retry()
 
-    def wait_for_namespace_deleted(self,
-                                   timeout_sec=240, wait_sec=2):
+    def wait_for_namespace_deleted(self, timeout_sec=240, wait_sec=2):
+
         @retrying.retry(retry_on_result=lambda r: r is not None,
                         stop_max_delay=timeout_sec * 1000,
                         wait_fixed=wait_sec * 1000)
@@ -161,27 +168,25 @@ class KubernetesNamespace:
                 logger.info('Waiting for namespace %s to be deleted',
                             namespace.metadata.name)
             return namespace
+
         _wait_for_deleted_namespace_with_retry()
 
-    def wait_for_service_neg(self, name: str,
-                             timeout_sec=60, wait_sec=1):
+    def wait_for_service_neg(self, name: str, timeout_sec=60, wait_sec=1):
+
         @retrying.retry(retry_on_result=lambda r: not r,
                         stop_max_delay=timeout_sec * 1000,
                         wait_fixed=wait_sec * 1000)
         def _wait_for_service_neg():
             service = self.get_service(name)
             if self.NEG_STATUS_META not in service.metadata.annotations:
-                logger.info('Waiting for service %s NEG',
-                            service.metadata.name)
+                logger.info('Waiting for service %s NEG', service.metadata.name)
                 return False
             return True
+
         _wait_for_service_neg()
 
-    def get_service_neg(
-        self,
-        service_name: str,
-        service_port: int
-    ) -> Tuple[str, List[str]]:
+    def get_service_neg(self, service_name: str,
+                        service_port: int) -> Tuple[str, List[str]]:
         service = self.get_service(service_name)
         neg_info: dict = json.loads(
             service.metadata.annotations[self.NEG_STATUS_META])
@@ -193,13 +198,12 @@ class KubernetesNamespace:
     def get_deployment(self, name) -> V1Deployment:
         return self.api.apps.read_namespaced_deployment(name, self.name)
 
-    def delete_deployment(
-        self,
-        name,
-        grace_period_seconds=DELETE_GRACE_PERIOD_SEC
-    ):
+    def delete_deployment(self,
+                          name,
+                          grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
         self.api.apps.delete_namespaced_deployment(
-            name=name, namespace=self.name,
+            name=name,
+            namespace=self.name,
             body=client.V1DeleteOptions(
                 propagation_policy='Foreground',
                 grace_period_seconds=grace_period_seconds))
@@ -208,34 +212,43 @@ class KubernetesNamespace:
         # V1LabelSelector.match_expressions not supported at the moment
         return self.list_pods_with_labels(deployment.spec.selector.match_labels)
 
-    def wait_for_deployment_available_replicas(self, name, count=1,
-                                               timeout_sec=60, wait_sec=1):
+    def wait_for_deployment_available_replicas(self,
+                                               name,
+                                               count=1,
+                                               timeout_sec=60,
+                                               wait_sec=1):
+
         @retrying.retry(
             retry_on_result=lambda r: not self._replicas_available(r, count),
             stop_max_delay=timeout_sec * 1000,
             wait_fixed=wait_sec * 1000)
         def _wait_for_deployment_available_replicas():
             deployment = self.get_deployment(name)
-            logger.info('Waiting for deployment %s to have %s available '
-                        'replicas, current count %s',
-                        deployment.metadata.name,
-                        count, deployment.status.available_replicas)
+            logger.info(
+                'Waiting for deployment %s to have %s available '
+                'replicas, current count %s', deployment.metadata.name, count,
+                deployment.status.available_replicas)
             return deployment
+
         _wait_for_deployment_available_replicas()
 
-    def wait_for_deployment_deleted(self, deployment_name: str,
-                                    timeout_sec=60, wait_sec=1):
+    def wait_for_deployment_deleted(self,
+                                    deployment_name: str,
+                                    timeout_sec=60,
+                                    wait_sec=1):
+
         @retrying.retry(retry_on_result=lambda r: r is not None,
                         stop_max_delay=timeout_sec * 1000,
                         wait_fixed=wait_sec * 1000)
         def _wait_for_deleted_deployment_with_retry():
             deployment = self.get_deployment(deployment_name)
             if deployment is not None:
-                logger.info('Waiting for deployment %s to be deleted. '
-                            'Non-terminated replicas: %s',
-                            deployment.metadata.name,
-                            deployment.status.replicas)
+                logger.info(
+                    'Waiting for deployment %s to be deleted. '
+                    'Non-terminated replicas: %s', deployment.metadata.name,
+                    deployment.status.replicas)
             return deployment
+
         _wait_for_deleted_deployment_with_retry()
 
     def list_pods_with_labels(self, labels: dict) -> List[V1Pod]:
@@ -247,15 +260,16 @@ class KubernetesNamespace:
         return self.api.core.read_namespaced_pod(name, self.name)
 
     def wait_for_pod_started(self, pod_name, timeout_sec=60, wait_sec=1):
+
         @retrying.retry(retry_on_result=lambda r: not self._pod_started(r),
                         stop_max_delay=timeout_sec * 1000,
                         wait_fixed=wait_sec * 1000)
         def _wait_for_pod_started():
             pod = self.get_pod(pod_name)
             logger.info('Waiting for pod %s to start, current phase: %s',
-                        pod.metadata.name,
-                        pod.status.phase)
+                        pod.metadata.name, pod.status.phase)
             return pod
+
         _wait_for_pod_started()
 
     def port_forward_pod(
@@ -269,12 +283,12 @@ class KubernetesNamespace:
         local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS
         local_port = local_port or remote_port
         cmd = [
-            "kubectl", "--context", self.api.context,
-            "--namespace", self.name,
+            "kubectl", "--context", self.api.context, "--namespace", self.name,
             "port-forward", "--address", local_address,
             f"pod/{pod.metadata.name}", f"{local_port}:{remote_port}"
         ]
-        pf = subprocess.Popen(cmd, stdout=subprocess.PIPE,
+        pf = subprocess.Popen(cmd,
+                              stdout=subprocess.PIPE,
                               stderr=subprocess.STDOUT,
                               universal_newlines=True)
         # Wait for stdout line indicating successful start.

+ 82 - 67
tools/run_tests/xds_test_driver/framework/infrastructure/traffic_director.py

@@ -75,13 +75,11 @@ class TrafficDirectorManager:
     def network_url(self):
         return f'global/networks/{self.network}'
 
-    def setup_for_grpc(
-        self,
-        service_host,
-        service_port,
-        *,
-        backend_protocol=BackendServiceProtocol.GRPC
-    ):
+    def setup_for_grpc(self,
+                       service_host,
+                       service_port,
+                       *,
+                       backend_protocol=BackendServiceProtocol.GRPC):
         self.create_health_check()
         self.create_backend_service(protocol=backend_protocol)
         self.create_url_map(service_host, service_port)
@@ -130,9 +128,8 @@ class TrafficDirectorManager:
         self.health_check = None
 
     def create_backend_service(
-        self,
-        protocol: BackendServiceProtocol = BackendServiceProtocol.GRPC
-    ):
+            self,
+            protocol: BackendServiceProtocol = BackendServiceProtocol.GRPC):
         name = self._ns_name(self.BACKEND_SERVICE_NAME)
         logger.info('Creating %s Backend Service %s', protocol.name, name)
         resource = self.compute.create_backend_service_traffic_director(
@@ -168,8 +165,8 @@ class TrafficDirectorManager:
     def backend_service_add_backends(self):
         logging.info('Adding backends to Backend Service %s: %r',
                      self.backend_service.name, self.backends)
-        self.compute.backend_service_add_backends(
-            self.backend_service, self.backends)
+        self.compute.backend_service_add_backends(self.backend_service,
+                                                  self.backends)
 
     def backend_service_remove_all_backends(self):
         logging.info('Removing backends from Backend Service %s',
@@ -180,8 +177,8 @@ class TrafficDirectorManager:
         logger.debug(
             "Waiting for Backend Service %s to report all backends healthy %r",
             self.backend_service, self.backends)
-        self.compute.wait_for_backends_healthy_status(
-            self.backend_service, self.backends)
+        self.compute.wait_for_backends_healthy_status(self.backend_service,
+                                                      self.backends)
 
     def create_url_map(
         self,
@@ -191,10 +188,11 @@ class TrafficDirectorManager:
         src_address = f'{src_host}:{src_port}'
         name = self._ns_name(self.URL_MAP_NAME)
         matcher_name = self._ns_name(self.URL_MAP_PATH_MATCHER_NAME)
-        logger.info('Creating URL map %s %s -> %s',
-                    name, src_address, self.backend_service.name)
-        resource = self.compute.create_url_map(
-            name, matcher_name, [src_address], self.backend_service)
+        logger.info('Creating URL map %s %s -> %s', name, src_address,
+                    self.backend_service.name)
+        resource = self.compute.create_url_map(name, matcher_name,
+                                               [src_address],
+                                               self.backend_service)
         self.url_map = resource
         return resource
 
@@ -212,10 +210,9 @@ class TrafficDirectorManager:
     def create_target_grpc_proxy(self):
         # todo: different kinds
         name = self._ns_name(self.TARGET_PROXY_NAME)
-        logger.info('Creating target GRPC proxy %s to url map %s',
-                    name, self.url_map.name)
-        resource = self.compute.create_target_grpc_proxy(
-            name, self.url_map)
+        logger.info('Creating target GRPC proxy %s to url map %s', name,
+                    self.url_map.name)
+        resource = self.compute.create_target_grpc_proxy(name, self.url_map)
         self.target_proxy = resource
 
     def delete_target_grpc_proxy(self, force=False):
@@ -233,10 +230,9 @@ class TrafficDirectorManager:
     def create_target_http_proxy(self):
         # todo: different kinds
         name = self._ns_name(self.TARGET_PROXY_NAME)
-        logger.info('Creating target HTTP proxy %s to url map %s',
-                    name, self.url_map.name)
-        resource = self.compute.create_target_http_proxy(
-            name, self.url_map)
+        logger.info('Creating target HTTP proxy %s to url map %s', name,
+                    self.url_map.name)
+        resource = self.compute.create_target_http_proxy(name, self.url_map)
         self.target_proxy = resource
         self.target_proxy_is_http = True
 
@@ -255,10 +251,11 @@ class TrafficDirectorManager:
     def create_forwarding_rule(self, src_port: int):
         name = self._ns_name(self.FORWARDING_RULE_NAME)
         src_port = int(src_port)
-        logging.info('Creating forwarding rule %s 0.0.0.0:%s -> %s in %s',
-                     name, src_port, self.target_proxy.url, self.network)
-        resource = self.compute.create_forwarding_rule(
-            name, src_port, self.target_proxy, self.network_url)
+        logging.info('Creating forwarding rule %s 0.0.0.0:%s -> %s in %s', name,
+                     src_port, self.target_proxy.url, self.network)
+        resource = self.compute.create_forwarding_rule(name, src_port,
+                                                       self.target_proxy,
+                                                       self.network_url)
         self.forwarding_rule = resource
         return resource
 
@@ -289,8 +286,10 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
         resource_prefix: str,
         network: str = 'default',
     ):
-        super().__init__(gcp_api_manager, project,
-                         resource_prefix=resource_prefix, network=network)
+        super().__init__(gcp_api_manager,
+                         project,
+                         resource_prefix=resource_prefix,
+                         network=network)
 
         # API
         self.netsec = NetworkSecurityV1Alpha1(gcp_api_manager, project)
@@ -301,25 +300,28 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
         self.ecs: Optional[EndpointConfigSelector] = None
         self.client_tls_policy: Optional[ClientTlsPolicy] = None
 
-    def setup_for_grpc(
-        self,
-        service_host,
-        service_port,
-        *,
-        backend_protocol=BackendServiceProtocol.HTTP2
-    ):
-        super().setup_for_grpc(service_host, service_port,
+    def setup_for_grpc(self,
+                       service_host,
+                       service_port,
+                       *,
+                       backend_protocol=BackendServiceProtocol.HTTP2):
+        super().setup_for_grpc(service_host,
+                               service_port,
                                backend_protocol=backend_protocol)
 
     def setup_server_security(self, server_port, *, tls, mtls):
         self.create_server_tls_policy(tls=tls, mtls=mtls)
         self.create_endpoint_config_selector(server_port)
 
-    def setup_client_security(self, server_namespace, server_name,
-                              *, tls=True, mtls=True):
+    def setup_client_security(self,
+                              server_namespace,
+                              server_name,
+                              *,
+                              tls=True,
+                              mtls=True):
         self.create_client_tls_policy(tls=tls, mtls=mtls)
-        self.backend_service_apply_client_mtls_policy(
-            server_namespace, server_name)
+        self.backend_service_apply_client_mtls_policy(server_namespace,
+                                                      server_name)
 
     def cleanup(self, *, force=False):
         # Cleanup in the reverse order of creation
@@ -334,12 +336,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
         name = self._ns_name(self.SERVER_TLS_POLICY_NAME)
         logger.info('Creating Server TLS Policy %s', name)
         if not tls and not mtls:
-            logger.warning('Server TLS Policy %s neither TLS, nor mTLS '
-                           'policy. Skipping creation', name)
+            logger.warning(
+                'Server TLS Policy %s neither TLS, nor mTLS '
+                'policy. Skipping creation', name)
             return
 
         grpc_endpoint = {
-            "grpcEndpoint": {"targetUri": self.GRPC_ENDPOINT_TARGET_URI}}
+            "grpcEndpoint": {
+                "targetUri": self.GRPC_ENDPOINT_TARGET_URI
+            }
+        }
 
         policy = {}
         if tls:
@@ -381,13 +387,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
             "type": "SIDECAR_PROXY",
             "httpFilters": {},
             "trafficPortSelector": port_selector,
-            "endpointMatcher": {"metadataLabelMatcher": label_matcher_all},
+            "endpointMatcher": {
+                "metadataLabelMatcher": label_matcher_all
+            },
         }
         if self.server_tls_policy:
             config["serverTlsPolicy"] = self.server_tls_policy.name
         else:
-            logger.warning('Creating Endpoint Config Selector %s with '
-                           'no Server TLS policy attached', name)
+            logger.warning(
+                'Creating Endpoint Config Selector %s with '
+                'no Server TLS policy attached', name)
 
         self.netsvc.create_endpoint_config_selector(name, config)
         self.ecs = self.netsvc.get_endpoint_config_selector(name)
@@ -408,12 +417,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
         name = self._ns_name(self.CLIENT_TLS_POLICY_NAME)
         logger.info('Creating Client TLS Policy %s', name)
         if not tls and not mtls:
-            logger.warning('Client TLS Policy %s neither TLS, nor mTLS '
-                           'policy. Skipping creation', name)
+            logger.warning(
+                'Client TLS Policy %s neither TLS, nor mTLS '
+                'policy. Skipping creation', name)
             return
 
         grpc_endpoint = {
-            "grpcEndpoint": {"targetUri": self.GRPC_ENDPOINT_TARGET_URI}}
+            "grpcEndpoint": {
+                "targetUri": self.GRPC_ENDPOINT_TARGET_URI
+            }
+        }
 
         policy = {}
         if tls:
@@ -442,21 +455,23 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
         server_name,
     ):
         if not self.client_tls_policy:
-            logger.warning('Client TLS policy not created, '
-                           'skipping attaching to Backend Service %s',
-                           self.backend_service.name)
+            logger.warning(
+                'Client TLS policy not created, '
+                'skipping attaching to Backend Service %s',
+                self.backend_service.name)
             return
 
         server_spiffe = (f'spiffe://{self.project}.svc.id.goog/'
                          f'ns/{server_namespace}/sa/{server_name}')
-        logging.info('Adding Client TLS Policy to Backend Service %s: %s, '
-                     'server %s',
-                     self.backend_service.name,
-                     self.client_tls_policy.url,
-                     server_spiffe)
-
-        self.compute.patch_backend_service(self.backend_service, {
-            'securitySettings': {
-                'clientTlsPolicy': self.client_tls_policy.url,
-                'subjectAltNames': [server_spiffe]
-            }})
+        logging.info(
+            'Adding Client TLS Policy to Backend Service %s: %s, '
+            'server %s', self.backend_service.name, self.client_tls_policy.url,
+            server_spiffe)
+
+        self.compute.patch_backend_service(
+            self.backend_service, {
+                'securitySettings': {
+                    'clientTlsPolicy': self.client_tls_policy.url,
+                    'subjectAltNames': [server_spiffe]
+                }
+            })

+ 3 - 3
tools/run_tests/xds_test_driver/framework/rpc/__init__.py

@@ -37,7 +37,8 @@ class GrpcClientHelper:
         self.service_name = re.sub('Stub$', '', self.stub.__class__.__name__)
 
     def call_unary_when_channel_ready(
-        self, *,
+        self,
+        *,
         rpc: str,
         req: Message,
         wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
@@ -56,8 +57,7 @@ class GrpcClientHelper:
         return rpc_callable(req, **call_kwargs)
 
     def _log_debug(self, rpc, req, call_kwargs):
-        logger.debug('RPC %s.%s(request=%s(%r), %s)',
-                     self.service_name, rpc,
+        logger.debug('RPC %s.%s(request=%s(%r), %s)', self.service_name, rpc,
                      req.__class__.__name__, json_format.MessageToDict(req),
                      ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))
 

+ 5 - 9
tools/run_tests/xds_test_driver/framework/rpc/grpc_channelz.py

@@ -83,10 +83,8 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
                 f'remote={cls.sock_address_to_str(socket.remote)}')
 
     @staticmethod
-    def find_server_socket_matching_client(
-        server_sockets: Iterator[Socket],
-        client_socket: Socket
-    ) -> Socket:
+    def find_server_socket_matching_client(server_sockets: Iterator[Socket],
+                                           client_socket: Socket) -> Socket:
         for server_socket in server_sockets:
             if server_socket.remote == client_socket.local:
                 return server_socket
@@ -103,7 +101,7 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
                 listen_socket = self.get_socket(listen_socket_ref.socket_id)
                 listen_address: Address = listen_socket.local
                 if (self.is_sock_tcpip_address(listen_address) and
-                    listen_address.tcpip_address.port == port):
+                        listen_address.tcpip_address.port == port):
                     return server
         return None
 
@@ -136,8 +134,7 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
             # value by adding 1 to the highest seen result ID.
             start += 1
             response = self.call_unary_when_channel_ready(
-                rpc='GetServers',
-                req=GetServersRequest(start_server_id=start))
+                rpc='GetServers', req=GetServersRequest(start_server_id=start))
             for server in response.server:
                 start = max(start, server.ref.server_id)
                 yield server
@@ -170,6 +167,5 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
     def get_socket(self, socket_id) -> Socket:
         """Return a single Socket, otherwise raises RpcError."""
         response: GetSocketResponse = self.call_unary_when_channel_ready(
-            rpc='GetSocket',
-            req=GetSocketRequest(socket_id=socket_id))
+            rpc='GetSocket', req=GetSocketRequest(socket_id=socket_id))
         return response.socket

+ 2 - 2
tools/run_tests/xds_test_driver/framework/rpc/grpc_testing.py

@@ -19,7 +19,6 @@ import framework.rpc
 from src.proto.grpc.testing import test_pb2_grpc
 from src.proto.grpc.testing import messages_pb2
 
-
 # Type aliases
 LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest
 LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse
@@ -33,7 +32,8 @@ class LoadBalancerStatsServiceClient(framework.rpc.GrpcClientHelper):
         super().__init__(channel, test_pb2_grpc.LoadBalancerStatsServiceStub)
 
     def get_client_stats(
-        self, *,
+        self,
+        *,
         num_rpcs: int,
         timeout_sec: Optional[int] = STATS_PARTIAL_RESULTS_TIMEOUT_SEC,
     ) -> LoadBalancerStatsResponse:

+ 22 - 31
tools/run_tests/xds_test_driver/framework/test_app/base_runner.py

@@ -33,6 +33,7 @@ TEMPLATE_DIR = '../../kubernetes-manifests'
 
 
 class KubernetesBaseRunner:
+
     def __init__(self,
                  k8s_namespace,
                  namespace_template=None,
@@ -50,8 +51,7 @@ class KubernetesBaseRunner:
             self.namespace = self._reuse_namespace()
         if not self.namespace:
             self.namespace = self._create_namespace(
-                self.namespace_template,
-                namespace_name=self.k8s_namespace.name)
+                self.namespace_template, namespace_name=self.k8s_namespace.name)
 
     def cleanup(self, *, force=False):
         if (self.namespace and not self.reuse_namespace) or force:
@@ -127,27 +127,21 @@ class KubernetesBaseRunner:
             raise RunnerError('Expected V1Namespace to be created '
                               f'from manifest {template}')
         if namespace.metadata.name != kwargs['namespace_name']:
-            raise RunnerError(
-                'Namespace created with unexpected name: '
-                f'{namespace.metadata.name}')
-        logger.info('Deployment %s created at %s',
-                    namespace.metadata.self_link,
+            raise RunnerError('Namespace created with unexpected name: '
+                              f'{namespace.metadata.name}')
+        logger.info('Deployment %s created at %s', namespace.metadata.self_link,
                     namespace.metadata.creation_timestamp)
         return namespace
 
-    def _create_service_account(
-        self,
-        template,
-        **kwargs
-    ) -> k8s.V1ServiceAccount:
+    def _create_service_account(self, template,
+                                **kwargs) -> k8s.V1ServiceAccount:
         resource = self._create_from_template(template, **kwargs)
         if not isinstance(resource, k8s.V1ServiceAccount):
             raise RunnerError('Expected V1ServiceAccount to be created '
                               f'from manifest {template}')
         if resource.metadata.name != kwargs['service_account_name']:
-            raise RunnerError(
-                'V1ServiceAccount created with unexpected name: '
-                f'{resource.metadata.name}')
+            raise RunnerError('V1ServiceAccount created with unexpected name: '
+                              f'{resource.metadata.name}')
         logger.info('V1ServiceAccount %s created at %s',
                     resource.metadata.self_link,
                     resource.metadata.creation_timestamp)
@@ -159,9 +153,8 @@ class KubernetesBaseRunner:
             raise RunnerError('Expected V1Deployment to be created '
                               f'from manifest {template}')
         if deployment.metadata.name != kwargs['deployment_name']:
-            raise RunnerError(
-                'Deployment created with unexpected name: '
-                f'{deployment.metadata.name}')
+            raise RunnerError('Deployment created with unexpected name: '
+                              f'{deployment.metadata.name}')
         logger.info('Deployment %s created at %s',
                     deployment.metadata.self_link,
                     deployment.metadata.creation_timestamp)
@@ -173,11 +166,9 @@ class KubernetesBaseRunner:
             raise RunnerError('Expected V1Service to be created '
                               f'from manifest {template}')
         if service.metadata.name != kwargs['service_name']:
-            raise RunnerError(
-                'Service created with unexpected name: '
-                f'{service.metadata.name}')
-        logger.info('Service %s created at %s',
-                    service.metadata.self_link,
+            raise RunnerError('Service created with unexpected name: '
+                              f'{service.metadata.name}')
+        logger.info('Service %s created at %s', service.metadata.self_link,
                     service.metadata.creation_timestamp)
         return service
 
@@ -185,8 +176,8 @@ class KubernetesBaseRunner:
         try:
             self.k8s_namespace.delete_deployment(name)
         except k8s.ApiException as e:
-            logger.info('Deployment %s deletion failed, error: %s %s',
-                        name, e.status, e.reason)
+            logger.info('Deployment %s deletion failed, error: %s %s', name,
+                        e.status, e.reason)
             return
 
         if wait_for_deletion:
@@ -197,8 +188,8 @@ class KubernetesBaseRunner:
         try:
             self.k8s_namespace.delete_service(name)
         except k8s.ApiException as e:
-            logger.info('Service %s deletion failed, error: %s %s',
-                        name, e.status, e.reason)
+            logger.info('Service %s deletion failed, error: %s %s', name,
+                        e.status, e.reason)
             return
 
         if wait_for_deletion:
@@ -232,8 +223,8 @@ class KubernetesBaseRunner:
     def _wait_deployment_with_available_replicas(self, name, count=1, **kwargs):
         logger.info('Waiting for deployment %s to have %s available replicas',
                     name, count)
-        self.k8s_namespace.wait_for_deployment_available_replicas(name, count,
-                                                                  **kwargs)
+        self.k8s_namespace.wait_for_deployment_available_replicas(
+            name, count, **kwargs)
         deployment = self.k8s_namespace.get_deployment(name)
         logger.info('Deployment %s has %i replicas available',
                     deployment.metadata.name,
@@ -251,5 +242,5 @@ class KubernetesBaseRunner:
         self.k8s_namespace.wait_for_service_neg(name, **kwargs)
         neg_name, neg_zones = self.k8s_namespace.get_service_neg(
             name, service_port)
-        logger.info("Service %s: detected NEG=%s in zones=%s", name,
-                    neg_name, neg_zones)
+        logger.info("Service %s: detected NEG=%s in zones=%s", name, neg_name,
+                    neg_zones)

+ 15 - 11
tools/run_tests/xds_test_driver/framework/test_app/client_app.py

@@ -32,7 +32,9 @@ LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
 
 
 class XdsTestClient(framework.rpc.GrpcApp):
-    def __init__(self, *,
+
+    def __init__(self,
+                 *,
                  ip: str,
                  rpc_port: int,
                  server_target: str,
@@ -55,7 +57,8 @@ class XdsTestClient(framework.rpc.GrpcApp):
         return ChannelzServiceClient(self._make_channel(self.maintenance_port))
 
     def get_load_balancer_stats(
-        self, *,
+        self,
+        *,
         num_rpcs: int,
         timeout_sec: Optional[int] = None,
     ) -> grpc_testing.LoadBalancerStatsResponse:
@@ -76,16 +79,14 @@ class XdsTestClient(framework.rpc.GrpcApp):
             stop=tenacity.stop_after_delay(60 * 3),
             reraise=True)
         channel = retryer(self.get_active_server_channel)
-        logger.info(
-            'Active server channel found: channel_id: %s, %s',
-            channel.ref.channel_id, channel.ref.name)
+        logger.info('Active server channel found: channel_id: %s, %s',
+                    channel.ref.channel_id, channel.ref.name)
         logger.debug('Server channel:\n%r', channel)
 
     def get_active_server_channel(self) -> Optional[grpc_channelz.Channel]:
         for channel in self.get_server_channels():
             state: ChannelConnectivityState = channel.data.state
-            logger.debug('Server channel: %s, state: %s',
-                         channel.ref.name,
+            logger.debug('Server channel: %s, state: %s', channel.ref.name,
                          ChannelConnectivityState.State.Name(state.state))
             if state.state is ChannelConnectivityState.READY:
                 return channel
@@ -107,6 +108,7 @@ class XdsTestClient(framework.rpc.GrpcApp):
 
 
 class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
+
     def __init__(self,
                  k8s_namespace,
                  *,
@@ -142,9 +144,11 @@ class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
         self.service_account: Optional[k8s.V1ServiceAccount] = None
         self.port_forwarder = None
 
-    def run(self, *,
+    def run(self,
+            *,
             server_target,
-            rpc='UnaryCall', qps=25,
+            rpc='UnaryCall',
+            qps=25,
             secure_mode=False,
             print_response=False) -> XdsTestClient:
         super().run()
@@ -183,8 +187,8 @@ class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
 
         # Experimental, for local debugging.
         if self.debug_use_port_forwarding:
-            logger.info('Enabling port forwarding from %s:%s',
-                        pod_ip, self.stats_port)
+            logger.info('Enabling port forwarding from %s:%s', pod_ip,
+                        self.stats_port)
             self.port_forwarder = self.k8s_namespace.port_forward_pod(
                 pod, remote_port=self.stats_port)
             rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS

+ 29 - 22
tools/run_tests/xds_test_driver/framework/test_app/server_app.py

@@ -27,7 +27,9 @@ ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
 
 
 class XdsTestServer(framework.rpc.GrpcApp):
-    def __init__(self, *,
+
+    def __init__(self,
+                 *,
                  ip: str,
                  rpc_port: int,
                  maintenance_port: Optional[int] = None,
@@ -54,13 +56,16 @@ class XdsTestServer(framework.rpc.GrpcApp):
 
     @property
     def xds_address(self) -> str:
-        if not self.xds_host: return ''
-        if not self.xds_port: return self.xds_host
+        if not self.xds_host:
+            return ''
+        if not self.xds_port:
+            return self.xds_host
         return f'{self.xds_host}:{self.xds_port}'
 
     @property
     def xds_uri(self) -> str:
-        if not self.xds_host: return ''
+        if not self.xds_host:
+            return ''
         return f'xds:///{self.xds_address}'
 
     def get_test_server(self):
@@ -74,10 +79,8 @@ class XdsTestServer(framework.rpc.GrpcApp):
         server = self.get_test_server()
         return self.channelz.list_server_sockets(server.ref.server_id)
 
-    def get_server_socket_matching_client(
-        self,
-        client_socket: grpc_channelz.Socket
-    ):
+    def get_server_socket_matching_client(self,
+                                          client_socket: grpc_channelz.Socket):
         client_local = self.channelz.sock_address_to_str(client_socket.local)
         logger.debug('Looking for a server socket connected to the client %s',
                      client_local)
@@ -95,6 +98,7 @@ class XdsTestServer(framework.rpc.GrpcApp):
 
 
 class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
+
     def __init__(self,
                  k8s_namespace,
                  *,
@@ -140,9 +144,12 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
         self.service: Optional[k8s.V1Service] = None
         self.port_forwarder = None
 
-    def run(self, *,
-            test_port=8080, maintenance_port=None,
-            secure_mode=False, server_id=None,
+    def run(self,
+            *,
+            test_port=8080,
+            maintenance_port=None,
+            secure_mode=False,
+            server_id=None,
             replica_count=1) -> XdsTestServer:
         # todo(sergiitk): multiple replicas
         if replica_count != 1:
@@ -201,8 +208,9 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
             server_id=server_id,
             secure_mode=secure_mode)
 
-        self._wait_deployment_with_available_replicas(
-            self.deployment_name, replica_count, timeout_sec=120)
+        self._wait_deployment_with_available_replicas(self.deployment_name,
+                                                      replica_count,
+                                                      timeout_sec=120)
 
         # Wait for pods running
         pods = self.k8s_namespace.list_deployment_pods(self.deployment)
@@ -215,19 +223,18 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
         rpc_host = None
         # Experimental, for local debugging.
         if self.debug_use_port_forwarding:
-            logger.info('Enabling port forwarding from %s:%s',
-                        pod_ip, maintenance_port)
+            logger.info('Enabling port forwarding from %s:%s', pod_ip,
+                        maintenance_port)
             self.port_forwarder = self.k8s_namespace.port_forward_pod(
                 pod, remote_port=maintenance_port)
             rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS
 
-        return XdsTestServer(
-            ip=pod_ip,
-            rpc_port=test_port,
-            maintenance_port=maintenance_port,
-            secure_mode=secure_mode,
-            server_id=server_id,
-            rpc_host=rpc_host)
+        return XdsTestServer(ip=pod_ip,
+                             rpc_port=test_port,
+                             maintenance_port=maintenance_port,
+                             secure_mode=secure_mode,
+                             server_id=server_id,
+                             rpc_host=rpc_host)
 
     def cleanup(self, *, force=False, force_namespace=False):
         if self.port_forwarder:

+ 26 - 23
tools/run_tests/xds_test_driver/framework/xds_flags.py

@@ -15,35 +15,38 @@ from absl import flags
 import googleapiclient.discovery
 
 # GCP
-PROJECT = flags.DEFINE_string(
-    "project", default=None, help="GCP Project ID. Required")
+PROJECT = flags.DEFINE_string("project",
+                              default=None,
+                              help="GCP Project ID. Required")
 NAMESPACE = flags.DEFINE_string(
-    "namespace", default=None,
+    "namespace",
+    default=None,
     help="Isolate GCP resources using given namespace / name prefix. Required")
-NETWORK = flags.DEFINE_string(
-    "network", default="default", help="GCP Network ID")
+NETWORK = flags.DEFINE_string("network",
+                              default="default",
+                              help="GCP Network ID")
 
 # Test server
-SERVER_NAME = flags.DEFINE_string(
-    "server_name", default="psm-grpc-server",
-    help="Server deployment and service name")
-SERVER_PORT = flags.DEFINE_integer(
-    "server_port", default=8080,
-    help="Server test port")
-SERVER_XDS_HOST = flags.DEFINE_string(
-    "server_xds_host", default='xds-test-server',
-    help="Test server xDS hostname")
-SERVER_XDS_PORT = flags.DEFINE_integer(
-    "server_xds_port", default=8000, help="Test server xDS port")
+SERVER_NAME = flags.DEFINE_string("server_name",
+                                  default="psm-grpc-server",
+                                  help="Server deployment and service name")
+SERVER_PORT = flags.DEFINE_integer("server_port",
+                                   default=8080,
+                                   help="Server test port")
+SERVER_XDS_HOST = flags.DEFINE_string("server_xds_host",
+                                      default='xds-test-server',
+                                      help="Test server xDS hostname")
+SERVER_XDS_PORT = flags.DEFINE_integer("server_xds_port",
+                                       default=8000,
+                                       help="Test server xDS port")
 
 # Test client
-CLIENT_NAME = flags.DEFINE_string(
-    "client_name", default="psm-grpc-client",
-    help="Client deployment and service name")
-CLIENT_PORT = flags.DEFINE_integer(
-    "client_port", default=8079,
-    help="Client test port")
-
+CLIENT_NAME = flags.DEFINE_string("client_name",
+                                  default="psm-grpc-client",
+                                  help="Client deployment and service name")
+CLIENT_PORT = flags.DEFINE_integer("client_port",
+                                   default=8079,
+                                   help="Client test port")
 
 flags.mark_flags_as_required([
     "project",

+ 15 - 11
tools/run_tests/xds_test_driver/framework/xds_k8s_flags.py

@@ -14,24 +14,28 @@
 from absl import flags
 
 # GCP
-KUBE_CONTEXT = flags.DEFINE_string(
-    "kube_context", default=None, help="Kubectl context to use")
+KUBE_CONTEXT = flags.DEFINE_string("kube_context",
+                                   default=None,
+                                   help="Kubectl context to use")
 GCP_SERVICE_ACCOUNT = flags.DEFINE_string(
-    "gcp_service_account", default=None,
+    "gcp_service_account",
+    default=None,
     help="GCP Service account for GKE workloads to impersonate")
 TD_BOOTSTRAP_IMAGE = flags.DEFINE_string(
-    "td_bootstrap_image", default=None,
+    "td_bootstrap_image",
+    default=None,
     help="Traffic Director gRPC Bootstrap Docker image")
 
 # Test app
-SERVER_IMAGE = flags.DEFINE_string(
-    "server_image", default=None,
-    help="Server Docker image name")
-CLIENT_IMAGE = flags.DEFINE_string(
-    "client_image", default=None,
-    help="Client Docker image name")
+SERVER_IMAGE = flags.DEFINE_string("server_image",
+                                   default=None,
+                                   help="Server Docker image name")
+CLIENT_IMAGE = flags.DEFINE_string("client_image",
+                                   default=None,
+                                   help="Client Docker image name")
 CLIENT_PORT_FORWARDING = flags.DEFINE_bool(
-    "client_debug_use_port_forwarding", default=False,
+    "client_debug_use_port_forwarding",
+    default=False,
     help="Development only: use kubectl port-forward to connect to test client")
 
 flags.mark_flags_as_required([

+ 70 - 76
tools/run_tests/xds_test_driver/framework/xds_k8s_testcase.py

@@ -107,11 +107,9 @@ class XdsKubernetesTestCase(absltest.TestCase):
         # Add backends to the Backend Service
         self.td.backend_service_add_neg_backends(neg_name, neg_zones)
 
-    def assertSuccessfulRpcs(
-        self,
-        test_client: XdsTestClient,
-        num_rpcs: int = 100
-    ):
+    def assertSuccessfulRpcs(self,
+                             test_client: XdsTestClient,
+                             num_rpcs: int = 100):
         # Run the test
         lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
         # Check the results
@@ -123,17 +121,20 @@ class XdsKubernetesTestCase(absltest.TestCase):
         logger.info(lb_stats.rpcs_by_peer)
         for backend, rpcs_count in lb_stats.rpcs_by_peer.items():
             self.assertGreater(
-                int(rpcs_count), 0,
+                int(rpcs_count),
+                0,
                 msg='Backend {backend} did not receive a single RPC')
 
     def assertFailedRpcsAtMost(self, lb_stats, limit):
         failed = int(lb_stats.num_failures)
         self.assertLessEqual(
-            failed, limit,
+            failed,
+            limit,
             msg=f'Unexpected number of RPC failures {failed} > {limit}')
 
 
 class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
+
     def setUp(self):
         super().setUp()
 
@@ -168,15 +169,13 @@ class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
             reuse_namespace=self.server_namespace == self.client_namespace)
 
     def startTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
-        test_server = self.server_runner.run(
-            replica_count=replica_count,
-            test_port=self.server_port,
-            **kwargs)
+        test_server = self.server_runner.run(replica_count=replica_count,
+                                             test_port=self.server_port,
+                                             **kwargs)
         test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
         return test_server
 
-    def startTestClient(self,
-                        test_server: XdsTestServer,
+    def startTestClient(self, test_server: XdsTestServer,
                         **kwargs) -> XdsTestClient:
         test_client = self.client_runner.run(server_target=test_server.xds_uri,
                                              **kwargs)
@@ -187,6 +186,7 @@ class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
 
 
 class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
+
     class SecurityMode(enum.Enum):
         MTLS = enum.auto()
         TLS = enum.auto()
@@ -229,43 +229,39 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
             debug_use_port_forwarding=self.client_port_forwarding)
 
     def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
-        test_server = self.server_runner.run(
-            replica_count=replica_count,
-            test_port=self.server_port,
-            maintenance_port=8081,
-            secure_mode=True,
-            **kwargs)
+        test_server = self.server_runner.run(replica_count=replica_count,
+                                             test_port=self.server_port,
+                                             maintenance_port=8081,
+                                             secure_mode=True,
+                                             **kwargs)
         test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
         return test_server
 
-    def setupSecurityPolicies(self, *,
-                              server_tls, server_mtls,
-                              client_tls, client_mtls):
-        self.td.setup_client_security(self.server_namespace, self.server_name,
-                                      tls=client_tls, mtls=client_mtls)
+    def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls,
+                              client_mtls):
+        self.td.setup_client_security(self.server_namespace,
+                                      self.server_name,
+                                      tls=client_tls,
+                                      mtls=client_mtls)
         self.td.setup_server_security(self.server_port,
-                                      tls=server_tls, mtls=server_mtls)
-
-    def startSecureTestClient(
-        self,
-        test_server: XdsTestServer,
-        **kwargs
-    ) -> XdsTestClient:
-        test_client = self.client_runner.run(
-            server_target=test_server.xds_uri,
-            secure_mode=True,
-            **kwargs)
+                                      tls=server_tls,
+                                      mtls=server_mtls)
+
+    def startSecureTestClient(self, test_server: XdsTestServer,
+                              **kwargs) -> XdsTestClient:
+        test_client = self.client_runner.run(server_target=test_server.xds_uri,
+                                             secure_mode=True,
+                                             **kwargs)
         logger.debug('Waiting fot the client to establish healthy channel with '
                      'the server')
         test_client.wait_for_active_server_channel()
         return test_client
 
-    def assertTestAppSecurity(self,
-                              mode: SecurityMode,
+    def assertTestAppSecurity(self, mode: SecurityMode,
                               test_client: XdsTestClient,
                               test_server: XdsTestServer):
-        client_socket, server_socket = self.getConnectedSockets(test_client,
-                                                                test_server)
+        client_socket, server_socket = self.getConnectedSockets(
+            test_client, test_server)
         server_security: grpc_channelz.Security = server_socket.security
         client_security: grpc_channelz.Security = client_socket.security
         logger.info('Server certs: %s', self.debug_sock_certs(server_security))
@@ -280,72 +276,70 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
         else:
             raise TypeError(f'Incorrect security mode')
 
-    def assertSecurityMtls(self,
-                           client_security: grpc_channelz.Security,
+    def assertSecurityMtls(self, client_security: grpc_channelz.Security,
                            server_security: grpc_channelz.Security):
-        self.assertEqual(client_security.WhichOneof('model'), 'tls',
+        self.assertEqual(client_security.WhichOneof('model'),
+                         'tls',
                          msg='(mTLS) Client socket security model must be TLS')
-        self.assertEqual(server_security.WhichOneof('model'), 'tls',
+        self.assertEqual(server_security.WhichOneof('model'),
+                         'tls',
                          msg='(mTLS) Server socket security model must be TLS')
         server_tls, client_tls = server_security.tls, client_security.tls
 
         # Confirm regular TLS: server local cert == client remote cert
-        self.assertNotEmpty(
+        self.assertNotEmpty(server_tls.local_certificate,
+                            msg="(mTLS) Server local certificate is missing")
+        self.assertNotEmpty(client_tls.remote_certificate,
+                            msg="(mTLS) Client remote certificate is missing")
+        self.assertEqual(
             server_tls.local_certificate,
-            msg="(mTLS) Server local certificate is missing")
-        self.assertNotEmpty(
             client_tls.remote_certificate,
-            msg="(mTLS) Client remote certificate is missing")
-        self.assertEqual(
-            server_tls.local_certificate, client_tls.remote_certificate,
             msg="(mTLS) Server local certificate must match client's "
-                "remote certificate")
+            "remote certificate")
 
         # mTLS: server remote cert == client local cert
-        self.assertNotEmpty(
+        self.assertNotEmpty(server_tls.remote_certificate,
+                            msg="(mTLS) Server remote certificate is missing")
+        self.assertNotEmpty(client_tls.local_certificate,
+                            msg="(mTLS) Client local certificate is missing")
+        self.assertEqual(
             server_tls.remote_certificate,
-            msg="(mTLS) Server remote certificate is missing")
-        self.assertNotEmpty(
             client_tls.local_certificate,
-            msg="(mTLS) Client local certificate is missing")
-        self.assertEqual(
-            server_tls.remote_certificate, client_tls.local_certificate,
             msg="(mTLS) Server remote certificate must match client's "
-                "local certificate")
+            "local certificate")
 
         # Success
         logger.info('mTLS security mode  confirmed!')
 
-    def assertSecurityTls(self,
-                          client_security: grpc_channelz.Security,
+    def assertSecurityTls(self, client_security: grpc_channelz.Security,
                           server_security: grpc_channelz.Security):
-        self.assertEqual(client_security.WhichOneof('model'), 'tls',
+        self.assertEqual(client_security.WhichOneof('model'),
+                         'tls',
                          msg='(TLS) Client socket security model must be TLS')
-        self.assertEqual(server_security.WhichOneof('model'), 'tls',
+        self.assertEqual(server_security.WhichOneof('model'),
+                         'tls',
                          msg='(TLS) Server socket security model must be TLS')
         server_tls, client_tls = server_security.tls, client_security.tls
 
         # Regular TLS: server local cert == client remote cert
-        self.assertNotEmpty(
-            server_tls.local_certificate,
-            msg="(TLS) Server local certificate is missing")
-        self.assertNotEmpty(
-            client_tls.remote_certificate,
-            msg="(TLS) Client remote certificate is missing")
-        self.assertEqual(
-            server_tls.local_certificate, client_tls.remote_certificate,
-            msg="(TLS) Server local certificate must match client "
-                "remote certificate")
+        self.assertNotEmpty(server_tls.local_certificate,
+                            msg="(TLS) Server local certificate is missing")
+        self.assertNotEmpty(client_tls.remote_certificate,
+                            msg="(TLS) Client remote certificate is missing")
+        self.assertEqual(server_tls.local_certificate,
+                         client_tls.remote_certificate,
+                         msg="(TLS) Server local certificate must match client "
+                         "remote certificate")
 
         # mTLS must not be used
         self.assertEmpty(
             server_tls.remote_certificate,
             msg="(TLS) Server remote certificate must be empty in TLS mode. "
-                "Is server security incorrectly configured for mTLS?")
+            "Is server security incorrectly configured for mTLS?")
         self.assertEmpty(
             client_tls.local_certificate,
             msg="(TLS) Client local certificate must be empty in TLS mode. "
-                "Is client security incorrectly configured for mTLS?")
+            "Is client security incorrectly configured for mTLS?")
 
         # Success
         logger.info('TLS security mode confirmed!')
@@ -373,8 +367,7 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
 
     @staticmethod
     def getConnectedSockets(
-        test_client: XdsTestClient,
-        test_server: XdsTestServer
+        test_client: XdsTestClient, test_server: XdsTestServer
     ) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]:
         client_sock = test_client.get_client_socket_with_test_server()
         server_sock = test_server.get_server_socket_matching_client(client_sock)
@@ -390,6 +383,7 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
 
     @staticmethod
     def debug_cert(cert):
-        if not cert: return 'missing'
+        if not cert:
+            return 'missing'
         sha1 = hashlib.sha1(cert)
         return f'sha1={sha1.hexdigest()}, len={len(cert)}'

+ 1 - 0
tools/run_tests/xds_test_driver/tests/baseline_test.py

@@ -27,6 +27,7 @@ XdsTestClient = xds_k8s_testcase.XdsTestClient
 
 
 class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase):
+
     def test_ping_pong(self):
         self.setupTrafficDirectorGrpc()
 

+ 15 - 8
tools/run_tests/xds_test_driver/tests/security_test.py

@@ -29,10 +29,13 @@ SecurityMode = xds_k8s_testcase.SecurityXdsKubernetesTestCase.SecurityMode
 
 
 class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
+
     def test_mtls(self):
         self.setupTrafficDirectorGrpc()
-        self.setupSecurityPolicies(server_tls=True, server_mtls=True,
-                                   client_tls=True, client_mtls=True)
+        self.setupSecurityPolicies(server_tls=True,
+                                   server_mtls=True,
+                                   client_tls=True,
+                                   client_mtls=True)
 
         test_server: XdsTestServer = self.startSecureTestServer()
         self.setupServerBackends()
@@ -43,8 +46,10 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
 
     def test_tls(self):
         self.setupTrafficDirectorGrpc()
-        self.setupSecurityPolicies(server_tls=True, server_mtls=False,
-                                   client_tls=True, client_mtls=False)
+        self.setupSecurityPolicies(server_tls=True,
+                                   server_mtls=False,
+                                   client_tls=True,
+                                   client_mtls=False)
 
         test_server: XdsTestServer = self.startSecureTestServer()
         self.setupServerBackends()
@@ -55,15 +60,17 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
 
     def test_plaintext_fallback(self):
         self.setupTrafficDirectorGrpc()
-        self.setupSecurityPolicies(server_tls=False, server_mtls=False,
-                                   client_tls=False, client_mtls=False)
+        self.setupSecurityPolicies(server_tls=False,
+                                   server_mtls=False,
+                                   client_tls=False,
+                                   client_mtls=False)
 
         test_server: XdsTestServer = self.startSecureTestServer()
         self.setupServerBackends()
         test_client: XdsTestClient = self.startSecureTestClient(test_server)
 
-        self.assertTestAppSecurity(
-            SecurityMode.PLAINTEXT, test_client, test_server)
+        self.assertTestAppSecurity(SecurityMode.PLAINTEXT, test_client,
+                                   test_server)
         self.assertSuccessfulRpcs(test_client)
 
     @absltest.skip(SKIP_REASON)