diff --git a/auth.py b/auth.py index 64e8fdd4917f5984d1d3147202090d8a290bcb3f..51aba9ec34fa85be60fbd3b2b1e65bb1929ad6f0 100644 --- a/auth.py +++ b/auth.py @@ -1,11 +1,12 @@ -import hmac, hashlib +import hmac +import hashlib import ssl from datetime import datetime class User: def __init__(self, username, groups, timestamp=None, obsolete=False, - permanent=False): + permanent=False): self.username = username self.groups = groups if timestamp is not None: @@ -16,8 +17,9 @@ class User: self.permanent = permanent def summarize(self): - return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), - str(self.timestamp.timestamp()), self.obsolete, self.permanent) + return ":".join(( + self.username, ",".join(self.groups), + str(self.timestamp.timestamp()), self.obsolete, self.permanent)) @staticmethod def from_summary(summary): @@ -45,7 +47,8 @@ class UserManager: for backend in self.backends: if backend.authenticate(username, password): groups = sorted(list(set(backend.groups(username, password)))) - return User(username, groups, obsolete=backend.obsolete, + return User( + username, groups, obsolete=backend.obsolete, permanent=permanent) return None @@ -77,8 +80,8 @@ class SecurityManager: if user is None: return False session_duration = datetime.now() - user.timestamp - macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), - hash) + macs_equal = hmac.compare_digest( + maccer.hexdigest().encode("utf-8"), hash) time_short = int(session_duration.total_seconds()) < self.max_duration return macs_equal and (time_short or user.permanent) @@ -97,23 +100,22 @@ class StaticUserManager: def authenticate(self, username, password): return (username in self.passwords - and self.passwords[username] == password) + and self.passwords[username] == password) def groups(self, username, password=None): if username in self.group_map: yield from self.group_map[username] def all_groups(self): - yield from list(set(group for group in groups.values())) + yield from list(set(group for group in self.group_map.values())) try: import ldap3 - from ldap3.utils.dn import parse_dn class LdapManager: def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, - obsolete=False): + obsolete=False): self.server = ldap3.Server(host, port=port, use_ssl=use_ssl) self.user_dn = user_dn self.group_dn = group_dn @@ -121,8 +123,8 @@ try: def authenticate(self, username, password): try: - connection = ldap3.Connection(self.server, - self.user_dn.format(username), password) + connection = ldap3.Connection( + self.server, self.user_dn.format(username), password) return connection.bind() except ldap3.core.exceptions.LDAPSocketOpenError: return False @@ -144,16 +146,15 @@ try: for group in group_reader.search(): yield group.cn.value - class ADManager: def __init__(self, host, domain, user_dn, group_dn, - port=636, use_ssl=True, ca_cert=None, obsolete=False): + port=636, use_ssl=True, ca_cert=None, obsolete=False): tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED) if ca_cert is not None: - tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED, - ca_certs_file=ca_cert) - self.server = ldap3.Server(host, port=port, use_ssl=use_ssl, - tls=tls_config) + tls_config = ldap3.Tls( + validate=ssl.CERT_REQUIRED, ca_certs_file=ca_cert) + self.server = ldap3.Server( + host, port=port, use_ssl=use_ssl, tls=tls_config) self.domain = domain self.user_dn = user_dn self.group_dn = group_dn @@ -176,11 +177,13 @@ try: connection.bind() obj_def = ldap3.ObjectDef("user", connection) name_filter = "cn:={}".format(username) - user_reader = ldap3.Reader(connection, obj_def, self.user_dn, - name_filter) + user_reader = ldap3.Reader( + connection, obj_def, self.user_dn, name_filter) group_def = ldap3.ObjectDef("group", connection) + def _yield_recursive_groups(group_dn): - group_reader = ldap3.Reader(connection, group_def, group_dn, None) + group_reader = ldap3.Reader( + connection, group_def, group_dn, None) for entry in group_reader.search(): yield entry.name.value for child in entry.memberOf: @@ -189,20 +192,22 @@ try: for group_dn in result.memberOf: yield from _yield_recursive_groups(group_dn) - def all_groups(self): connection = self.prepare_connection() connection.bind() obj_def = ldap3.ObjectDef("group", connection) group_reader = ldap3.Reader(connection, obj_def, self.group_dn) - for result in reader.search(): + for result in group_reader.search(): yield result.name.value + except ModuleNotFoundError: pass try: - import grp, pwd, pam + import grp + import pwd + import pam class PAMManager: def __init__(self, obsolete=False): @@ -224,4 +229,3 @@ try: yield group.gr_name except ModuleNotFoundError: pass -