import hmac import hashlib import ssl from datetime import datetime class User: def __init__(self, username, groups, all_groups, timestamp=None, permanent=False): self.username = username self.groups = groups self.all_groups = all_groups if timestamp is not None: self.timestamp = timestamp else: self.timestamp = datetime.now() self.permanent = permanent def __repr__(self): return "<User({})>".format(self.username) def summarize(self): return ":".join(( self.username, ",".join(self.groups), ",".join(self.all_groups), str(self.timestamp.timestamp()), str(self.permanent))) @staticmethod def from_summary(summary): parts = summary.split(":", 4) if len(parts) != 5: return None (name, group_str, all_group_str, timestamp_str, permanent_str) = parts timestamp = datetime.fromtimestamp(float(timestamp_str)) groups = group_str.split(",") all_groups = all_group_str.split(",") permanent = permanent_str == "True" return User(name, groups, all_groups, timestamp, permanent) @staticmethod def from_hashstring(secure_string): summary, hash = secure_string.split("=", 1) return User.from_summary(summary) class UserManager: def __init__(self, backends): self.backends = backends def login(self, username, password, permanent=False): for backend in self.backends: if backend.authenticate(username, password): groups = sorted(list(set(backend.groups(username, password)))) all_groups = sorted(list(set(backend.all_groups( username, password)))) return User( username, groups, all_groups, permanent=permanent) return None class SecurityManager: def __init__(self, key, max_duration=300): if isinstance(key, str): key = key.encode("utf-8") self.maccer = hmac.new(key, digestmod=hashlib.sha512) self.max_duration = max_duration def hash_user(self, user): maccer = self.maccer.copy() summary = user.summarize() maccer.update(summary.encode("utf-8")) return "{}={}".format(summary, maccer.hexdigest()) def check_user(self, string): parts = string.split("=", 1) if len(parts) != 2: # wrong format, expecting summary:hash return False summary, hash = map(lambda s: s.encode("utf-8"), parts) maccer = self.maccer.copy() maccer.update(summary) user = User.from_hashstring(string) if user is None: return False session_duration = datetime.now() - user.timestamp macs_equal = hmac.compare_digest( maccer.hexdigest().encode("utf-8"), hash) time_short = int(session_duration.total_seconds()) < self.max_duration return macs_equal and (time_short or user.permanent) class StaticUserManager: def __init__(self, users): self.passwords = { username: password for (username, password, groups) in users } self.group_map = { username: tuple(groups) for (username, password, groups) in users } def __repr__(self): users = [ (username, self.passwords[username], self.group_map[username]) for username in self.passwords ] return "StaticUserManager({})".format(users) def authenticate(self, username, password): return (username in self.passwords and self.passwords[username] == password) def groups(self, username, password=None): if username in self.group_map: yield from self.group_map[username] def all_groups(self, username, password): yield from list(set( group for groups in self.group_map.values() for group in groups)) try: import ldap3 class LdapManager: def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True): self.server = ldap3.Server(host, port=port, use_ssl=use_ssl) self.user_dn = user_dn self.group_dn = group_dn def __repr__(self): return ( "LdapManager(host='{host}', user_dn='{user_dn}', " "group_dn='{group_dn}', port={port}, use_ssl={use_ssl})" .format( host=self.server.host, user_dn=self.user_dn, group_dn=self.group_dn, port=self.server.port, use_ssl=self.server.ssl)) def authenticate(self, username, password): try: connection = ldap3.Connection( self.server, self.user_dn.format(username), password) return connection.bind() except ldap3.core.exceptions.LDAPSocketOpenError: return False def groups(self, username, password=None): connection = ldap3.Connection(self.server) obj_def = ldap3.ObjectDef("posixgroup", connection) group_reader = ldap3.Reader(connection, obj_def, self.group_dn) username = username.lower() for group in group_reader.search(): members = group.memberUid.value if members is not None and username in members: yield group.cn.value def all_groups(self, username, password): connection = ldap3.Connection(self.server) obj_def = ldap3.ObjectDef("posixgroup", connection) group_reader = ldap3.Reader(connection, obj_def, self.group_dn) for group in group_reader.search(): yield group.cn.value class ADManager: def __init__(self, host, domain, user_dn, group_dn, port=636, use_ssl=True, ca_cert=None): tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED) if ca_cert is not None: tls_config = ldap3.Tls( validate=ssl.CERT_REQUIRED, ca_certs_file=ca_cert) if isinstance(host, str): self.server = ldap3.Server( host, port=port, use_ssl=use_ssl, tls=tls_config) else: hosts = host self.server = ldap3.ServerPool([ ldap3.Server( host, port=port, use_ssl=use_ssl, tls=tls_config) for host in hosts ], ldap3.FIRST) self.domain = domain self.user_dn = user_dn self.group_dn = group_dn self.ca_cert = ca_cert self.host = host self.port = port self.use_ssl = use_ssl def __repr__(self): return ( "ADManager(host='{host}', domain='{domain}', " "user_dn='{user_dn}', group_dn='{group_dn}', " "port={port}, use_ssl={use_ssl}, ca_cert='{ca_cert}')" .format( host=self.host, domain=self.domain, user_dn=self.user_dn, group_dn=self.group_dn, port=self.port, use_ssl=self.use_ssl, ca_cert=self.ca_cert)) def prepare_connection(self, username=None, password=None): if username is not None and password is not None: ad_user = "{}\\{}".format(self.domain, username) return ldap3.Connection(self.server, ad_user, password) return ldap3.Connection(self.server) def authenticate(self, username, password): try: return self.prepare_connection(username, password).bind() except ldap3.core.exceptions.LDAPSocketOpenError: return False def groups(self, username, password): connection = self.prepare_connection(username, password) if not connection.bind(): return obj_def = ldap3.ObjectDef("user", connection) name_filter = "cn:={}".format(username) user_reader = ldap3.Reader( connection, obj_def, self.user_dn, name_filter) group_def = ldap3.ObjectDef("group", connection) all_group_reader = ldap3.Reader( connection, group_def, self.group_dn) all_groups = { group.primaryGroupToken.value: group for group in all_group_reader.search() } def _yield_recursive_groups(group_dn): group_reader = ldap3.Reader( connection, group_def, group_dn) for entry in group_reader.search(): yield entry.name.value for child in entry.memberOf: yield from _yield_recursive_groups(child) for result in user_reader.search(): yield from _yield_recursive_groups( all_groups[result.primaryGroupID.value] .distinguishedName.value) for group_dn in result.memberOf: yield from _yield_recursive_groups(group_dn) def all_groups(self, username, password): connection = self.prepare_connection(username, password) if not connection.bind(): return obj_def = ldap3.ObjectDef("group", connection) group_reader = ldap3.Reader(connection, obj_def, self.group_dn) for result in group_reader.search(): yield result.name.value except ImportError: pass try: import grp import pwd import pam class PAMManager: def __init__(self): self.pam = pam.pam() def __repr__(self): return "PAMManager()" def authenticate(self, username, password): return self.pam.authenticate(username, password) def groups(self, username, password=None): yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name for group in grp.getgrall(): if username in group.gr_mem: yield group.gr_name def all_groups(self, username, password): for group in grp.getgrall(): yield group.gr_name except ImportError: pass