Commit 459f9e57 authored by Robin Sonnabend's avatar Robin Sonnabend

Fix recursive group membership

Honor user.primaryGroupID in addition to memberOf.
Also, the group choice dialog now offers all groups anyone can have.

/close #160
parent f5f1f42a
......@@ -5,10 +5,11 @@ from datetime import datetime
class User:
def __init__(self, username, groups, timestamp=None, obsolete=False,
permanent=False):
def __init__(self, username, groups, all_groups, timestamp=None,
obsolete=False, permanent=False):
self.username = username
self.groups = groups
self.all_groups = all_groups
if timestamp is not None:
self.timestamp = timestamp
else:
......@@ -18,21 +19,23 @@ class User:
def summarize(self):
return ":".join((
self.username, ",".join(self.groups),
self.username, ",".join(self.groups), ",".join(self.all_groups),
str(self.timestamp.timestamp()), str(self.obsolete),
str(self.permanent)))
@staticmethod
def from_summary(summary):
parts = summary.split(":", 4)
if len(parts) != 5:
parts = summary.split(":", 5)
if len(parts) != 6:
return None
name, group_str, timestamp_str, obsolete_str, permanent_str = parts
(name, group_str, all_group_str, timestamp_str, obsolete_str,
permanent_str) = parts
timestamp = datetime.fromtimestamp(float(timestamp_str))
obsolete = obsolete_str == "True"
groups = group_str.split(",")
all_groups = group_str.split(",")
permanent = permanent_str == "True"
return User(name, groups, timestamp, obsolete, permanent)
return User(name, groups, all_groups, timestamp, obsolete, permanent)
@staticmethod
def from_hashstring(secure_string):
......@@ -48,15 +51,13 @@ class UserManager:
for backend in self.backends:
if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password))))
all_groups = sorted(list(set(backend.all_groups(
username, password))))
return User(
username, groups, obsolete=backend.obsolete,
username, groups, all_groups, obsolete=backend.obsolete,
permanent=permanent)
return None
def all_groups(self):
for backend in self.backends:
yield from backend.all_groups()
class SecurityManager:
def __init__(self, key, max_duration=300):
......@@ -107,7 +108,7 @@ class StaticUserManager:
if username in self.group_map:
yield from self.group_map[username]
def all_groups(self):
def all_groups(self, username, password):
yield from list(set(group for group in self.group_map.values()))
......@@ -140,7 +141,7 @@ try:
if members is not None and username in members:
yield group.cn.value
def all_groups(self):
def all_groups(self, username, password):
connection = ldap3.Connection(self.server)
obj_def = ldap3.ObjectDef("posixgroup", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
......@@ -175,33 +176,45 @@ try:
def groups(self, username, password):
connection = self.prepare_connection(username, password)
connection.bind()
if not connection.bind():
return
obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(
connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection)
all_group_reader = ldap3.Reader(
connection, group_def, self.group_dn)
all_groups = {
group.primaryGroupToken.value: group
for group in all_group_reader.search()
}
def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(
connection, group_def, group_dn, None)
connection, group_def, group_dn)
for entry in group_reader.search():
yield entry.name.value
for child in entry.memberOf:
yield from _yield_recursive_groups(child)
for result in user_reader.search():
yield from _yield_recursive_groups(
all_groups[result.primaryGroupID.value]
.distinguishedName.value)
for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn)
def all_groups(self):
connection = self.prepare_connection()
connection.bind()
def all_groups(self, username, password):
connection = self.prepare_connection(username, password)
if not connection.bind():
return
obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in group_reader.search():
yield result.name.value
except ModuleNotFoundError:
except ImportError:
pass
......@@ -219,14 +232,13 @@ try:
return self.pam.authenticate(username, password)
def groups(self, username, password=None):
print(username)
yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name
for group in grp.getgrall():
if username in group.gr_mem:
yield group.gr_name
def all_groups(self):
def all_groups(self, username, password):
for group in grp.getgrall():
yield group.gr_name
except ModuleNotFoundError:
except ImportError:
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment