Skip to content
Snippets Groups Projects
Commit 47fd859c authored by Robin Sonnabend's avatar Robin Sonnabend
Browse files

make auth backends optional

Only define authentication backends that require external packages if
their required modules have been found.
There is no need to install packages for unused backends.
parent 5acbe6f4
No related branches found
No related tags found
No related merge requests found
import hmac, hashlib import hmac, hashlib
import ssl import ssl
import ldap3
from ldap3.utils.dn import parse_dn
from datetime import datetime from datetime import datetime
import grp, pwd, pam
class User: class User:
def __init__(self, username, groups, timestamp=None, obsolete=False, permanent=False): def __init__(self, username, groups, timestamp=None, obsolete=False,
permanent=False):
self.username = username self.username = username
self.groups = groups self.groups = groups
if timestamp is not None: if timestamp is not None:
...@@ -17,7 +16,8 @@ class User: ...@@ -17,7 +16,8 @@ class User:
self.permanent = permanent self.permanent = permanent
def summarize(self): def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), str(self.timestamp.timestamp()), self.obsolete, self.permanent) return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent)
@staticmethod @staticmethod
def from_summary(summary): def from_summary(summary):
...@@ -36,6 +36,7 @@ class User: ...@@ -36,6 +36,7 @@ class User:
summary, hash = secure_string.split("=", 1) summary, hash = secure_string.split("=", 1)
return User.from_summary(summary) return User.from_summary(summary)
class UserManager: class UserManager:
def __init__(self, backends): def __init__(self, backends):
self.backends = backends self.backends = backends
...@@ -44,7 +45,8 @@ class UserManager: ...@@ -44,7 +45,8 @@ class UserManager:
for backend in self.backends: for backend in self.backends:
if backend.authenticate(username, password): if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password)))) groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete, permanent=permanent) return User(username, groups, obsolete=backend.obsolete,
permanent=permanent)
return None return None
def all_groups(self): def all_groups(self):
...@@ -52,8 +54,66 @@ class UserManager: ...@@ -52,8 +54,66 @@ class UserManager:
yield from backend.all_groups() yield from backend.all_groups()
class SecurityManager:
def __init__(self, key, max_duration=300):
self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
self.max_duration = max_duration
def hash_user(self, user):
maccer = self.maccer.copy()
summary = user.summarize()
maccer.update(summary.encode("utf-8"))
return "{}={}".format(summary, maccer.hexdigest())
def check_user(self, string):
parts = string.split("=", 1)
if len(parts) != 2:
# wrong format, expecting summary:hash
return False
summary, hash = map(lambda s: s.encode("utf-8"), parts)
maccer = self.maccer.copy()
maccer.update(summary)
user = User.from_hashstring(string)
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"),
hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
class StaticUserManager:
def __init__(self, users, obsolete=False):
self.passwords = {
username: password
for (username, password, groups) in users
}
self.group_map = {
username: groups
for (username, password, groups) in users
}
self.obsolete = obsolete
def authenticate(self, username, password):
return (username in self.passwords
and self.passwords[username] == password)
def groups(self, username, password=None):
if username in self.group_map:
yield from self.group_map[username]
def all_groups(self):
yield from list(set(group for group in groups.values()))
try:
import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager: class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, obsolete=False): def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
obsolete=False):
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl) self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
self.user_dn = user_dn self.user_dn = user_dn
self.group_dn = group_dn self.group_dn = group_dn
...@@ -61,7 +121,8 @@ class LdapManager: ...@@ -61,7 +121,8 @@ class LdapManager:
def authenticate(self, username, password): def authenticate(self, username, password):
try: try:
connection = ldap3.Connection(self.server, self.user_dn.format(username), password) connection = ldap3.Connection(self.server,
self.user_dn.format(username), password)
return connection.bind() return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError: except ldap3.core.exceptions.LDAPSocketOpenError:
return False return False
...@@ -115,7 +176,8 @@ class ADManager: ...@@ -115,7 +176,8 @@ class ADManager:
connection.bind() connection.bind()
obj_def = ldap3.ObjectDef("user", connection) obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username) name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn, name_filter) user_reader = ldap3.Reader(connection, obj_def, self.user_dn,
name_filter)
group_def = ldap3.ObjectDef("group", connection) group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn): def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None) group_reader = ldap3.Reader(connection, group_def, group_dn, None)
...@@ -135,31 +197,12 @@ class ADManager: ...@@ -135,31 +197,12 @@ class ADManager:
group_reader = ldap3.Reader(connection, obj_def, self.group_dn) group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search(): for result in reader.search():
yield result.name.value yield result.name.value
except ModuleNotFoundError:
pass
class StaticUserManager: try:
def __init__(self, users, obsolete=False): import grp, pwd, pam
self.passwords = {
username: password
for (username, password, groups) in users
}
self.group_map = {
username: groups
for (username, password, groups) in users
}
self.obsolete = obsolete
def authenticate(self, username, password):
return (username in self.passwords
and self.passwords[username] == password)
def groups(self, username, password=None):
if username in self.group_map:
yield from self.group_map[username]
def all_groups(self):
yield from list(set(group for group in groups.values()))
class PAMManager: class PAMManager:
def __init__(self, obsolete=False): def __init__(self, obsolete=False):
...@@ -179,31 +222,6 @@ class PAMManager: ...@@ -179,31 +222,6 @@ class PAMManager:
def all_groups(self): def all_groups(self):
for group in grp.getgrall(): for group in grp.getgrall():
yield group.gr_name yield group.gr_name
except ModuleNotFoundError:
class SecurityManager: pass
def __init__(self, key, max_duration=300):
self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
self.max_duration = max_duration
def hash_user(self, user):
maccer = self.maccer.copy()
summary = user.summarize()
maccer.update(summary.encode("utf-8"))
return "{}={}".format(summary, maccer.hexdigest())
def check_user(self, string):
parts = string.split("=", 1)
if len(parts) != 2:
# wrong format, expecting summary:hash
return False
summary, hash = map(lambda s: s.encode("utf-8"), parts)
maccer = self.maccer.copy()
maccer.update(summary)
user = User.from_hashstring(string)
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment