diff --git a/auth.py b/auth.py index 3241a712952fce4fd8bdc24dfa4a5e5d49d13740..7130963886eaa8fcb3e94aac2476f921cb7b4ee4 100644 --- a/auth.py +++ b/auth.py @@ -5,10 +5,11 @@ from datetime import datetime class User: - def __init__(self, username, groups, timestamp=None, obsolete=False, - permanent=False): + def __init__(self, username, groups, all_groups, timestamp=None, + obsolete=False, permanent=False): self.username = username self.groups = groups + self.all_groups = all_groups if timestamp is not None: self.timestamp = timestamp else: @@ -18,21 +19,23 @@ class User: def summarize(self): return ":".join(( - self.username, ",".join(self.groups), + self.username, ",".join(self.groups), ",".join(self.all_groups), str(self.timestamp.timestamp()), str(self.obsolete), str(self.permanent))) @staticmethod def from_summary(summary): - parts = summary.split(":", 4) - if len(parts) != 5: + parts = summary.split(":", 5) + if len(parts) != 6: return None - name, group_str, timestamp_str, obsolete_str, permanent_str = parts + (name, group_str, all_group_str, timestamp_str, obsolete_str, + permanent_str) = parts timestamp = datetime.fromtimestamp(float(timestamp_str)) obsolete = obsolete_str == "True" groups = group_str.split(",") + all_groups = group_str.split(",") permanent = permanent_str == "True" - return User(name, groups, timestamp, obsolete, permanent) + return User(name, groups, all_groups, timestamp, obsolete, permanent) @staticmethod def from_hashstring(secure_string): @@ -48,15 +51,13 @@ class UserManager: for backend in self.backends: if backend.authenticate(username, password): groups = sorted(list(set(backend.groups(username, password)))) + all_groups = sorted(list(set(backend.all_groups( + username, password)))) return User( - username, groups, obsolete=backend.obsolete, + username, groups, all_groups, obsolete=backend.obsolete, permanent=permanent) return None - def all_groups(self): - for backend in self.backends: - yield from backend.all_groups() - class SecurityManager: def __init__(self, key, max_duration=300): @@ -107,7 +108,7 @@ class StaticUserManager: if username in self.group_map: yield from self.group_map[username] - def all_groups(self): + def all_groups(self, username, password): yield from list(set(group for group in self.group_map.values())) @@ -140,7 +141,7 @@ try: if members is not None and username in members: yield group.cn.value - def all_groups(self): + def all_groups(self, username, password): connection = ldap3.Connection(self.server) obj_def = ldap3.ObjectDef("posixgroup", connection) group_reader = ldap3.Reader(connection, obj_def, self.group_dn) @@ -175,33 +176,45 @@ try: def groups(self, username, password): connection = self.prepare_connection(username, password) - connection.bind() + if not connection.bind(): + return obj_def = ldap3.ObjectDef("user", connection) name_filter = "cn:={}".format(username) user_reader = ldap3.Reader( connection, obj_def, self.user_dn, name_filter) group_def = ldap3.ObjectDef("group", connection) + all_group_reader = ldap3.Reader( + connection, group_def, self.group_dn) + all_groups = { + group.primaryGroupToken.value: group + for group in all_group_reader.search() + } + def _yield_recursive_groups(group_dn): group_reader = ldap3.Reader( - connection, group_def, group_dn, None) + connection, group_def, group_dn) for entry in group_reader.search(): yield entry.name.value for child in entry.memberOf: yield from _yield_recursive_groups(child) for result in user_reader.search(): + yield from _yield_recursive_groups( + all_groups[result.primaryGroupID.value] + .distinguishedName.value) for group_dn in result.memberOf: yield from _yield_recursive_groups(group_dn) - def all_groups(self): - connection = self.prepare_connection() - connection.bind() + def all_groups(self, username, password): + connection = self.prepare_connection(username, password) + if not connection.bind(): + return obj_def = ldap3.ObjectDef("group", connection) group_reader = ldap3.Reader(connection, obj_def, self.group_dn) for result in group_reader.search(): yield result.name.value -except ModuleNotFoundError: +except ImportError: pass @@ -219,14 +232,13 @@ try: return self.pam.authenticate(username, password) def groups(self, username, password=None): - print(username) yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name for group in grp.getgrall(): if username in group.gr_mem: yield group.gr_name - def all_groups(self): + def all_groups(self, username, password): for group in grp.getgrall(): yield group.gr_name -except ModuleNotFoundError: +except ImportError: pass diff --git a/views/forms.py b/views/forms.py index 843cae1b4d949b60fbcab5b41d704f8fe325722a..5e94ec7ea4a60767ae934f015348f151c050d221 100644 --- a/views/forms.py +++ b/views/forms.py @@ -82,7 +82,7 @@ def get_latex_template_choices(): def get_group_choices(): user = current_user() - groups = sorted(user.groups) + groups = sorted(user.all_groups) choices = list(zip(groups, groups)) choices.insert(0, ("", "Keine Gruppe")) return choices