#!/usr/bin/env python3

import argparse
import configparser
import dns.zone
import re
import sys
from pathlib import Path

config = configparser.ConfigParser(
        interpolation=configparser.ExtendedInterpolation())

try:
    with open('./sshgen.cfg') as fp:
        config.read_file(fp)
except Exception:
    pass
if 'presets' not in config:
    config['presets'] = {}
    choices = []
    choice_default = None
else:
    choices = list(config['presets'])
    choice_default = list(config['presets'].keys())[0]

parser = argparse.ArgumentParser(
        description='Generates a SSH config file from some DNS zone(s).')
if choices:
    parser.add_argument('--preset', choices=choices, default=choice_default,
                        help='select a configuration preset')
else:
    parser.add_argument('--preset', action='store', default=choice_default,
                        help='select a configuration preset')
parser.add_argument('--cfg', action='store', default='./sshgen.cfg',
                    help='config file')
args = parser.parse_args()
preset = args.preset

if args.cfg != './sshgen.cfg':
    with open(args.cfg) as fp:
        config.read_file(fp)
if preset not in list(config['presets']):
    sys.exit('preset not in presets configuration')


def get_zones():
    all_zones = []
    for x, y in config['zones'].items():
        p = Path(y)
        if p.is_dir():
            all_zones.extend([z for z in p.iterdir()])
        elif p.is_file():
            all_zones.append(p)
        else:
            print('incorrectly configured zone {}, skipping'.format(x),
                  file=sys.stderr)
    return all_zones


def get_zone_file(zone):
    with open(str(zone), 'r') as fp:
        return '\n'.join(fp.readlines())


def retrieve_hosts():
    d = get_zones()
    h = {}
    i = {}

    for k in d:
        try:
            z = dns.zone.from_text(get_zone_file(k), relativize=False)

            # TODO AAAA records (and others)
            for (name, ttl, rdata) in z.iterate_rdatas('A'):
                host = h.get(name)
                if host is None:
                    host = []
                    h[name] = host
                addr = i.get(rdata.address)
                if addr is None:
                    addr = []
                    i[rdata.address] = addr
                addr.append(name)
            for (name, ttl, rdata) in z.iterate_rdatas('CNAME'):
                target = h.get(rdata.target)
                if target is None:
                    target = []
                    h[rdata.target] = target
                target.append(name)
        except dns.zone.UnknownOrigin:
            for line in get_zone_file(k).splitlines():
                if line.startswith('#') or len(line.strip()) < 3:
                    continue
                parts = line.split()
                addr = i.get(parts[0])
                if addr is None:
                    addr = []
                    i[parts[0]] = addr
                addr.append(parts[1])
                host = h.get(parts[1])
                if host is None:
                    host = []
                    h[parts[1]] = host
                if len(parts) > 2:
                    for alt in parts[2:]:
                        if alt == '#':
                            break
                        host.append(alt)

    req_set = set(s.strip() for s in config['excludes']['required'].split(',')
                  if s.strip())

    def intersects(s):
        cmp_set = s.copy()
        for x in s:
            parts = x.split('.')
            for n in range(len(parts)):
                start = -1 - n
                cmp_set.add('.'.join(parts[start:]))
        return len(req_set.intersection(cmp_set)) > 0
    h = {k: v for k, v in h.items() if intersects(set([k]+v))}

    fin = False
    while not fin:
        fin = True
        for key in h:
            for k in h:
                if key in h[k]:
                    h[k].extend(h[key])
                    del h[key]
                    fin = False
                    break
            if not fin:
                break
    for key in i:
        if len(i[key]) > 1:
            max_len = 0
            max_host = None
            for j in i[key]:
                if j in h:  # TODO
                    x = len(str(h[j]))
                    if x > max_len:
                        max_len = x
                        max_host = j
            for j in i[key]:
                if j == max_host or j not in h:  # TODO
                    continue
                h[max_host].append(j)
                h[max_host].extend(h[j])
                del h[j]

    return h


proxies = {}
strip_domains = []
preset_config = [k.strip() for k in config['presets'][preset].split(',')]
for c in preset_config:
    if c.startswith('proxies_'):
        proxies.update({re.compile(k.strip()): v
                        for v in config[c] for k in config[c][v].split(',')})
    elif c.startswith('strip_'):
        strip_options = config['strips'][c[len('strip_'):]]
        strip_domains.extend([re.compile(r'\.{}\.?'.format(k.strip()))
                              for k in strip_options.split(',')])
    else:
        pass

exclude_hosts = [re.compile(x.strip())
                 for x in config['excludes']['hosts'].split(',')]
exclude_aliases = [re.compile(x.strip())
                   for x in config['excludes']['aliases'].split(',')]
usernames = {re.compile(k.strip()): v
             for v in config['usernames']
             for k in config['usernames'][v].split(',')}
agents = {re.compile(k.strip()): True
          for k in config['agents']['enabled'].split(',')}
agents.update({re.compile(k.strip()): False
               for k in config['agents']['disabled'].split(',')})

h = {}
h = retrieve_hosts()


def modify_list(h):
    for e in exclude_hosts:
        h = {l: m for l, m in h.items() if not e.match(str(l))}
        for k in h:
            h[k] = [l for l in h[k] if not e.match(str(l))]
    for e in exclude_aliases:
        ni = {}
        for k in h:
            h[k] = [l for l in h[k] if not e.match(str(l))]
            if e.match(str(k)):
                ni[h[k][0]] = h[k][1:]
        h.update(ni)
        h = {l: m for l, m in h.items() if not e.match(str(l))}
    for k in h:
        for ak, av in config['aliases'].items():
            if str(k) == ak or str(k)[:-1] == ak:
                h[k].extend([x.strip() for x in av.split(',')])
    return h


h = modify_list(h)


def re_suffix(pattern, text):
    res = pattern.search(text)
    if res and res.span()[1] == len(text) and res.span()[0] != 0:
        return text[res.span()[0]:res.span()[1]]
    return None


for k in h:
    c = [str(k)]
    c.extend([str(k)[:-len(re_suffix(d, str(k)))] for d in strip_domains
             if re_suffix(d, str(k))])
    c.extend(map(str, h[k]))
    for j in map(str, h[k]):
        c.extend([j[:-len(re_suffix(d, j))] for d in strip_domains
                 if re_suffix(d, j)])
    c = [x[:-1] if x.endswith('.') else x for x in c]

    print('Host ' + ' '.join(c))
    hn = str(k)
    print('\tHostName ' + (hn[:-1] if hn.endswith('.') else hn))
    for u in usernames:
        if u.match(str(k)):
            print('\tUser ' + usernames[u])
            break
    for a, v in sorted(agents.items(), key=lambda x: x[1]):
        if a.match(str(k)):
            print('\tForwardAgent ' + ('yes' if v else 'no'))
            break
    for p in proxies:
        if p.match(str(k)):
            print('\tProxyJump ' + proxies[p])
            break
    print('')