Commit 47fd859c authored by Robin Sonnabend's avatar Robin Sonnabend

make auth backends optional

Only define authentication backends that require external packages if
their required modules have been found.
There is no need to install packages for unused backends.
parent 5acbe6f4
import hmac, hashlib
import ssl
import ldap3
from ldap3.utils.dn import parse_dn
from datetime import datetime
import grp, pwd, pam
class User:
def __init__(self, username, groups, timestamp=None, obsolete=False, permanent=False):
def __init__(self, username, groups, timestamp=None, obsolete=False,
permanent=False):
self.username = username
self.groups = groups
if timestamp is not None:
......@@ -17,7 +16,8 @@ class User:
self.permanent = permanent
def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), str(self.timestamp.timestamp()), self.obsolete, self.permanent)
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent)
@staticmethod
def from_summary(summary):
......@@ -36,6 +36,7 @@ class User:
summary, hash = secure_string.split("=", 1)
return User.from_summary(summary)
class UserManager:
def __init__(self, backends):
self.backends = backends
......@@ -44,7 +45,8 @@ class UserManager:
for backend in self.backends:
if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete, permanent=permanent)
return User(username, groups, obsolete=backend.obsolete,
permanent=permanent)
return None
def all_groups(self):
......@@ -52,89 +54,33 @@ class UserManager:
yield from backend.all_groups()
class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, obsolete=False):
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
self.user_dn = user_dn
self.group_dn = group_dn
self.obsolete = obsolete
def authenticate(self, username, password):
try:
connection = ldap3.Connection(self.server, self.user_dn.format(username), password)
return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
return False
def groups(self, username, password=None):
connection = ldap3.Connection(self.server)
obj_def = ldap3.ObjectDef("posixgroup", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
username = username.lower()
for group in group_reader.search():
members = group.memberUid.value
if members is not None and username in members:
yield group.cn.value
class SecurityManager:
def __init__(self, key, max_duration=300):
self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
self.max_duration = max_duration
def all_groups(self):
connection = ldap3.Connection(self.server)
obj_def = ldap3.ObjectDef("posixgroup", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for group in group_reader.search():
yield group.cn.value
class ADManager:
def __init__(self, host, domain, user_dn, group_dn,
port=636, use_ssl=True, ca_cert=None, obsolete=False):
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
if ca_cert is not None:
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED,
ca_certs_file=ca_cert)
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl,
tls=tls_config)
self.domain = domain
self.user_dn = user_dn
self.group_dn = group_dn
self.obsolete = obsolete
def hash_user(self, user):
maccer = self.maccer.copy()
summary = user.summarize()
maccer.update(summary.encode("utf-8"))
return "{}={}".format(summary, maccer.hexdigest())
def prepare_connection(self, username=None, password=None):
if username is not None and password is not None:
ad_user = "{}\\{}".format(self.domain, username)
return ldap3.Connection(self.server, ad_user, password)
return ldap3.Connection(self.server)
def authenticate(self, username, password):
try:
return self.prepare_connection(username, password).bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
def check_user(self, string):
parts = string.split("=", 1)
if len(parts) != 2:
# wrong format, expecting summary:hash
return False
def groups(self, username, password):
connection = self.prepare_connection(username, password)
connection.bind()
obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None)
for entry in group_reader.search():
yield entry.name.value
for child in entry.memberOf:
yield from _yield_recursive_groups(child)
for result in user_reader.search():
for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn)
def all_groups(self):
connection = self.prepare_connection()
connection.bind()
obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search():
yield result.name.value
summary, hash = map(lambda s: s.encode("utf-8"), parts)
maccer = self.maccer.copy()
maccer.update(summary)
user = User.from_hashstring(string)
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"),
hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
class StaticUserManager:
......@@ -161,49 +107,121 @@ class StaticUserManager:
yield from list(set(group for group in groups.values()))
class PAMManager:
def __init__(self, obsolete=False):
self.pam = pam.pam()
self.obsolete = obsolete
try:
import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
obsolete=False):
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
self.user_dn = user_dn
self.group_dn = group_dn
self.obsolete = obsolete
def authenticate(self, username, password):
try:
connection = ldap3.Connection(self.server,
self.user_dn.format(username), password)
return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
return False
def groups(self, username, password=None):
connection = ldap3.Connection(self.server)
obj_def = ldap3.ObjectDef("posixgroup", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
username = username.lower()
for group in group_reader.search():
members = group.memberUid.value
if members is not None and username in members:
yield group.cn.value
def all_groups(self):
connection = ldap3.Connection(self.server)
obj_def = ldap3.ObjectDef("posixgroup", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for group in group_reader.search():
yield group.cn.value
def authenticate(self, username, password):
return self.pam.authenticate(username, password)
def groups(self, username, password=None):
print(username)
yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name
for group in grp.getgrall():
if username in group.gr_mem:
class ADManager:
def __init__(self, host, domain, user_dn, group_dn,
port=636, use_ssl=True, ca_cert=None, obsolete=False):
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
if ca_cert is not None:
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED,
ca_certs_file=ca_cert)
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl,
tls=tls_config)
self.domain = domain
self.user_dn = user_dn
self.group_dn = group_dn
self.obsolete = obsolete
def prepare_connection(self, username=None, password=None):
if username is not None and password is not None:
ad_user = "{}\\{}".format(self.domain, username)
return ldap3.Connection(self.server, ad_user, password)
return ldap3.Connection(self.server)
def authenticate(self, username, password):
try:
return self.prepare_connection(username, password).bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
return False
def groups(self, username, password):
connection = self.prepare_connection(username, password)
connection.bind()
obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn,
name_filter)
group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None)
for entry in group_reader.search():
yield entry.name.value
for child in entry.memberOf:
yield from _yield_recursive_groups(child)
for result in user_reader.search():
for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn)
def all_groups(self):
connection = self.prepare_connection()
connection.bind()
obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search():
yield result.name.value
except ModuleNotFoundError:
pass
try:
import grp, pwd, pam
class PAMManager:
def __init__(self, obsolete=False):
self.pam = pam.pam()
self.obsolete = obsolete
def authenticate(self, username, password):
return self.pam.authenticate(username, password)
def groups(self, username, password=None):
print(username)
yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name
for group in grp.getgrall():
if username in group.gr_mem:
yield group.gr_name
def all_groups(self):
for group in grp.getgrall():
yield group.gr_name
def all_groups(self):
for group in grp.getgrall():
yield group.gr_name
class SecurityManager:
def __init__(self, key, max_duration=300):
self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
self.max_duration = max_duration
def hash_user(self, user):
maccer = self.maccer.copy()
summary = user.summarize()
maccer.update(summary.encode("utf-8"))
return "{}={}".format(summary, maccer.hexdigest())
def check_user(self, string):
parts = string.split("=", 1)
if len(parts) != 2:
# wrong format, expecting summary:hash
return False
summary, hash = map(lambda s: s.encode("utf-8"), parts)
maccer = self.maccer.copy()
maccer.update(summary)
user = User.from_hashstring(string)
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
except ModuleNotFoundError:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment