From 47fd859c25cb1acc5ef45dedaef94aa9c974be53 Mon Sep 17 00:00:00 2001 From: Robin Sonnabend <robin@fsmpi.rwth-aachen.de> Date: Fri, 2 Mar 2018 16:39:18 +0100 Subject: [PATCH] make auth backends optional Only define authentication backends that require external packages if their required modules have been found. There is no need to install packages for unused backends. --- auth.py | 274 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 146 insertions(+), 128 deletions(-) diff --git a/auth.py b/auth.py index 0323a2f..64e8fdd 100644 --- a/auth.py +++ b/auth.py @@ -1,12 +1,11 @@ import hmac, hashlib import ssl -import ldap3 -from ldap3.utils.dn import parse_dn from datetime import datetime -import grp, pwd, pam + class User: - def __init__(self, username, groups, timestamp=None, obsolete=False, permanent=False): + def __init__(self, username, groups, timestamp=None, obsolete=False, + permanent=False): self.username = username self.groups = groups if timestamp is not None: @@ -17,7 +16,8 @@ class User: self.permanent = permanent def summarize(self): - return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), str(self.timestamp.timestamp()), self.obsolete, self.permanent) + return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), + str(self.timestamp.timestamp()), self.obsolete, self.permanent) @staticmethod def from_summary(summary): @@ -36,6 +36,7 @@ class User: summary, hash = secure_string.split("=", 1) return User.from_summary(summary) + class UserManager: def __init__(self, backends): self.backends = backends @@ -44,7 +45,8 @@ class UserManager: for backend in self.backends: if backend.authenticate(username, password): groups = sorted(list(set(backend.groups(username, password)))) - return User(username, groups, obsolete=backend.obsolete, permanent=permanent) + return User(username, groups, obsolete=backend.obsolete, + permanent=permanent) return None def all_groups(self): @@ -52,89 +54,33 @@ class UserManager: yield from backend.all_groups() -class LdapManager: - def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, obsolete=False): - self.server = ldap3.Server(host, port=port, use_ssl=use_ssl) - self.user_dn = user_dn - self.group_dn = group_dn - self.obsolete = obsolete - - def authenticate(self, username, password): - try: - connection = ldap3.Connection(self.server, self.user_dn.format(username), password) - return connection.bind() - except ldap3.core.exceptions.LDAPSocketOpenError: - return False - - def groups(self, username, password=None): - connection = ldap3.Connection(self.server) - obj_def = ldap3.ObjectDef("posixgroup", connection) - group_reader = ldap3.Reader(connection, obj_def, self.group_dn) - username = username.lower() - for group in group_reader.search(): - members = group.memberUid.value - if members is not None and username in members: - yield group.cn.value +class SecurityManager: + def __init__(self, key, max_duration=300): + self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512) + self.max_duration = max_duration - def all_groups(self): - connection = ldap3.Connection(self.server) - obj_def = ldap3.ObjectDef("posixgroup", connection) - group_reader = ldap3.Reader(connection, obj_def, self.group_dn) - for group in group_reader.search(): - yield group.cn.value - - -class ADManager: - def __init__(self, host, domain, user_dn, group_dn, - port=636, use_ssl=True, ca_cert=None, obsolete=False): - tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED) - if ca_cert is not None: - tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED, - ca_certs_file=ca_cert) - self.server = ldap3.Server(host, port=port, use_ssl=use_ssl, - tls=tls_config) - self.domain = domain - self.user_dn = user_dn - self.group_dn = group_dn - self.obsolete = obsolete + def hash_user(self, user): + maccer = self.maccer.copy() + summary = user.summarize() + maccer.update(summary.encode("utf-8")) + return "{}={}".format(summary, maccer.hexdigest()) - def prepare_connection(self, username=None, password=None): - if username is not None and password is not None: - ad_user = "{}\\{}".format(self.domain, username) - return ldap3.Connection(self.server, ad_user, password) - return ldap3.Connection(self.server) - - def authenticate(self, username, password): - try: - return self.prepare_connection(username, password).bind() - except ldap3.core.exceptions.LDAPSocketOpenError: + def check_user(self, string): + parts = string.split("=", 1) + if len(parts) != 2: + # wrong format, expecting summary:hash return False - - def groups(self, username, password): - connection = self.prepare_connection(username, password) - connection.bind() - obj_def = ldap3.ObjectDef("user", connection) - name_filter = "cn:={}".format(username) - user_reader = ldap3.Reader(connection, obj_def, self.user_dn, name_filter) - group_def = ldap3.ObjectDef("group", connection) - def _yield_recursive_groups(group_dn): - group_reader = ldap3.Reader(connection, group_def, group_dn, None) - for entry in group_reader.search(): - yield entry.name.value - for child in entry.memberOf: - yield from _yield_recursive_groups(child) - for result in user_reader.search(): - for group_dn in result.memberOf: - yield from _yield_recursive_groups(group_dn) - - - def all_groups(self): - connection = self.prepare_connection() - connection.bind() - obj_def = ldap3.ObjectDef("group", connection) - group_reader = ldap3.Reader(connection, obj_def, self.group_dn) - for result in reader.search(): - yield result.name.value + summary, hash = map(lambda s: s.encode("utf-8"), parts) + maccer = self.maccer.copy() + maccer.update(summary) + user = User.from_hashstring(string) + if user is None: + return False + session_duration = datetime.now() - user.timestamp + macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), + hash) + time_short = int(session_duration.total_seconds()) < self.max_duration + return macs_equal and (time_short or user.permanent) class StaticUserManager: @@ -161,49 +107,121 @@ class StaticUserManager: yield from list(set(group for group in groups.values())) -class PAMManager: - def __init__(self, obsolete=False): - self.pam = pam.pam() - self.obsolete = obsolete +try: + import ldap3 + from ldap3.utils.dn import parse_dn + + class LdapManager: + def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, + obsolete=False): + self.server = ldap3.Server(host, port=port, use_ssl=use_ssl) + self.user_dn = user_dn + self.group_dn = group_dn + self.obsolete = obsolete + + def authenticate(self, username, password): + try: + connection = ldap3.Connection(self.server, + self.user_dn.format(username), password) + return connection.bind() + except ldap3.core.exceptions.LDAPSocketOpenError: + return False + + def groups(self, username, password=None): + connection = ldap3.Connection(self.server) + obj_def = ldap3.ObjectDef("posixgroup", connection) + group_reader = ldap3.Reader(connection, obj_def, self.group_dn) + username = username.lower() + for group in group_reader.search(): + members = group.memberUid.value + if members is not None and username in members: + yield group.cn.value + + def all_groups(self): + connection = ldap3.Connection(self.server) + obj_def = ldap3.ObjectDef("posixgroup", connection) + group_reader = ldap3.Reader(connection, obj_def, self.group_dn) + for group in group_reader.search(): + yield group.cn.value - def authenticate(self, username, password): - return self.pam.authenticate(username, password) - def groups(self, username, password=None): - print(username) - yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name - for group in grp.getgrall(): - if username in group.gr_mem: + class ADManager: + def __init__(self, host, domain, user_dn, group_dn, + port=636, use_ssl=True, ca_cert=None, obsolete=False): + tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED) + if ca_cert is not None: + tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED, + ca_certs_file=ca_cert) + self.server = ldap3.Server(host, port=port, use_ssl=use_ssl, + tls=tls_config) + self.domain = domain + self.user_dn = user_dn + self.group_dn = group_dn + self.obsolete = obsolete + + def prepare_connection(self, username=None, password=None): + if username is not None and password is not None: + ad_user = "{}\\{}".format(self.domain, username) + return ldap3.Connection(self.server, ad_user, password) + return ldap3.Connection(self.server) + + def authenticate(self, username, password): + try: + return self.prepare_connection(username, password).bind() + except ldap3.core.exceptions.LDAPSocketOpenError: + return False + + def groups(self, username, password): + connection = self.prepare_connection(username, password) + connection.bind() + obj_def = ldap3.ObjectDef("user", connection) + name_filter = "cn:={}".format(username) + user_reader = ldap3.Reader(connection, obj_def, self.user_dn, + name_filter) + group_def = ldap3.ObjectDef("group", connection) + def _yield_recursive_groups(group_dn): + group_reader = ldap3.Reader(connection, group_def, group_dn, None) + for entry in group_reader.search(): + yield entry.name.value + for child in entry.memberOf: + yield from _yield_recursive_groups(child) + for result in user_reader.search(): + for group_dn in result.memberOf: + yield from _yield_recursive_groups(group_dn) + + + def all_groups(self): + connection = self.prepare_connection() + connection.bind() + obj_def = ldap3.ObjectDef("group", connection) + group_reader = ldap3.Reader(connection, obj_def, self.group_dn) + for result in reader.search(): + yield result.name.value +except ModuleNotFoundError: + pass + + +try: + import grp, pwd, pam + + class PAMManager: + def __init__(self, obsolete=False): + self.pam = pam.pam() + self.obsolete = obsolete + + def authenticate(self, username, password): + return self.pam.authenticate(username, password) + + def groups(self, username, password=None): + print(username) + yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name + for group in grp.getgrall(): + if username in group.gr_mem: + yield group.gr_name + + def all_groups(self): + for group in grp.getgrall(): yield group.gr_name - - def all_groups(self): - for group in grp.getgrall(): - yield group.gr_name - -class SecurityManager: - def __init__(self, key, max_duration=300): - self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512) - self.max_duration = max_duration - - def hash_user(self, user): - maccer = self.maccer.copy() - summary = user.summarize() - maccer.update(summary.encode("utf-8")) - return "{}={}".format(summary, maccer.hexdigest()) - - def check_user(self, string): - parts = string.split("=", 1) - if len(parts) != 2: - # wrong format, expecting summary:hash - return False - summary, hash = map(lambda s: s.encode("utf-8"), parts) - maccer = self.maccer.copy() - maccer.update(summary) - user = User.from_hashstring(string) - if user is None: - return False - session_duration = datetime.now() - user.timestamp - macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), hash) - time_short = int(session_duration.total_seconds()) < self.max_duration - return macs_equal and (time_short or user.permanent) +except ModuleNotFoundError: + pass -- GitLab