Skip to content
Snippets Groups Projects
Select Git revision
  • develop
  • debian-develop
2 results

Changelog

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    generate.py 7.35 KiB
    #!/usr/bin/env python3
    
    import argparse
    import configparser
    import dns.zone
    import re
    import sys
    from pathlib import Path
    
    config = configparser.ConfigParser(
            interpolation=configparser.ExtendedInterpolation())
    
    try:
        with open('./sshgen.cfg') as fp:
            config.read_file(fp)
    except Exception:
        pass
    if 'presets' not in config:
        config['presets'] = {}
        choices = []
        choice_default = None
    else:
        choices = list(config['presets'])
        choice_default = list(config['presets'].keys())[0]
    
    parser = argparse.ArgumentParser(
            description='Generates a SSH config file from some DNS zone(s).')
    if choices:
        parser.add_argument('--preset', choices=choices, default=choice_default,
                            help='select a configuration preset')
    else:
        parser.add_argument('--preset', action='store', default=choice_default,
                            help='select a configuration preset')
    parser.add_argument('--cfg', action='store', default='./sshgen.cfg',
                        help='config file')
    args = parser.parse_args()
    preset = args.preset
    
    if args.cfg != './sshgen.cfg':
        with open(args.cfg) as fp:
            config.read_file(fp)
    if preset not in list(config['presets']):
        sys.exit('preset not in presets configuration')
    
    
    def get_zones():
        all_zones = []
        for x, y in config['zones'].items():
            p = Path(y)
            if p.is_dir():
                all_zones.extend([z for z in p.iterdir()])
            elif p.is_file():
                all_zones.append(p)
            else:
                print('incorrectly configured zone {}, skipping'.format(x),
                      file=sys.stderr)
        return all_zones
    
    
    def get_zone_file(zone):
        with open(str(zone), 'r') as fp:
            return '\n'.join(fp.readlines())
    
    
    def retrieve_hosts():
        d = get_zones()
        h = {}
        i = {}
    
        for k in d:
            try:
                z = dns.zone.from_text(get_zone_file(k), relativize=False)
    
                # TODO AAAA records (and others)
                for (name, ttl, rdata) in z.iterate_rdatas('A'):
                    host = h.get(name)
                    if host is None:
                        host = []
                        h[name] = host
                    addr = i.get(rdata.address)
                    if addr is None:
                        addr = []
                        i[rdata.address] = addr
                    addr.append(name)
                for (name, ttl, rdata) in z.iterate_rdatas('CNAME'):
                    target = h.get(rdata.target)
                    if target is None:
                        target = []
                        h[rdata.target] = target
                    target.append(name)
            except dns.zone.UnknownOrigin:
                for line in get_zone_file(k).splitlines():
                    if line.startswith('#') or len(line.strip()) < 3:
                        continue
                    parts = line.split()
                    addr = i.get(parts[0])
                    if addr is None:
                        addr = []
                        i[parts[0]] = addr
                    addr.append(parts[1])
                    host = h.get(parts[1])
                    if host is None:
                        host = []
                        h[parts[1]] = host
                    if len(parts) > 2:
                        for alt in parts[2:]:
                            if alt == '#':
                                break
                            host.append(alt)
    
        req_set = set(s.strip() for s in config['excludes']['required'].split(',') if s.strip())
        def intersects(s):
            cmp_set = s.copy()
            for x in s:
                parts = x.split('.')
                for n in range(len(parts)):
                    start = -1 - n
                    cmp_set.add('.'.join(parts[start:]))
            return len(req_set.intersection(cmp_set)) > 0
        h = {k: v for k, v in h.items() if intersects(set([k]+v))}
    
        fin = False
        while not fin:
            fin = True
            for key in h:
                for k in h:
                    if key in h[k]:
                        h[k].extend(h[key])
                        del h[key]
                        fin = False
                        break
                if not fin:
                    break
        for key in i:
            if len(i[key]) > 1:
                max_len = 0
                max_host = None
                for j in i[key]:
                    if j in h:  # TODO
                        x = len(str(h[j]))
                        if x > max_len:
                            max_len = x
                            max_host = j
                for j in i[key]:
                    if j == max_host or j not in h:  # TODO
                        continue
                    h[max_host].append(j)
                    h[max_host].extend(h[j])
                    del h[j]
    
        return h
    
    
    proxies = {}
    strip_domains = []
    preset_config = [k.strip() for k in config['presets'][preset].split(',')]
    for c in preset_config:
        if c.startswith('proxies_'):
            proxies.update({re.compile(k.strip()): v
                            for v in config[c] for k in config[c][v].split(',')})
        elif c.startswith('strip_'):
            strip_options = config['strips'][c[len('strip_'):]]
            strip_domains.extend([re.compile(r'\.{}\.?'.format(k.strip()))
                                  for k in strip_options.split(',')])
        else:
            pass
    
    exclude_hosts = [re.compile(x.strip())
                     for x in config['excludes']['hosts'].split(',')]
    exclude_aliases = [re.compile(x.strip())
                       for x in config['excludes']['aliases'].split(',')]
    usernames = {re.compile(k.strip()): v
                 for v in config['usernames']
                 for k in config['usernames'][v].split(',')}
    agents = {re.compile(k.strip()): True
              for k in config['agents']['enabled'].split(',')}
    agents.update({re.compile(k.strip()): False
                   for k in config['agents']['disabled'].split(',')})
    
    h = {}
    h = retrieve_hosts()
    
    
    def modify_list(h):
        for e in exclude_hosts:
            h = {l: m for l, m in h.items() if not e.match(str(l))}
            for k in h:
                h[k] = [l for l in h[k] if not e.match(str(l))]
        for e in exclude_aliases:
            ni = {}
            for k in h:
                h[k] = [l for l in h[k] if not e.match(str(l))]
                if e.match(str(k)):
                    ni[h[k][0]] = h[k][1:]
            h.update(ni)
            h = {l: m for l, m in h.items() if not e.match(str(l))}
        for k in h:
            for ak, av in config['aliases'].items():
                if str(k) == ak or str(k)[:-1] == ak:
                    h[k].extend([x.strip() for x in av.split(',')])
        return h
    
    
    h = modify_list(h)
    
    
    def re_suffix(pattern, text):
        res = pattern.search(text)
        if res and res.span()[1] == len(text) and res.span()[0] != 0:
            return text[res.span()[0]:res.span()[1]]
        return None
    
    
    for k in h:
        c = [str(k)]
        c.extend([str(k)[:-len(re_suffix(d, str(k)))] for d in strip_domains
                 if re_suffix(d, str(k))])
        c.extend(map(str, h[k]))
        for j in map(str, h[k]):
            c.extend([j[:-len(re_suffix(d, j))] for d in strip_domains
                     if re_suffix(d, j)])
        c = [x[:-1] if x.endswith('.') else x for x in c]
    
        print('Host ' + ' '.join(c))
        hn = str(k)
        print('\tHostName ' + (hn[:-1] if hn.endswith('.') else hn))
        for u in usernames:
            if u.match(str(k)):
                print('\tUser ' + usernames[u])
                break
        for a, v in sorted(agents.items(), key=lambda x: x[1]):
            if a.match(str(k)):
                print('\tForwardAgent ' + ('yes' if v else 'no'))
                break
        for p in proxies:
            if p.match(str(k)):
                print('\tProxyJump ' + proxies[p])
                break
        print('')