Skip to content
Snippets Groups Projects
Commit d2c68041 authored by Robin Sonnabend's avatar Robin Sonnabend
Browse files

Improve *.py code quality

Also fixes some latent, rather unimportant bugs
parent 47fd859c
Branches
No related tags found
No related merge requests found
import hmac, hashlib import hmac
import hashlib
import ssl import ssl
from datetime import datetime from datetime import datetime
...@@ -16,8 +17,9 @@ class User: ...@@ -16,8 +17,9 @@ class User:
self.permanent = permanent self.permanent = permanent
def summarize(self): def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), return ":".join((
str(self.timestamp.timestamp()), self.obsolete, self.permanent) self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent))
@staticmethod @staticmethod
def from_summary(summary): def from_summary(summary):
...@@ -45,7 +47,8 @@ class UserManager: ...@@ -45,7 +47,8 @@ class UserManager:
for backend in self.backends: for backend in self.backends:
if backend.authenticate(username, password): if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password)))) groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete, return User(
username, groups, obsolete=backend.obsolete,
permanent=permanent) permanent=permanent)
return None return None
...@@ -77,8 +80,8 @@ class SecurityManager: ...@@ -77,8 +80,8 @@ class SecurityManager:
if user is None: if user is None:
return False return False
session_duration = datetime.now() - user.timestamp session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), macs_equal = hmac.compare_digest(
hash) maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent) return macs_equal and (time_short or user.permanent)
...@@ -104,12 +107,11 @@ class StaticUserManager: ...@@ -104,12 +107,11 @@ class StaticUserManager:
yield from self.group_map[username] yield from self.group_map[username]
def all_groups(self): def all_groups(self):
yield from list(set(group for group in groups.values())) yield from list(set(group for group in self.group_map.values()))
try: try:
import ldap3 import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager: class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
...@@ -121,8 +123,8 @@ try: ...@@ -121,8 +123,8 @@ try:
def authenticate(self, username, password): def authenticate(self, username, password):
try: try:
connection = ldap3.Connection(self.server, connection = ldap3.Connection(
self.user_dn.format(username), password) self.server, self.user_dn.format(username), password)
return connection.bind() return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError: except ldap3.core.exceptions.LDAPSocketOpenError:
return False return False
...@@ -144,16 +146,15 @@ try: ...@@ -144,16 +146,15 @@ try:
for group in group_reader.search(): for group in group_reader.search():
yield group.cn.value yield group.cn.value
class ADManager: class ADManager:
def __init__(self, host, domain, user_dn, group_dn, def __init__(self, host, domain, user_dn, group_dn,
port=636, use_ssl=True, ca_cert=None, obsolete=False): port=636, use_ssl=True, ca_cert=None, obsolete=False):
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED) tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
if ca_cert is not None: if ca_cert is not None:
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED, tls_config = ldap3.Tls(
ca_certs_file=ca_cert) validate=ssl.CERT_REQUIRED, ca_certs_file=ca_cert)
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl, self.server = ldap3.Server(
tls=tls_config) host, port=port, use_ssl=use_ssl, tls=tls_config)
self.domain = domain self.domain = domain
self.user_dn = user_dn self.user_dn = user_dn
self.group_dn = group_dn self.group_dn = group_dn
...@@ -176,11 +177,13 @@ try: ...@@ -176,11 +177,13 @@ try:
connection.bind() connection.bind()
obj_def = ldap3.ObjectDef("user", connection) obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username) name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn, user_reader = ldap3.Reader(
name_filter) connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection) group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn): def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None) group_reader = ldap3.Reader(
connection, group_def, group_dn, None)
for entry in group_reader.search(): for entry in group_reader.search():
yield entry.name.value yield entry.name.value
for child in entry.memberOf: for child in entry.memberOf:
...@@ -189,20 +192,22 @@ try: ...@@ -189,20 +192,22 @@ try:
for group_dn in result.memberOf: for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn) yield from _yield_recursive_groups(group_dn)
def all_groups(self): def all_groups(self):
connection = self.prepare_connection() connection = self.prepare_connection()
connection.bind() connection.bind()
obj_def = ldap3.ObjectDef("group", connection) obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn) group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search(): for result in group_reader.search():
yield result.name.value yield result.name.value
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
try: try:
import grp, pwd, pam import grp
import pwd
import pam
class PAMManager: class PAMManager:
def __init__(self, obsolete=False): def __init__(self, obsolete=False):
...@@ -224,4 +229,3 @@ try: ...@@ -224,4 +229,3 @@ try:
yield group.gr_name yield group.gr_name
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment