k8s.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright 2020 gRPC authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import json
  16. import logging
  17. import subprocess
  18. import time
  19. from typing import Optional, List, Tuple
  20. # TODO(sergiitk): replace with tenacity
  21. import retrying
  22. import kubernetes.config
  23. from kubernetes import client
  24. from kubernetes import utils
  25. logger = logging.getLogger(__name__)
  26. # Type aliases
  27. V1Deployment = client.V1Deployment
  28. V1ServiceAccount = client.V1ServiceAccount
  29. V1Pod = client.V1Pod
  30. V1PodList = client.V1PodList
  31. V1Service = client.V1Service
  32. V1Namespace = client.V1Namespace
  33. ApiException = client.ApiException
  34. def simple_resource_get(func):
  35. def wrap_not_found_return_none(*args, **kwargs):
  36. try:
  37. return func(*args, **kwargs)
  38. except client.ApiException as e:
  39. if e.status == 404:
  40. # Ignore 404
  41. return None
  42. raise
  43. return wrap_not_found_return_none
  44. def label_dict_to_selector(labels: dict) -> str:
  45. return ','.join(f'{k}=={v}' for k, v in labels.items())
  46. class KubernetesApiManager:
  47. def __init__(self, context):
  48. self.context = context
  49. self.client = self._cached_api_client_for_context(context)
  50. self.apps = client.AppsV1Api(self.client)
  51. self.core = client.CoreV1Api(self.client)
  52. def close(self):
  53. self.client.close()
  54. @classmethod
  55. @functools.lru_cache(None)
  56. def _cached_api_client_for_context(cls, context: str) -> client.ApiClient:
  57. return kubernetes.config.new_client_from_config(context=context)
  58. class PortForwardingError(Exception):
  59. """Error forwarding port"""
  60. class KubernetesNamespace:
  61. NEG_STATUS_META = 'cloud.google.com/neg-status'
  62. PORT_FORWARD_LOCAL_ADDRESS: str = '127.0.0.1'
  63. DELETE_GRACE_PERIOD_SEC: int = 5
  64. def __init__(self, api: KubernetesApiManager, name: str):
  65. self.name = name
  66. self.api = api
  67. def apply_manifest(self, manifest):
  68. return utils.create_from_dict(self.api.client,
  69. manifest,
  70. namespace=self.name)
  71. @simple_resource_get
  72. def get_service(self, name) -> V1Service:
  73. return self.api.core.read_namespaced_service(name, self.name)
  74. @simple_resource_get
  75. def get_service_account(self, name) -> V1Service:
  76. return self.api.core.read_namespaced_service_account(name, self.name)
  77. def delete_service(self, name,
  78. grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
  79. self.api.core.delete_namespaced_service(
  80. name=name,
  81. namespace=self.name,
  82. body=client.V1DeleteOptions(
  83. propagation_policy='Foreground',
  84. grace_period_seconds=grace_period_seconds))
  85. def delete_service_account(self,
  86. name,
  87. grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
  88. self.api.core.delete_namespaced_service_account(
  89. name=name,
  90. namespace=self.name,
  91. body=client.V1DeleteOptions(
  92. propagation_policy='Foreground',
  93. grace_period_seconds=grace_period_seconds))
  94. @simple_resource_get
  95. def get(self) -> V1Namespace:
  96. return self.api.core.read_namespace(self.name)
  97. def delete(self, grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
  98. self.api.core.delete_namespace(
  99. name=self.name,
  100. body=client.V1DeleteOptions(
  101. propagation_policy='Foreground',
  102. grace_period_seconds=grace_period_seconds))
  103. def wait_for_service_deleted(self, name: str, timeout_sec=60, wait_sec=1):
  104. @retrying.retry(retry_on_result=lambda r: r is not None,
  105. stop_max_delay=timeout_sec * 1000,
  106. wait_fixed=wait_sec * 1000)
  107. def _wait_for_deleted_service_with_retry():
  108. service = self.get_service(name)
  109. if service is not None:
  110. logger.info('Waiting for service %s to be deleted',
  111. service.metadata.name)
  112. return service
  113. _wait_for_deleted_service_with_retry()
  114. def wait_for_service_account_deleted(self,
  115. name: str,
  116. timeout_sec=60,
  117. wait_sec=1):
  118. @retrying.retry(retry_on_result=lambda r: r is not None,
  119. stop_max_delay=timeout_sec * 1000,
  120. wait_fixed=wait_sec * 1000)
  121. def _wait_for_deleted_service_account_with_retry():
  122. service_account = self.get_service_account(name)
  123. if service_account is not None:
  124. logger.info('Waiting for service account %s to be deleted',
  125. service_account.metadata.name)
  126. return service_account
  127. _wait_for_deleted_service_account_with_retry()
  128. def wait_for_namespace_deleted(self, timeout_sec=240, wait_sec=2):
  129. @retrying.retry(retry_on_result=lambda r: r is not None,
  130. stop_max_delay=timeout_sec * 1000,
  131. wait_fixed=wait_sec * 1000)
  132. def _wait_for_deleted_namespace_with_retry():
  133. namespace = self.get()
  134. if namespace is not None:
  135. logger.info('Waiting for namespace %s to be deleted',
  136. namespace.metadata.name)
  137. return namespace
  138. _wait_for_deleted_namespace_with_retry()
  139. def wait_for_service_neg(self, name: str, timeout_sec=60, wait_sec=1):
  140. @retrying.retry(retry_on_result=lambda r: not r,
  141. stop_max_delay=timeout_sec * 1000,
  142. wait_fixed=wait_sec * 1000)
  143. def _wait_for_service_neg():
  144. service = self.get_service(name)
  145. if self.NEG_STATUS_META not in service.metadata.annotations:
  146. logger.info('Waiting for service %s NEG', service.metadata.name)
  147. return False
  148. return True
  149. _wait_for_service_neg()
  150. def get_service_neg(self, service_name: str,
  151. service_port: int) -> Tuple[str, List[str]]:
  152. service = self.get_service(service_name)
  153. neg_info: dict = json.loads(
  154. service.metadata.annotations[self.NEG_STATUS_META])
  155. neg_name: str = neg_info['network_endpoint_groups'][str(service_port)]
  156. neg_zones: List[str] = neg_info['zones']
  157. return neg_name, neg_zones
  158. @simple_resource_get
  159. def get_deployment(self, name) -> V1Deployment:
  160. return self.api.apps.read_namespaced_deployment(name, self.name)
  161. def delete_deployment(self,
  162. name,
  163. grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
  164. self.api.apps.delete_namespaced_deployment(
  165. name=name,
  166. namespace=self.name,
  167. body=client.V1DeleteOptions(
  168. propagation_policy='Foreground',
  169. grace_period_seconds=grace_period_seconds))
  170. def list_deployment_pods(self, deployment: V1Deployment) -> List[V1Pod]:
  171. # V1LabelSelector.match_expressions not supported at the moment
  172. return self.list_pods_with_labels(deployment.spec.selector.match_labels)
  173. def wait_for_deployment_available_replicas(self,
  174. name,
  175. count=1,
  176. timeout_sec=60,
  177. wait_sec=1):
  178. @retrying.retry(
  179. retry_on_result=lambda r: not self._replicas_available(r, count),
  180. stop_max_delay=timeout_sec * 1000,
  181. wait_fixed=wait_sec * 1000)
  182. def _wait_for_deployment_available_replicas():
  183. deployment = self.get_deployment(name)
  184. logger.info(
  185. 'Waiting for deployment %s to have %s available '
  186. 'replicas, current count %s', deployment.metadata.name, count,
  187. deployment.status.available_replicas)
  188. return deployment
  189. _wait_for_deployment_available_replicas()
  190. def wait_for_deployment_deleted(self,
  191. deployment_name: str,
  192. timeout_sec=60,
  193. wait_sec=1):
  194. @retrying.retry(retry_on_result=lambda r: r is not None,
  195. stop_max_delay=timeout_sec * 1000,
  196. wait_fixed=wait_sec * 1000)
  197. def _wait_for_deleted_deployment_with_retry():
  198. deployment = self.get_deployment(deployment_name)
  199. if deployment is not None:
  200. logger.info(
  201. 'Waiting for deployment %s to be deleted. '
  202. 'Non-terminated replicas: %s', deployment.metadata.name,
  203. deployment.status.replicas)
  204. return deployment
  205. _wait_for_deleted_deployment_with_retry()
  206. def list_pods_with_labels(self, labels: dict) -> List[V1Pod]:
  207. pod_list: V1PodList = self.api.core.list_namespaced_pod(
  208. self.name, label_selector=label_dict_to_selector(labels))
  209. return pod_list.items
  210. def get_pod(self, name) -> client.V1Pod:
  211. return self.api.core.read_namespaced_pod(name, self.name)
  212. def wait_for_pod_started(self, pod_name, timeout_sec=60, wait_sec=1):
  213. @retrying.retry(retry_on_result=lambda r: not self._pod_started(r),
  214. stop_max_delay=timeout_sec * 1000,
  215. wait_fixed=wait_sec * 1000)
  216. def _wait_for_pod_started():
  217. pod = self.get_pod(pod_name)
  218. logger.info('Waiting for pod %s to start, current phase: %s',
  219. pod.metadata.name, pod.status.phase)
  220. return pod
  221. _wait_for_pod_started()
  222. def port_forward_pod(
  223. self,
  224. pod: V1Pod,
  225. remote_port: int,
  226. local_port: Optional[int] = None,
  227. local_address: Optional[str] = None,
  228. ) -> subprocess.Popen:
  229. """Experimental"""
  230. local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS
  231. local_port = local_port or remote_port
  232. cmd = [
  233. "kubectl", "--context", self.api.context, "--namespace", self.name,
  234. "port-forward", "--address", local_address,
  235. f"pod/{pod.metadata.name}", f"{local_port}:{remote_port}"
  236. ]
  237. pf = subprocess.Popen(cmd,
  238. stdout=subprocess.PIPE,
  239. stderr=subprocess.STDOUT,
  240. universal_newlines=True)
  241. # Wait for stdout line indicating successful start.
  242. expected = (f"Forwarding from {local_address}:{local_port}"
  243. f" -> {remote_port}")
  244. try:
  245. while True:
  246. time.sleep(0.05)
  247. output = pf.stdout.readline().strip()
  248. if not output:
  249. return_code = pf.poll()
  250. if return_code is not None:
  251. errors = [error for error in pf.stdout.readlines()]
  252. raise PortForwardingError(
  253. 'Error forwarding port, kubectl return '
  254. f'code {return_code}, output {errors}')
  255. elif output != expected:
  256. raise PortForwardingError(
  257. f'Error forwarding port, unexpected output {output}')
  258. else:
  259. logger.info(output)
  260. break
  261. except Exception:
  262. self.port_forward_stop(pf)
  263. raise
  264. # TODO(sergiitk): return new PortForwarder object
  265. return pf
  266. @staticmethod
  267. def port_forward_stop(pf):
  268. logger.info('Shutting down port forwarding, pid %s', pf.pid)
  269. pf.kill()
  270. stdout, _stderr = pf.communicate(timeout=5)
  271. logger.info('Port forwarding stopped')
  272. # TODO(sergiitk): make debug
  273. logger.info('Port forwarding remaining stdout: %s', stdout)
  274. @staticmethod
  275. def _pod_started(pod: V1Pod):
  276. return pod.status.phase not in ('Pending', 'Unknown')
  277. @staticmethod
  278. def _replicas_available(deployment, count):
  279. return (deployment is not None and
  280. deployment.status.available_replicas is not None and
  281. deployment.status.available_replicas >= count)