| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 | 
							- #!/usr/bin/env python2.7
 
- # Copyright 2015 gRPC authors.
 
- #
 
- # Licensed under the Apache License, Version 2.0 (the "License");
 
- # you may not use this file except in compliance with the License.
 
- # You may obtain a copy of the License at
 
- #
 
- #     http://www.apache.org/licenses/LICENSE-2.0
 
- #
 
- # Unless required by applicable law or agreed to in writing, software
 
- # distributed under the License is distributed on an "AS IS" BASIS,
 
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 
- # See the License for the specific language governing permissions and
 
- # limitations under the License.
 
- """Starts a local DNS server for use in tests"""
 
- import argparse
 
- import sys
 
- import yaml
 
- import signal
 
- import os
 
- import threading
 
- import time
 
- import twisted
 
- import twisted.internet
 
- import twisted.internet.reactor
 
- import twisted.internet.threads
 
- import twisted.internet.defer
 
- import twisted.internet.protocol
 
- import twisted.names
 
- import twisted.names.client
 
- import twisted.names.dns
 
- import twisted.names.server
 
- from twisted.names import client, server, common, authority, dns
 
- import argparse
 
- import platform
 
- _SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp'  # missing end '.' for twisted syntax
 
- _SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123'
 
- class NoFileAuthority(authority.FileAuthority):
 
-     def __init__(self, soa, records):
 
-         # skip FileAuthority
 
-         common.ResolverBase.__init__(self)
 
-         self.soa = soa
 
-         self.records = records
 
- def start_local_dns_server(args):
 
-     all_records = {}
 
-     def _push_record(name, r):
 
-         print('pushing record: |%s|' % name)
 
-         if all_records.get(name) is not None:
 
-             all_records[name].append(r)
 
-             return
 
-         all_records[name] = [r]
 
-     def _maybe_split_up_txt_data(name, txt_data, r_ttl):
 
-         start = 0
 
-         txt_data_list = []
 
-         while len(txt_data[start:]) > 0:
 
-             next_read = len(txt_data[start:])
 
-             if next_read > 255:
 
-                 next_read = 255
 
-             txt_data_list.append(txt_data[start:start + next_read])
 
-             start += next_read
 
-         _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl))
 
-     with open(args.records_config_path) as config:
 
-         test_records_config = yaml.load(config)
 
-     common_zone_name = test_records_config['resolver_tests_common_zone_name']
 
-     for group in test_records_config['resolver_component_tests']:
 
-         for name in group['records'].keys():
 
-             for record in group['records'][name]:
 
-                 r_type = record['type']
 
-                 r_data = record['data']
 
-                 r_ttl = int(record['TTL'])
 
-                 record_full_name = '%s.%s' % (name, common_zone_name)
 
-                 assert record_full_name[-1] == '.'
 
-                 record_full_name = record_full_name[:-1]
 
-                 if r_type == 'A':
 
-                     _push_record(record_full_name,
 
-                                  dns.Record_A(r_data, ttl=r_ttl))
 
-                 if r_type == 'AAAA':
 
-                     _push_record(record_full_name,
 
-                                  dns.Record_AAAA(r_data, ttl=r_ttl))
 
-                 if r_type == 'SRV':
 
-                     p, w, port, target = r_data.split(' ')
 
-                     p = int(p)
 
-                     w = int(w)
 
-                     port = int(port)
 
-                     target_full_name = '%s.%s' % (target, common_zone_name)
 
-                     r_data = '%s %s %s %s' % (p, w, port, target_full_name)
 
-                     _push_record(
 
-                         record_full_name,
 
-                         dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl))
 
-                 if r_type == 'TXT':
 
-                     _maybe_split_up_txt_data(record_full_name, r_data, r_ttl)
 
-     # Add an optional IPv4 record is specified
 
-     if args.add_a_record:
 
-         extra_host, extra_host_ipv4 = args.add_a_record.split(':')
 
-         _push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0))
 
-     # Server health check record
 
-     _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME,
 
-                  dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0))
 
-     soa_record = dns.Record_SOA(mname=common_zone_name)
 
-     test_domain_com = NoFileAuthority(
 
-         soa=(common_zone_name, soa_record),
 
-         records=all_records,
 
-     )
 
-     server = twisted.names.server.DNSServerFactory(
 
-         authorities=[test_domain_com], verbose=2)
 
-     server.noisy = 2
 
-     twisted.internet.reactor.listenTCP(args.port, server)
 
-     dns_proto = twisted.names.dns.DNSDatagramProtocol(server)
 
-     dns_proto.noisy = 2
 
-     twisted.internet.reactor.listenUDP(args.port, dns_proto)
 
-     print('starting local dns server on 127.0.0.1:%s' % args.port)
 
-     print('starting twisted.internet.reactor')
 
-     twisted.internet.reactor.suggestThreadPoolSize(1)
 
-     twisted.internet.reactor.run()
 
- def _quit_on_signal(signum, _frame):
 
-     print('Received SIGNAL %d. Quitting with exit code 0' % signum)
 
-     twisted.internet.reactor.stop()
 
-     sys.stdout.flush()
 
-     sys.exit(0)
 
- def flush_stdout_loop():
 
-     num_timeouts_so_far = 0
 
-     sleep_time = 1
 
-     # Prevent zombies. Tests that use this server are short-lived.
 
-     max_timeouts = 60 * 10
 
-     while num_timeouts_so_far < max_timeouts:
 
-         sys.stdout.flush()
 
-         time.sleep(sleep_time)
 
-         num_timeouts_so_far += 1
 
-     print('Process timeout reached, or cancelled. Exitting 0.')
 
-     os.kill(os.getpid(), signal.SIGTERM)
 
- def main():
 
-     argp = argparse.ArgumentParser(
 
-         description='Local DNS Server for resolver tests')
 
-     argp.add_argument('-p',
 
-                       '--port',
 
-                       default=None,
 
-                       type=int,
 
-                       help='Port for DNS server to listen on for TCP and UDP.')
 
-     argp.add_argument(
 
-         '-r',
 
-         '--records_config_path',
 
-         default=None,
 
-         type=str,
 
-         help=('Directory of resolver_test_record_groups.yaml file. '
 
-               'Defaults to path needed when the test is invoked as part '
 
-               'of run_tests.py.'))
 
-     argp.add_argument(
 
-         '--add_a_record',
 
-         default=None,
 
-         type=str,
 
-         help=('Add an A record via the command line. Useful for when we '
 
-               'need to serve a one-off A record that is under a '
 
-               'different domain then the rest the records configured in '
 
-               '--records_config_path (which all need to be under the '
 
-               'same domain). Format: <name>:<ipv4 address>'))
 
-     args = argp.parse_args()
 
-     signal.signal(signal.SIGTERM, _quit_on_signal)
 
-     signal.signal(signal.SIGINT, _quit_on_signal)
 
-     output_flush_thread = threading.Thread(target=flush_stdout_loop)
 
-     output_flush_thread.setDaemon(True)
 
-     output_flush_thread.start()
 
-     start_local_dns_server(args)
 
- if __name__ == '__main__':
 
-     main()
 
 
  |