Commit 7c9c3829 authored by markus scheller's avatar markus scheller
Browse files

Merge branch 'master' into 172-dokumentation-ueberarbeiten

parents bf03dc3f 130cbe7c
...@@ -5,10 +5,11 @@ from datetime import datetime ...@@ -5,10 +5,11 @@ from datetime import datetime
class User: class User:
def __init__(self, username, groups, timestamp=None, obsolete=False, def __init__(self, username, groups, all_groups, timestamp=None,
permanent=False): obsolete=False, permanent=False):
self.username = username self.username = username
self.groups = groups self.groups = groups
self.all_groups = all_groups
if timestamp is not None: if timestamp is not None:
self.timestamp = timestamp self.timestamp = timestamp
else: else:
...@@ -18,21 +19,23 @@ class User: ...@@ -18,21 +19,23 @@ class User:
def summarize(self): def summarize(self):
return ":".join(( return ":".join((
self.username, ",".join(self.groups), self.username, ",".join(self.groups), ",".join(self.all_groups),
str(self.timestamp.timestamp()), str(self.obsolete), str(self.timestamp.timestamp()), str(self.obsolete),
str(self.permanent))) str(self.permanent)))
@staticmethod @staticmethod
def from_summary(summary): def from_summary(summary):
parts = summary.split(":", 4) parts = summary.split(":", 5)
if len(parts) != 5: if len(parts) != 6:
return None return None
name, group_str, timestamp_str, obsolete_str, permanent_str = parts (name, group_str, all_group_str, timestamp_str, obsolete_str,
permanent_str) = parts
timestamp = datetime.fromtimestamp(float(timestamp_str)) timestamp = datetime.fromtimestamp(float(timestamp_str))
obsolete = obsolete_str == "True" obsolete = obsolete_str == "True"
groups = group_str.split(",") groups = group_str.split(",")
all_groups = group_str.split(",")
permanent = permanent_str == "True" permanent = permanent_str == "True"
return User(name, groups, timestamp, obsolete, permanent) return User(name, groups, all_groups, timestamp, obsolete, permanent)
@staticmethod @staticmethod
def from_hashstring(secure_string): def from_hashstring(secure_string):
...@@ -48,15 +51,13 @@ class UserManager: ...@@ -48,15 +51,13 @@ class UserManager:
for backend in self.backends: for backend in self.backends:
if backend.authenticate(username, password): if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password)))) groups = sorted(list(set(backend.groups(username, password))))
all_groups = sorted(list(set(backend.all_groups(
username, password))))
return User( return User(
username, groups, obsolete=backend.obsolete, username, groups, all_groups, obsolete=backend.obsolete,
permanent=permanent) permanent=permanent)
return None return None
def all_groups(self):
for backend in self.backends:
yield from backend.all_groups()
class SecurityManager: class SecurityManager:
def __init__(self, key, max_duration=300): def __init__(self, key, max_duration=300):
...@@ -107,7 +108,7 @@ class StaticUserManager: ...@@ -107,7 +108,7 @@ class StaticUserManager:
if username in self.group_map: if username in self.group_map:
yield from self.group_map[username] yield from self.group_map[username]
def all_groups(self): def all_groups(self, username, password):
yield from list(set(group for group in self.group_map.values())) yield from list(set(group for group in self.group_map.values()))
...@@ -140,7 +141,7 @@ try: ...@@ -140,7 +141,7 @@ try:
if members is not None and username in members: if members is not None and username in members:
yield group.cn.value yield group.cn.value
def all_groups(self): def all_groups(self, username, password):
connection = ldap3.Connection(self.server) connection = ldap3.Connection(self.server)
obj_def = ldap3.ObjectDef("posixgroup", connection) obj_def = ldap3.ObjectDef("posixgroup", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn) group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
...@@ -175,33 +176,45 @@ try: ...@@ -175,33 +176,45 @@ try:
def groups(self, username, password): def groups(self, username, password):
connection = self.prepare_connection(username, password) connection = self.prepare_connection(username, password)
connection.bind() if not connection.bind():
return
obj_def = ldap3.ObjectDef("user", connection) obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username) name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader( user_reader = ldap3.Reader(
connection, obj_def, self.user_dn, name_filter) connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection) group_def = ldap3.ObjectDef("group", connection)
all_group_reader = ldap3.Reader(
connection, group_def, self.group_dn)
all_groups = {
group.primaryGroupToken.value: group
for group in all_group_reader.search()
}
def _yield_recursive_groups(group_dn): def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader( group_reader = ldap3.Reader(
connection, group_def, group_dn, None) connection, group_def, group_dn)
for entry in group_reader.search(): for entry in group_reader.search():
yield entry.name.value yield entry.name.value
for child in entry.memberOf: for child in entry.memberOf:
yield from _yield_recursive_groups(child) yield from _yield_recursive_groups(child)
for result in user_reader.search(): for result in user_reader.search():
yield from _yield_recursive_groups(
all_groups[result.primaryGroupID.value]
.distinguishedName.value)
for group_dn in result.memberOf: for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn) yield from _yield_recursive_groups(group_dn)
def all_groups(self): def all_groups(self, username, password):
connection = self.prepare_connection() connection = self.prepare_connection(username, password)
connection.bind() if not connection.bind():
return
obj_def = ldap3.ObjectDef("group", connection) obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn) group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in group_reader.search(): for result in group_reader.search():
yield result.name.value yield result.name.value
except ModuleNotFoundError: except ImportError:
pass pass
...@@ -219,14 +232,13 @@ try: ...@@ -219,14 +232,13 @@ try:
return self.pam.authenticate(username, password) return self.pam.authenticate(username, password)
def groups(self, username, password=None): def groups(self, username, password=None):
print(username)
yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name
for group in grp.getgrall(): for group in grp.getgrall():
if username in group.gr_mem: if username in group.gr_mem:
yield group.gr_name yield group.gr_name
def all_groups(self): def all_groups(self, username, password):
for group in grp.getgrall(): for group in grp.getgrall():
yield group.gr_name yield group.gr_name
except ModuleNotFoundError: except ImportError:
pass pass
...@@ -105,6 +105,12 @@ class ProtocolType(DatabaseModel): ...@@ -105,6 +105,12 @@ class ProtocolType(DatabaseModel):
return None return None
return candidates[0] return candidates[0]
def get_protocols_on_date(self, protocol_date):
return [
protocol for protocol in self.protocols
if protocol.date == protocol_date
]
def has_public_view_right(self, user, check_networks=True): def has_public_view_right(self, user, check_networks=True):
return ( return (
self.has_public_anonymous_view_right(check_networks=check_networks) self.has_public_anonymous_view_right(check_networks=check_networks)
...@@ -423,7 +429,12 @@ class Protocol(DatabaseModel): ...@@ -423,7 +429,12 @@ class Protocol(DatabaseModel):
tzinfo=tz.tzlocal()) tzinfo=tz.tzlocal())
@staticmethod @staticmethod
def create_new_protocol(protocoltype, date, start_time=None): def create_new_protocol(
protocoltype, date, start_time=None, allow_duplicate=False):
if not allow_duplicate:
duplicate_candidates = protocoltype.get_protocols_on_date(date)
if duplicate_candidates:
return duplicate_candidates[0]
if start_time is None: if start_time is None:
start_time = protocoltype.usual_time start_time = protocoltype.usual_time
protocol = Protocol( protocol = Protocol(
......
...@@ -75,7 +75,7 @@ try: ...@@ -75,7 +75,7 @@ try:
"release": get_git_revision(), "release": get_git_revision(),
} }
sentry.get_user_info = get_user_info sentry.get_user_info = get_user_info
except ModuleNotFoundError: except ImportError:
print("Raven not installed. Not sending issues to Sentry.") print("Raven not installed. Not sending issues to Sentry.")
except AttributeError: except AttributeError:
print("DSN not configured. Not sending issues to Sentry.") print("DSN not configured. Not sending issues to Sentry.")
...@@ -91,7 +91,7 @@ def make_celery(app, config): ...@@ -91,7 +91,7 @@ def make_celery(app, config):
raven_client = RavenClient(config.SENTRY_DSN) raven_client = RavenClient(config.SENTRY_DSN)
register_logger_signal(raven_client) register_logger_signal(raven_client)
register_signal(raven_client) register_signal(raven_client)
except ModuleNotFoundError: except ImportError:
print("Raven not installed. Not sending celery issues to Sentry.") print("Raven not installed. Not sending celery issues to Sentry.")
except AttributeError: except AttributeError:
print("DSN not configured. Not sending celery issues to Sentry.") print("DSN not configured. Not sending celery issues to Sentry.")
...@@ -1029,7 +1029,7 @@ def send_protocol_reminder(protocol): ...@@ -1029,7 +1029,7 @@ def send_protocol_reminder(protocol):
if not config.MAIL_ACTIVE: if not config.MAIL_ACTIVE:
flash("Die Mailfunktion ist nicht aktiviert.", "alert-error") flash("Die Mailfunktion ist nicht aktiviert.", "alert-error")
return back.redirect("show_protocol", protocol_id=protocol.id) return back.redirect("show_protocol", protocol_id=protocol.id)
meetingreminders = protocol.reminders meetingreminders = protocol.protocoltype.reminders
if len(meetingreminders) == 0: if len(meetingreminders) == 0:
flash("Für diesen Protokolltyp sind keine Einladungsmails " flash("Für diesen Protokolltyp sind keine Einladungsmails "
"konfiguriert.", "alert-error") "konfiguriert.", "alert-error")
...@@ -1924,7 +1924,8 @@ def check_and_send_reminders(): ...@@ -1924,7 +1924,8 @@ def check_and_send_reminders():
with app.app_context(): with app.app_context():
current_time = datetime.now() current_time = datetime.now()
current_day = current_time.date() current_day = current_time.date()
for protocol in Protocol.query.filter(not Protocol.done).all(): query = Protocol.query.filter(Protocol.done == False) # noqa: E712
for protocol in query.all():
day_difference = (protocol.date - current_day).days day_difference = (protocol.date - current_day).days
usual_time = protocol.get_time() usual_time = protocol.get_time()
protocol_time = datetime( protocol_time = datetime(
......
...@@ -492,8 +492,9 @@ def parse_protocol_async_inner(protocol): ...@@ -492,8 +492,9 @@ def parse_protocol_async_inner(protocol):
if len(protocol_tag.values) > 1: if len(protocol_tag.values) > 1:
new_protocol_time = datetime.strptime( new_protocol_time = datetime.strptime(
protocol_tag.values[1], "%H:%M") protocol_tag.values[1], "%H:%M")
Protocol.create_new_protocol( if not protocol.protocoltype.get_protocols_on_date(new_protocol_date):
protocol.protocoltype, new_protocol_date, new_protocol_time) Protocol.create_new_protocol(
protocol.protocoltype, new_protocol_date, new_protocol_time)
# TOPs # TOPs
old_tops = list(protocol.tops) old_tops = list(protocol.tops)
...@@ -582,17 +583,18 @@ def push_to_wiki(protocol, content, infobox_content, summary): ...@@ -582,17 +583,18 @@ def push_to_wiki(protocol, content, infobox_content, summary):
@celery.task @celery.task
def push_to_wiki_async(protocol_id, content, infobox_content, summary): def push_to_wiki_async(protocol_id, content, infobox_content, summary):
with WikiClient() as wiki_client, app.app_context(): with app.app_context():
protocol = Protocol.query.filter_by(id=protocol_id).first() protocol = Protocol.query.filter_by(id=protocol_id).first()
try: try:
wiki_client.edit_page( with WikiClient() as wiki_client:
title=protocol.protocoltype.get_wiki_infobox_title(), wiki_client.edit_page(
content=infobox_content, title=protocol.protocoltype.get_wiki_infobox_title(),
summary=summary) content=infobox_content,
wiki_client.edit_page( summary=summary)
title=protocol.get_wiki_title(), wiki_client.edit_page(
content=content, title=protocol.get_wiki_title(),
summary=summary) content=content,
summary=summary)
except WikiException as exc: except WikiException as exc:
return _make_error( return _make_error(
protocol, "Pushing to Wiki", "Pushing to Wiki failed.", protocol, "Pushing to Wiki", "Pushing to Wiki failed.",
......
...@@ -82,7 +82,7 @@ def get_latex_template_choices(): ...@@ -82,7 +82,7 @@ def get_latex_template_choices():
def get_group_choices(): def get_group_choices():
user = current_user() user = current_user()
groups = sorted(user.groups) groups = sorted(user.all_groups)
choices = list(zip(groups, groups)) choices = list(zip(groups, groups))
choices.insert(0, ("", "Keine Gruppe")) choices.insert(0, ("", "Keine Gruppe"))
return choices return choices
......
import requests import requests
from json import JSONDecodeError
import config import config
...@@ -117,7 +118,10 @@ class WikiClient: ...@@ -117,7 +118,10 @@ class WikiClient:
"HTTP status code {} on action {}.".format( "HTTP status code {} on action {}.".format(
req.status_code, action)) req.status_code, action))
self.cookies.update(req.cookies) self.cookies.update(req.cookies)
return req.json() try:
return req.json()
except JSONDecodeError:
raise WikiException("Server did not return valid JSON.")
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment