Skip to content
Snippets Groups Projects
Commit 47fd859c authored by Robin Sonnabend's avatar Robin Sonnabend
Browse files

make auth backends optional

Only define authentication backends that require external packages if
their required modules have been found.
There is no need to install packages for unused backends.
parent 5acbe6f4
No related branches found
No related tags found
No related merge requests found
import hmac, hashlib
import ssl
import ldap3
from ldap3.utils.dn import parse_dn
from datetime import datetime
import grp, pwd, pam
class User:
def __init__(self, username, groups, timestamp=None, obsolete=False, permanent=False):
def __init__(self, username, groups, timestamp=None, obsolete=False,
permanent=False):
self.username = username
self.groups = groups
if timestamp is not None:
......@@ -17,7 +16,8 @@ class User:
self.permanent = permanent
def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), str(self.timestamp.timestamp()), self.obsolete, self.permanent)
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent)
@staticmethod
def from_summary(summary):
......@@ -36,6 +36,7 @@ class User:
summary, hash = secure_string.split("=", 1)
return User.from_summary(summary)
class UserManager:
def __init__(self, backends):
self.backends = backends
......@@ -44,7 +45,8 @@ class UserManager:
for backend in self.backends:
if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete, permanent=permanent)
return User(username, groups, obsolete=backend.obsolete,
permanent=permanent)
return None
def all_groups(self):
......@@ -52,8 +54,66 @@ class UserManager:
yield from backend.all_groups()
class SecurityManager:
def __init__(self, key, max_duration=300):
self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
self.max_duration = max_duration
def hash_user(self, user):
maccer = self.maccer.copy()
summary = user.summarize()
maccer.update(summary.encode("utf-8"))
return "{}={}".format(summary, maccer.hexdigest())
def check_user(self, string):
parts = string.split("=", 1)
if len(parts) != 2:
# wrong format, expecting summary:hash
return False
summary, hash = map(lambda s: s.encode("utf-8"), parts)
maccer = self.maccer.copy()
maccer.update(summary)
user = User.from_hashstring(string)
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"),
hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
class StaticUserManager:
def __init__(self, users, obsolete=False):
self.passwords = {
username: password
for (username, password, groups) in users
}
self.group_map = {
username: groups
for (username, password, groups) in users
}
self.obsolete = obsolete
def authenticate(self, username, password):
return (username in self.passwords
and self.passwords[username] == password)
def groups(self, username, password=None):
if username in self.group_map:
yield from self.group_map[username]
def all_groups(self):
yield from list(set(group for group in groups.values()))
try:
import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, obsolete=False):
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
obsolete=False):
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
self.user_dn = user_dn
self.group_dn = group_dn
......@@ -61,7 +121,8 @@ class LdapManager:
def authenticate(self, username, password):
try:
connection = ldap3.Connection(self.server, self.user_dn.format(username), password)
connection = ldap3.Connection(self.server,
self.user_dn.format(username), password)
return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
return False
......@@ -115,7 +176,8 @@ class ADManager:
connection.bind()
obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn, name_filter)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn,
name_filter)
group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None)
......@@ -135,31 +197,12 @@ class ADManager:
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search():
yield result.name.value
except ModuleNotFoundError:
pass
class StaticUserManager:
def __init__(self, users, obsolete=False):
self.passwords = {
username: password
for (username, password, groups) in users
}
self.group_map = {
username: groups
for (username, password, groups) in users
}
self.obsolete = obsolete
def authenticate(self, username, password):
return (username in self.passwords
and self.passwords[username] == password)
def groups(self, username, password=None):
if username in self.group_map:
yield from self.group_map[username]
def all_groups(self):
yield from list(set(group for group in groups.values()))
try:
import grp, pwd, pam
class PAMManager:
def __init__(self, obsolete=False):
......@@ -179,31 +222,6 @@ class PAMManager:
def all_groups(self):
for group in grp.getgrall():
yield group.gr_name
class SecurityManager:
def __init__(self, key, max_duration=300):
self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
self.max_duration = max_duration
def hash_user(self, user):
maccer = self.maccer.copy()
summary = user.summarize()
maccer.update(summary.encode("utf-8"))
return "{}={}".format(summary, maccer.hexdigest())
def check_user(self, string):
parts = string.split("=", 1)
if len(parts) != 2:
# wrong format, expecting summary:hash
return False
summary, hash = map(lambda s: s.encode("utf-8"), parts)
maccer = self.maccer.copy()
maccer.update(summary)
user = User.from_hashstring(string)
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
except ModuleNotFoundError:
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment