Commit 6d9c3fdd authored by Robin Sonnabend's avatar Robin Sonnabend

Improve *.py code quality

Also fixes some latent, rather unimportant bugs
parent 663a9d49
import hmac, hashlib
import hmac
import hashlib
import ssl
from datetime import datetime
class User:
def __init__(self, username, groups, timestamp=None, obsolete=False,
permanent=False):
permanent=False):
self.username = username
self.groups = groups
if timestamp is not None:
......@@ -16,8 +17,9 @@ class User:
self.permanent = permanent
def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent)
return ":".join((
self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent))
@staticmethod
def from_summary(summary):
......@@ -45,7 +47,8 @@ class UserManager:
for backend in self.backends:
if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete,
return User(
username, groups, obsolete=backend.obsolete,
permanent=permanent)
return None
......@@ -77,8 +80,8 @@ class SecurityManager:
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"),
hash)
macs_equal = hmac.compare_digest(
maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
......@@ -97,23 +100,22 @@ class StaticUserManager:
def authenticate(self, username, password):
return (username in self.passwords
and self.passwords[username] == password)
and self.passwords[username] == password)
def groups(self, username, password=None):
if username in self.group_map:
yield from self.group_map[username]
def all_groups(self):
yield from list(set(group for group in groups.values()))
yield from list(set(group for group in self.group_map.values()))
try:
import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
obsolete=False):
obsolete=False):
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
self.user_dn = user_dn
self.group_dn = group_dn
......@@ -121,8 +123,8 @@ try:
def authenticate(self, username, password):
try:
connection = ldap3.Connection(self.server,
self.user_dn.format(username), password)
connection = ldap3.Connection(
self.server, self.user_dn.format(username), password)
return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
return False
......@@ -144,16 +146,15 @@ try:
for group in group_reader.search():
yield group.cn.value
class ADManager:
def __init__(self, host, domain, user_dn, group_dn,
port=636, use_ssl=True, ca_cert=None, obsolete=False):
port=636, use_ssl=True, ca_cert=None, obsolete=False):
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
if ca_cert is not None:
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED,
ca_certs_file=ca_cert)
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl,
tls=tls_config)
tls_config = ldap3.Tls(
validate=ssl.CERT_REQUIRED, ca_certs_file=ca_cert)
self.server = ldap3.Server(
host, port=port, use_ssl=use_ssl, tls=tls_config)
self.domain = domain
self.user_dn = user_dn
self.group_dn = group_dn
......@@ -176,11 +177,13 @@ try:
connection.bind()
obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn,
name_filter)
user_reader = ldap3.Reader(
connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None)
group_reader = ldap3.Reader(
connection, group_def, group_dn, None)
for entry in group_reader.search():
yield entry.name.value
for child in entry.memberOf:
......@@ -189,20 +192,22 @@ try:
for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn)
def all_groups(self):
connection = self.prepare_connection()
connection.bind()
obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search():
for result in group_reader.search():
yield result.name.value
except ModuleNotFoundError:
pass
try:
import grp, pwd, pam
import grp
import pwd
import pam
class PAMManager:
def __init__(self, obsolete=False):
......@@ -224,4 +229,3 @@ try:
yield group.gr_name
except ModuleNotFoundError:
pass
......@@ -10,6 +10,7 @@ import config
cookie = getattr(config, "REDIRECT_BACK_COOKIE", "back")
default_view = getattr(config, "REDIRECT_BACK_DEFAULT", "index")
def anchor(func, cookie=cookie):
@functools.wraps(func)
def result(*args, **kwargs):
......@@ -17,8 +18,10 @@ def anchor(func, cookie=cookie):
return func(*args, **kwargs)
return result
def url(default=default_view, cookie=cookie, **url_args):
return session.get(cookie, url_for(default, **url_args))
def redirect(default=default_view, cookie=cookie, **url_args):
return flask_redirect(url(default, cookie, **url_args))
......@@ -2,15 +2,16 @@ from datetime import datetime, timedelta
import random
import quopri
from caldav import DAVClient, Principal, Calendar, Event
from caldav.lib.error import PropfindError
from caldav import DAVClient
from vobject.base import ContentLine
import config
class CalendarException(Exception):
pass
class Client:
def __init__(self, calendar=None, url=None):
if not config.CALENDAR_ACTIVE:
......@@ -23,9 +24,12 @@ class Client:
self.principal = self.client.principal()
break
except Exception as exc:
print("Got exception {} from caldav, retrying".format(str(exc)))
print("Got exception {} from caldav, retrying".format(
str(exc)))
if self.principal is None:
raise CalendarException("Got {} CalDAV-error from the CalDAV server.".format(config.CALENDAR_MAX_REQUESTS))
raise CalendarException(
"Got {} CalDAV-error from the CalDAV server.".format(
config.CALENDAR_MAX_REQUESTS))
if calendar is not None:
self.calendar = self.get_calendar(calendar)
else:
......@@ -41,9 +45,11 @@ class Client:
for calendar in self.principal.calendars()
]
except Exception as exc:
print("Got exception {} from caldav, retrying".format(str(exc)))
raise CalendarException("Got {} CalDAV Errors from the CalDAV server.".format(config.CALENDAR_MAX_REQUESTS))
print("Got exception {} from caldav, retrying".format(
str(exc)))
raise CalendarException(
"Got {} CalDAV Errors from the CalDAV server.".format(
config.CALENDAR_MAX_REQUESTS))
def get_calendar(self, calendar_name):
candidates = self.principal.calendars()
......@@ -57,12 +63,14 @@ class Client:
return
candidates = [
Event.from_raw_event(raw_event)
for raw_event in self.calendar.date_search(begin, begin + timedelta(hours=1))
for raw_event in self.calendar.date_search(
begin, begin + timedelta(hours=1))
]
candidates = [event for event in candidates if event.name == name]
event = None
if len(candidates) == 0:
event = Event(None, name, description, begin,
event = Event(
None, name, description, begin,
begin + timedelta(hours=config.CALENDAR_DEFAULT_DURATION))
vevent = self.calendar.add_event(event.to_vcal())
event.vevent = vevent
......@@ -76,11 +84,14 @@ NAME_KEY = "summary"
DESCRIPTION_KEY = "description"
BEGIN_KEY = "dtstart"
END_KEY = "dtend"
def _get_item(content, key):
if key in content:
return content[key][0].value
return None
class Event:
def __init__(self, vevent, name, description, begin, end):
self.vevent = vevent
......@@ -97,7 +108,8 @@ class Event:
description = _get_item(content, DESCRIPTION_KEY)
begin = _get_item(content, BEGIN_KEY)
end = _get_item(content, END_KEY)
return Event(vevent=vevent, name=name, description=description,
return Event(
vevent=vevent, name=name, description=description,
begin=begin, end=end)
def set_description(self, description):
......@@ -105,7 +117,8 @@ class Event:
self.description = description
encoded = encode_quopri(description)
if DESCRIPTION_KEY not in raw_event.contents:
raw_event.contents[DESCRIPTION_KEY] = [ContentLine(DESCRIPTION_KEY, {"ENCODING": ["QUOTED-PRINTABLE"]}, encoded)]
raw_event.contents[DESCRIPTION_KEY] = [ContentLine(
DESCRIPTION_KEY, {"ENCODING": ["QUOTED-PRINTABLE"]}, encoded)]
else:
content_line = raw_event.contents[DESCRIPTION_KEY][0]
content_line.value = encoded
......@@ -129,21 +142,28 @@ SUMMARY:{summary}
DESCRIPTION;ENCODING=QUOTED-PRINTABLE:{description}
END:VEVENT
END:VCALENDAR""".format(
uid=create_uid(), now=date_format(datetime.now()-offset),
begin=date_format(self.begin-offset), end=date_format(self.end-offset),
uid=create_uid(),
now=date_format(datetime.now() - offset),
begin=date_format(self.begin - offset),
end=date_format(self.end - offset),
summary=self.name,
description=encode_quopri(self.description))
def create_uid():
return str(random.randint(0, 1e10)).rjust(10, "0")
def date_format(dt):
return dt.strftime("%Y%m%dT%H%M%SZ")
def get_timezone_offset():
difference = datetime.now() - datetime.utcnow()
return timedelta(hours=round(difference.seconds / 3600 + difference.days * 24))
return timedelta(
hours=round(difference.seconds / 3600 + difference.days * 24))
def encode_quopri(text):
return quopri.encodestring(text.encode("utf-8")).replace(b"\n", b"=0A").decode("utf-8")
def encode_quopri(text):
return quopri.encodestring(text.encode("utf-8")).replace(
b"\n", b"=0A").decode("utf-8")
......@@ -3,14 +3,20 @@ import regex as re
import os
import sys
ROUTE_PATTERN = r'@(?:[[:alpha:]])+\.route\(\"(?<url>[^"]+)"[^)]*\)\s*(?:@[[:alpha:]_()., ]+\s*)*def\s+(?<name>[[:alpha:]][[:alnum:]_]*)\((?<params>[[:alnum:], ]*)\):'
ROUTE_PATTERN = (
r'@(?:[[:alpha:]])+\.route\(\"(?<url>[^"]+)"[^)]*\)\s*'
r'(?:@[[:alpha:]_()., ]+\s*)*def\s+(?<name>[[:alpha:]][[:alnum:]_]*)'
r'\((?<params>[[:alnum:], ]*)\):')
quote_group = "[\"']"
URL_FOR_PATTERN = r'url_for\({quotes}(?<name>[[:alpha:]][[:alnum:]_]*){quotes}'.format(quotes=quote_group)
URL_FOR_PATTERN = (
r'url_for\({quotes}(?<name>[[:alpha:]][[:alnum:]_]*)'
'{quotes}'.format(quotes=quote_group))
ROOT_DIR = "."
ENDINGS = [".py", ".html", ".txt"]
MAX_DEPTH = 2
def list_dir(dir, level=0):
if level >= MAX_DEPTH:
return
......@@ -23,7 +29,8 @@ def list_dir(dir, level=0):
if file.endswith(ending):
yield path
elif os.path.isdir(path):
yield from list_dir(path, level+1)
yield from list_dir(path, level + 1)
class Route:
def __init__(self, file, name, parameters):
......@@ -38,13 +45,15 @@ class Route:
def get_parameter_set(self):
return {parameter.name for parameter in self.parameters}
class Parameter:
def __init__(self, name, type=None):
self.name = name
self.type = type
def __repr__(self):
return "Parameter({name}, {type})".format(name=self.name, type=self.type)
return "Parameter({name}, {type})".format(
name=self.name, type=self.type)
@staticmethod
def from_string(text):
......@@ -53,6 +62,7 @@ class Parameter:
return Parameter(name, type)
return Parameter(text)
def split_url_parameters(url):
params = []
current_param = None
......@@ -68,9 +78,11 @@ def split_url_parameters(url):
current_param += char
return params
def split_function_parameters(parameters):
return list(map(str.strip, parameters.split(",")))
def read_url_for_parameters(content):
params = []
bracket_level = 1
......@@ -92,6 +104,7 @@ def read_url_for_parameters(content):
elif char == ")":
bracket_level -= 1
class UrlFor:
def __init__(self, file, name, parameters):
self.file = file
......@@ -99,8 +112,10 @@ class UrlFor:
self.parameters = parameters
def __repr__(self):
return "UrlFor(file={file}, name={name}, parameters={parameters})".format(
file=self.file, name=self.name, parameters=self.parameters)
return (
"UrlFor(file={file}, name={name}, parameters={parameters})".format(
file=self.file, name=self.name, parameters=self.parameters))
routes = {}
url_fors = []
......@@ -109,24 +124,29 @@ for file in list_dir(ROOT_DIR):
content = infile.read()
for match in re.finditer(ROUTE_PATTERN, content):
name = match.group("name")
function_parameters = split_function_parameters(match.group("params"))
function_parameters = split_function_parameters(
match.group("params"))
url_parameters = split_url_parameters(match.group("url"))
routes[name] = Route(file, name, url_parameters)
for match in re.finditer(URL_FOR_PATTERN, content):
name = match.group("name")
begin, end = match.span()
parameters = read_url_for_parameters(content[end:])
url_fors.append(UrlFor(file=file, name=name, parameters=parameters))
url_fors.append(UrlFor(
file=file, name=name, parameters=parameters))
for url_for in url_fors:
if url_for.name not in routes:
print("Missing route '{}' (for url_for in '{}')".format(url_for.name, url_for.file))
print("Missing route '{}' (for url_for in '{}')".format(
url_for.name, url_for.file))
continue
route = routes[url_for.name]
route_parameters = route.get_parameter_set()
url_parameters = set(url_for.parameters)
if len(route_parameters ^ url_parameters) > 0:
print("Parameters not matching for '{}' in '{}:'".format(url_for.name, url_for.file))
print("Parameters not matching for '{}' in '{}:'".format(
url_for.name, url_for.file))
only_route = route_parameters - url_parameters
only_url = url_parameters - route_parameters
if len(only_route) > 0:
......
from flask import redirect, flash, request, url_for
from flask import flash
from functools import wraps
from models.database import ALL_MODELS
from shared import db, current_user
from shared import current_user
import back
ID_KEY = "id"
......@@ -12,12 +12,15 @@ OBJECT_DOES_NOT_EXIST_MESSAGE = "There is no {} with id {}."
MISSING_VIEW_RIGHT = "Dir fehlenden die nötigen Zugriffsrechte."
def default_redirect():
return back.redirect()
def login_redirect():
return back.redirect("login")
def db_lookup(*models, check_exists=True):
def _decorator(function):
@wraps(function)
......@@ -32,7 +35,8 @@ def db_lookup(*models, check_exists=True):
obj = model.query.filter_by(id=obj_id).first()
if check_exists and obj is None:
model_name = model.__class__.__name__
flash(OBJECT_DOES_NOT_EXIST_MESSAGE.format(model_name, obj_id),
flash(OBJECT_DOES_NOT_EXIST_MESSAGE.format(
model_name, obj_id),
"alert-error")
return default_redirect()
kwargs[key] = obj
......@@ -41,8 +45,10 @@ def db_lookup(*models, check_exists=True):
return _decorated_function
return _decorator
def require_right(right, require_exist):
necessary_right_name = "has_{}_right".format(right)
def _decorator(function):
@wraps(function)
def _decorated_function(*args, **kwargs):
......@@ -65,17 +71,22 @@ def require_right(right, require_exist):
return _decorated_function
return _decorator
def require_public_view_right(require_exist=True):
return require_right("public_view", require_exist)
def require_private_view_right(require_exist=True):
return require_right("private_view", require_exist)
def require_modify_right(require_exist=True):
return require_right("modify", require_exist)
def require_publish_right(require_exist=True):
return require_right("publish", require_exist)
def require_admin_right(require_exist=True):
return require_right("admin", require_exist)
from datetime import datetime
from fuzzywuzzy import fuzz, process
import tempfile
from fuzzywuzzy import process
from models.database import Todo, OldTodo, Protocol, ProtocolType, TodoMail
from models.database import OldTodo, Protocol, ProtocolType, TodoMail
from shared import db
import config
def lookup_todo_id(old_candidates, new_who, new_description):
# Check for perfect matches
for candidate in old_candidates:
if candidate.who == new_who and candidate.description == new_description:
if (candidate.who == new_who
and candidate.description == new_description):
return candidate.old_id
# Accept if who has been changed
for candidate in old_candidates:
......@@ -32,11 +33,13 @@ def lookup_todo_id(old_candidates, new_who, new_description):
new_description, best_match, best_match_score))
return None
INSERT_PROTOCOLTYPE = "INSERT INTO `protocolManager_protocoltype`"
INSERT_PROTOCOL = "INSERT INTO `protocolManager_protocol`"
INSERT_TODO = "INSERT INTO `protocolManager_todo`"
INSERT_TODOMAIL = "INSERT INTO `protocolManager_todonamemailassignment`"
def import_old_protocols(sql_text):
protocoltype_lines = []
protocol_lines = []
......@@ -46,22 +49,27 @@ def import_old_protocols(sql_text):
elif line.startswith(INSERT_PROTOCOL):
protocol_lines.append(line)
if (len(protocoltype_lines) == 0
or len(protocol_lines) == 0):
or len(protocol_lines) == 0):
raise ValueError("Necessary lines not found.")
type_id_to_handle = {}
for type_line in protocoltype_lines:
for id, handle, name, mail, protocol_id in _split_insert_line(type_line):
for id, handle, name, mail, protocol_id in _split_insert_line(
type_line):
type_id_to_handle[int(id)] = handle.lower()
protocols = []
for protocol_line in protocol_lines:
for (protocol_id, old_type_id, date, source, textsummary, htmlsummary,
deleted, sent, document_id) in _split_insert_line(protocol_line):
deleted, sent, document_id) in _split_insert_line(
protocol_line):
date = datetime.strptime(date, "%Y-%m-%d")
handle = type_id_to_handle[int(old_type_id)]
protocoltype = ProtocolType.query.filter(ProtocolType.short_name.ilike(handle)).first()
protocoltype = ProtocolType.query.filter(
ProtocolType.short_name.ilike(handle)).first()
if protocoltype is None:
raise KeyError("No protocoltype for handle '{}'.".format(handle))
protocol = Protocol(protocoltype_id=protocoltype.id, date=date, source=source)
raise KeyError(
"No protocoltype for handle '{}'.".format(handle))
protocol = Protocol(
protocoltype_id=protocoltype.id, date=date, source=source)
db.session.add(protocol)
db.session.commit()
import tasks
......@@ -70,6 +78,7 @@ def import_old_protocols(sql_text):
print(protocol.date)
tasks.parse_protocol(protocol)
def import_old_todomails(sql_text):
todomail_lines = []
for line in sql_text.splitlines():
......@@ -98,28 +107,34 @@ def import_old_todos(sql_text):
elif line.startswith(INSERT_TODO):
todo_lines.append(line)
if (len(protocoltype_lines) == 0
or len(protocol_lines) == 0
or len(todo_lines) == 0):
or len(protocol_lines) == 0
or len(todo_lines) == 0):
raise ValueError("Necessary lines not found.")
type_id_to_handle = {}
for type_line in protocoltype_lines:
for id, handle, name, mail, protocol_id in _split_insert_line(type_line):
for id, handle, name, mail, protocol_id in _split_insert_line(
type_line):
type_id_to_handle[int(id)] = handle.lower()
protocol_id_to_key = {}
for protocol_line in protocol_lines:
for (protocol_id, type_id, date, source, textsummary, htmlsummary,
deleted, sent, document_id) in _split_insert_line(protocol_line):
deleted, sent, document_id) in _split_insert_line(
protocol_line):
handle = type_id_to_handle[int(type_id)]
date_string = date [2:]
protocol_id_to_key[int(protocol_id)] = "{}-{}".format(handle, date_string)
date_string = date[2:]
protocol_id_to_key[int(protocol_id)] = "{}-{}".format(
handle, date_string)
todos = []
for todo_line in todo_lines:
for old_id, protocol_id, who, what, start_time, end_time, done in _split_insert_line(todo_line):
for (old_id, protocol_id, who, what, start_time, end_time,
done) in _split_insert_line(todo_line):
protocol_id = int(protocol_id)
if protocol_id not in protocol_id_to_key:
print("Missing protocol with ID {} for Todo {}".format(protocol_id, what))
print("Missing protocol with ID {} for Todo {}".format(
protocol_id, what))
continue
todo = OldTodo(old_id=old_id, who=who, description=what,
todo = OldTodo(
old_id=old_id, who=who, description=what,
protocol_key=protocol_id_to_key[protocol_id])
todos.append(todo)
OldTodo.query.delete()
......@@ -127,12 +142,16 @@ def import_old_todos(sql_text):
for todo in todos:
db.session.add(todo)
db.session.commit()
def _split_insert_line(line):
insert_part, values_part = line.split("VALUES", 1)
return _split_base_level(values_part)
def _split_base_level(text, begin="(", end=")", separator=",", string_terminator="'", line_end=";", ignore=" ", escape="\\"):
def _split_base_level(
text, begin="(", end=")", separator=",", string_terminator="'",
line_end=";", ignore=" ", escape="\\"):
raw_parts = []
current_part = None
index = 0
......@@ -210,5 +229,3 @@ def _split_base_level(text, begin="(", end=")", separator=",", string_terminator
fields.append(current_field)
parts.append(fields)
return parts
This diff is collapsed.
from flask_sqlalchemy import SQLAlchemy
from flask import session, redirect, url_for, request
from flask import session, redirect, url_for, flash
import re
from functools import wraps
......@@ -11,9 +11,10 @@ import config
db = SQLAlchemy()
# the following code is written by Lars Beckers and not to be published without permission
# the following code escape_tex is written by Lars Beckers
# and not to be published without permission
latex_chars = [
("\\", "\\backslash"), # this needs to be first
("\\", "\\backslash"), # this needs to be first
("$", "\$"),
('%', '\\%'),
('&', '\\&'),
......@@ -23,7 +24,6 @@ latex_chars = [
('}', '\\}'),
('[', '\\['),
(']', '\\]'),
#('"', '"\''),
('~', r'$\sim{}$'),
('^', r'\textasciicircum{}'),
('Ë„', r'\textasciicircum{}'),
......@@ -37,75 +37,99 @@ latex_chars = [
('<', '$<$'),
('>', '$>$'),
('\\backslashin', '$\\in$'),
('\\backslash', '$\\backslash$') # this needs to be last
('\\backslash', '$\\backslash$') # this needs to be last
]
def escape_tex(text):
out = text
for old, new in latex_chars:
out = out.replace(old, new)
# beware, the following is carefully crafted code
res = ''
k, l = (0, -1)
while k >= 0:
k = out.find('"', l+1)
if k >= 0:
res += out[l+1:k]
l = out.find('"', k+1)
if l >= 0:
res += '\\enquote{' + out[k+1:l] + '}'
start, end = (0, -1)
while start >= 0:
start = out.find('"', end + 1)
if start >= 0:
res += out[end + 1:start]
end = out.find('"', start + 1)
if end >= 0:
res += '\\enquote{' + out[start + 1:end] + '}'
else:
res += '"\'' + out[k+1:]
k = l
res += '"\'' + out[start + 1:]
start = end
else:
res += out[l+1:]
res += out[end + 1:]
# yes, this is not quite escaping latex chars, but anyway...
res = re.sub('([a-z])\(', '\\1 (', res)
res = re.sub('\)([a-z])', ') \\1', res)
#logging.debug('escape latex ({0}/{1}): {2} --> {3}'.format(len(text), len(res), text.split('\n')[0], res.split('\n')[0]))
return res
def unhyphen(text):
return " ".join([r"\mbox{" + word + "}" for word in text.split(" ")])
def date_filter(date):
return date.strftime("%d. %B %Y")
def datetime_filter(date):
return date.strftime("%d. %B %Y, %H:%M")
def date_filter_long(date):
return date.strftime("%A, %d.%m.%Y, Kalenderwoche %W")
def date_filter_short(date):
return date.strftime("%d.%m.%Y")
def time_filter(time):
return time.strftime("%H:%M Uhr")
def time_filter_short(time):
return time.strftime("%H:%M")