Commit d2c68041 authored by Robin Sonnabend's avatar Robin Sonnabend

Improve *.py code quality

Also fixes some latent, rather unimportant bugs
parent 47fd859c
import hmac, hashlib
import hmac
import hashlib
import ssl
from datetime import datetime
class User:
def __init__(self, username, groups, timestamp=None, obsolete=False,
permanent=False):
permanent=False):
self.username = username
self.groups = groups
if timestamp is not None:
......@@ -16,8 +17,9 @@ class User:
self.permanent = permanent
def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent)
return ":".join((
self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent))
@staticmethod
def from_summary(summary):
......@@ -45,7 +47,8 @@ class UserManager:
for backend in self.backends:
if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete,
return User(
username, groups, obsolete=backend.obsolete,
permanent=permanent)
return None
......@@ -77,8 +80,8 @@ class SecurityManager:
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"),
hash)
macs_equal = hmac.compare_digest(
maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
......@@ -97,23 +100,22 @@ class StaticUserManager:
def authenticate(self, username, password):
return (username in self.passwords
and self.passwords[username] == password)
and self.passwords[username] == password)
def groups(self, username, password=None):
if username in self.group_map:
yield from self.group_map[username]
def all_groups(self):
yield from list(set(group for group in groups.values()))
yield from list(set(group for group in self.group_map.values()))
try:
import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
obsolete=False):
obsolete=False):
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
self.user_dn = user_dn
self.group_dn = group_dn
......@@ -121,8 +123,8 @@ try:
def authenticate(self, username, password):
try:
connection = ldap3.Connection(self.server,
self.user_dn.format(username), password)
connection = ldap3.Connection(
self.server, self.user_dn.format(username), password)
return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
return False
......@@ -144,16 +146,15 @@ try:
for group in group_reader.search():
yield group.cn.value
class ADManager:
def __init__(self, host, domain, user_dn, group_dn,
port=636, use_ssl=True, ca_cert=None, obsolete=False):
port=636, use_ssl=True, ca_cert=None, obsolete=False):
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
if ca_cert is not None:
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED,
ca_certs_file=ca_cert)
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl,
tls=tls_config)
tls_config = ldap3.Tls(
validate=ssl.CERT_REQUIRED, ca_certs_file=ca_cert)
self.server = ldap3.Server(
host, port=port, use_ssl=use_ssl, tls=tls_config)
self.domain = domain
self.user_dn = user_dn
self.group_dn = group_dn
......@@ -176,11 +177,13 @@ try:
connection.bind()
obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn,
name_filter)
user_reader = ldap3.Reader(
connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None)
group_reader = ldap3.Reader(
connection, group_def, group_dn, None)
for entry in group_reader.search():
yield entry.name.value
for child in entry.memberOf:
......@@ -189,20 +192,22 @@ try:
for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn)
def all_groups(self):
connection = self.prepare_connection()
connection.bind()
obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search():
for result in group_reader.search():
yield result.name.value
except ModuleNotFoundError:
pass
try:
import grp, pwd, pam
import grp
import pwd
import pam
class PAMManager:
def __init__(self, obsolete=False):
......@@ -224,4 +229,3 @@ try:
yield group.gr_name
except ModuleNotFoundError:
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment