Skip to content
Snippets Groups Projects
Commit 6d9c3fdd authored by Robin Sonnabend's avatar Robin Sonnabend
Browse files

Improve *.py code quality

Also fixes some latent, rather unimportant bugs
parent 663a9d49
Branches
No related tags found
No related merge requests found
import hmac, hashlib import hmac
import hashlib
import ssl import ssl
from datetime import datetime from datetime import datetime
...@@ -16,8 +17,9 @@ class User: ...@@ -16,8 +17,9 @@ class User:
self.permanent = permanent self.permanent = permanent
def summarize(self): def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), return ":".join((
str(self.timestamp.timestamp()), self.obsolete, self.permanent) self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent))
@staticmethod @staticmethod
def from_summary(summary): def from_summary(summary):
...@@ -45,7 +47,8 @@ class UserManager: ...@@ -45,7 +47,8 @@ class UserManager:
for backend in self.backends: for backend in self.backends:
if backend.authenticate(username, password): if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password)))) groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete, return User(
username, groups, obsolete=backend.obsolete,
permanent=permanent) permanent=permanent)
return None return None
...@@ -77,8 +80,8 @@ class SecurityManager: ...@@ -77,8 +80,8 @@ class SecurityManager:
if user is None: if user is None:
return False return False
session_duration = datetime.now() - user.timestamp session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), macs_equal = hmac.compare_digest(
hash) maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent) return macs_equal and (time_short or user.permanent)
...@@ -104,12 +107,11 @@ class StaticUserManager: ...@@ -104,12 +107,11 @@ class StaticUserManager:
yield from self.group_map[username] yield from self.group_map[username]
def all_groups(self): def all_groups(self):
yield from list(set(group for group in groups.values())) yield from list(set(group for group in self.group_map.values()))
try: try:
import ldap3 import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager: class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
...@@ -121,8 +123,8 @@ try: ...@@ -121,8 +123,8 @@ try:
def authenticate(self, username, password): def authenticate(self, username, password):
try: try:
connection = ldap3.Connection(self.server, connection = ldap3.Connection(
self.user_dn.format(username), password) self.server, self.user_dn.format(username), password)
return connection.bind() return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError: except ldap3.core.exceptions.LDAPSocketOpenError:
return False return False
...@@ -144,16 +146,15 @@ try: ...@@ -144,16 +146,15 @@ try:
for group in group_reader.search(): for group in group_reader.search():
yield group.cn.value yield group.cn.value
class ADManager: class ADManager:
def __init__(self, host, domain, user_dn, group_dn, def __init__(self, host, domain, user_dn, group_dn,
port=636, use_ssl=True, ca_cert=None, obsolete=False): port=636, use_ssl=True, ca_cert=None, obsolete=False):
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED) tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
if ca_cert is not None: if ca_cert is not None:
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED, tls_config = ldap3.Tls(
ca_certs_file=ca_cert) validate=ssl.CERT_REQUIRED, ca_certs_file=ca_cert)
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl, self.server = ldap3.Server(
tls=tls_config) host, port=port, use_ssl=use_ssl, tls=tls_config)
self.domain = domain self.domain = domain
self.user_dn = user_dn self.user_dn = user_dn
self.group_dn = group_dn self.group_dn = group_dn
...@@ -176,11 +177,13 @@ try: ...@@ -176,11 +177,13 @@ try:
connection.bind() connection.bind()
obj_def = ldap3.ObjectDef("user", connection) obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username) name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn, user_reader = ldap3.Reader(
name_filter) connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection) group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn): def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None) group_reader = ldap3.Reader(
connection, group_def, group_dn, None)
for entry in group_reader.search(): for entry in group_reader.search():
yield entry.name.value yield entry.name.value
for child in entry.memberOf: for child in entry.memberOf:
...@@ -189,20 +192,22 @@ try: ...@@ -189,20 +192,22 @@ try:
for group_dn in result.memberOf: for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn) yield from _yield_recursive_groups(group_dn)
def all_groups(self): def all_groups(self):
connection = self.prepare_connection() connection = self.prepare_connection()
connection.bind() connection.bind()
obj_def = ldap3.ObjectDef("group", connection) obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn) group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search(): for result in group_reader.search():
yield result.name.value yield result.name.value
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
try: try:
import grp, pwd, pam import grp
import pwd
import pam
class PAMManager: class PAMManager:
def __init__(self, obsolete=False): def __init__(self, obsolete=False):
...@@ -224,4 +229,3 @@ try: ...@@ -224,4 +229,3 @@ try:
yield group.gr_name yield group.gr_name
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
...@@ -10,6 +10,7 @@ import config ...@@ -10,6 +10,7 @@ import config
cookie = getattr(config, "REDIRECT_BACK_COOKIE", "back") cookie = getattr(config, "REDIRECT_BACK_COOKIE", "back")
default_view = getattr(config, "REDIRECT_BACK_DEFAULT", "index") default_view = getattr(config, "REDIRECT_BACK_DEFAULT", "index")
def anchor(func, cookie=cookie): def anchor(func, cookie=cookie):
@functools.wraps(func) @functools.wraps(func)
def result(*args, **kwargs): def result(*args, **kwargs):
...@@ -17,8 +18,10 @@ def anchor(func, cookie=cookie): ...@@ -17,8 +18,10 @@ def anchor(func, cookie=cookie):
return func(*args, **kwargs) return func(*args, **kwargs)
return result return result
def url(default=default_view, cookie=cookie, **url_args): def url(default=default_view, cookie=cookie, **url_args):
return session.get(cookie, url_for(default, **url_args)) return session.get(cookie, url_for(default, **url_args))
def redirect(default=default_view, cookie=cookie, **url_args): def redirect(default=default_view, cookie=cookie, **url_args):
return flask_redirect(url(default, cookie, **url_args)) return flask_redirect(url(default, cookie, **url_args))
...@@ -2,15 +2,16 @@ from datetime import datetime, timedelta ...@@ -2,15 +2,16 @@ from datetime import datetime, timedelta
import random import random
import quopri import quopri
from caldav import DAVClient, Principal, Calendar, Event from caldav import DAVClient
from caldav.lib.error import PropfindError
from vobject.base import ContentLine from vobject.base import ContentLine
import config import config
class CalendarException(Exception): class CalendarException(Exception):
pass pass
class Client: class Client:
def __init__(self, calendar=None, url=None): def __init__(self, calendar=None, url=None):
if not config.CALENDAR_ACTIVE: if not config.CALENDAR_ACTIVE:
...@@ -23,9 +24,12 @@ class Client: ...@@ -23,9 +24,12 @@ class Client:
self.principal = self.client.principal() self.principal = self.client.principal()
break break
except Exception as exc: except Exception as exc:
print("Got exception {} from caldav, retrying".format(str(exc))) print("Got exception {} from caldav, retrying".format(
str(exc)))
if self.principal is None: if self.principal is None:
raise CalendarException("Got {} CalDAV-error from the CalDAV server.".format(config.CALENDAR_MAX_REQUESTS)) raise CalendarException(
"Got {} CalDAV-error from the CalDAV server.".format(
config.CALENDAR_MAX_REQUESTS))
if calendar is not None: if calendar is not None:
self.calendar = self.get_calendar(calendar) self.calendar = self.get_calendar(calendar)
else: else:
...@@ -41,9 +45,11 @@ class Client: ...@@ -41,9 +45,11 @@ class Client:
for calendar in self.principal.calendars() for calendar in self.principal.calendars()
] ]
except Exception as exc: except Exception as exc:
print("Got exception {} from caldav, retrying".format(str(exc))) print("Got exception {} from caldav, retrying".format(
raise CalendarException("Got {} CalDAV Errors from the CalDAV server.".format(config.CALENDAR_MAX_REQUESTS)) str(exc)))
raise CalendarException(
"Got {} CalDAV Errors from the CalDAV server.".format(
config.CALENDAR_MAX_REQUESTS))
def get_calendar(self, calendar_name): def get_calendar(self, calendar_name):
candidates = self.principal.calendars() candidates = self.principal.calendars()
...@@ -57,12 +63,14 @@ class Client: ...@@ -57,12 +63,14 @@ class Client:
return return
candidates = [ candidates = [
Event.from_raw_event(raw_event) Event.from_raw_event(raw_event)
for raw_event in self.calendar.date_search(begin, begin + timedelta(hours=1)) for raw_event in self.calendar.date_search(
begin, begin + timedelta(hours=1))
] ]
candidates = [event for event in candidates if event.name == name] candidates = [event for event in candidates if event.name == name]
event = None event = None
if len(candidates) == 0: if len(candidates) == 0:
event = Event(None, name, description, begin, event = Event(
None, name, description, begin,
begin + timedelta(hours=config.CALENDAR_DEFAULT_DURATION)) begin + timedelta(hours=config.CALENDAR_DEFAULT_DURATION))
vevent = self.calendar.add_event(event.to_vcal()) vevent = self.calendar.add_event(event.to_vcal())
event.vevent = vevent event.vevent = vevent
...@@ -76,11 +84,14 @@ NAME_KEY = "summary" ...@@ -76,11 +84,14 @@ NAME_KEY = "summary"
DESCRIPTION_KEY = "description" DESCRIPTION_KEY = "description"
BEGIN_KEY = "dtstart" BEGIN_KEY = "dtstart"
END_KEY = "dtend" END_KEY = "dtend"
def _get_item(content, key): def _get_item(content, key):
if key in content: if key in content:
return content[key][0].value return content[key][0].value
return None return None
class Event: class Event:
def __init__(self, vevent, name, description, begin, end): def __init__(self, vevent, name, description, begin, end):
self.vevent = vevent self.vevent = vevent
...@@ -97,7 +108,8 @@ class Event: ...@@ -97,7 +108,8 @@ class Event:
description = _get_item(content, DESCRIPTION_KEY) description = _get_item(content, DESCRIPTION_KEY)
begin = _get_item(content, BEGIN_KEY) begin = _get_item(content, BEGIN_KEY)
end = _get_item(content, END_KEY) end = _get_item(content, END_KEY)
return Event(vevent=vevent, name=name, description=description, return Event(
vevent=vevent, name=name, description=description,
begin=begin, end=end) begin=begin, end=end)
def set_description(self, description): def set_description(self, description):
...@@ -105,7 +117,8 @@ class Event: ...@@ -105,7 +117,8 @@ class Event:
self.description = description self.description = description
encoded = encode_quopri(description) encoded = encode_quopri(description)
if DESCRIPTION_KEY not in raw_event.contents: if DESCRIPTION_KEY not in raw_event.contents:
raw_event.contents[DESCRIPTION_KEY] = [ContentLine(DESCRIPTION_KEY, {"ENCODING": ["QUOTED-PRINTABLE"]}, encoded)] raw_event.contents[DESCRIPTION_KEY] = [ContentLine(
DESCRIPTION_KEY, {"ENCODING": ["QUOTED-PRINTABLE"]}, encoded)]
else: else:
content_line = raw_event.contents[DESCRIPTION_KEY][0] content_line = raw_event.contents[DESCRIPTION_KEY][0]
content_line.value = encoded content_line.value = encoded
...@@ -129,21 +142,28 @@ SUMMARY:{summary} ...@@ -129,21 +142,28 @@ SUMMARY:{summary}
DESCRIPTION;ENCODING=QUOTED-PRINTABLE:{description} DESCRIPTION;ENCODING=QUOTED-PRINTABLE:{description}
END:VEVENT END:VEVENT
END:VCALENDAR""".format( END:VCALENDAR""".format(
uid=create_uid(), now=date_format(datetime.now()-offset), uid=create_uid(),
begin=date_format(self.begin-offset), end=date_format(self.end-offset), now=date_format(datetime.now() - offset),
begin=date_format(self.begin - offset),
end=date_format(self.end - offset),
summary=self.name, summary=self.name,
description=encode_quopri(self.description)) description=encode_quopri(self.description))
def create_uid(): def create_uid():
return str(random.randint(0, 1e10)).rjust(10, "0") return str(random.randint(0, 1e10)).rjust(10, "0")
def date_format(dt): def date_format(dt):
return dt.strftime("%Y%m%dT%H%M%SZ") return dt.strftime("%Y%m%dT%H%M%SZ")
def get_timezone_offset(): def get_timezone_offset():
difference = datetime.now() - datetime.utcnow() difference = datetime.now() - datetime.utcnow()
return timedelta(hours=round(difference.seconds / 3600 + difference.days * 24)) return timedelta(
hours=round(difference.seconds / 3600 + difference.days * 24))
def encode_quopri(text):
return quopri.encodestring(text.encode("utf-8")).replace(b"\n", b"=0A").decode("utf-8")
def encode_quopri(text):
return quopri.encodestring(text.encode("utf-8")).replace(
b"\n", b"=0A").decode("utf-8")
...@@ -3,14 +3,20 @@ import regex as re ...@@ -3,14 +3,20 @@ import regex as re
import os import os
import sys import sys
ROUTE_PATTERN = r'@(?:[[:alpha:]])+\.route\(\"(?<url>[^"]+)"[^)]*\)\s*(?:@[[:alpha:]_()., ]+\s*)*def\s+(?<name>[[:alpha:]][[:alnum:]_]*)\((?<params>[[:alnum:], ]*)\):' ROUTE_PATTERN = (
r'@(?:[[:alpha:]])+\.route\(\"(?<url>[^"]+)"[^)]*\)\s*'
r'(?:@[[:alpha:]_()., ]+\s*)*def\s+(?<name>[[:alpha:]][[:alnum:]_]*)'
r'\((?<params>[[:alnum:], ]*)\):')
quote_group = "[\"']" quote_group = "[\"']"
URL_FOR_PATTERN = r'url_for\({quotes}(?<name>[[:alpha:]][[:alnum:]_]*){quotes}'.format(quotes=quote_group) URL_FOR_PATTERN = (
r'url_for\({quotes}(?<name>[[:alpha:]][[:alnum:]_]*)'
'{quotes}'.format(quotes=quote_group))
ROOT_DIR = "." ROOT_DIR = "."
ENDINGS = [".py", ".html", ".txt"] ENDINGS = [".py", ".html", ".txt"]
MAX_DEPTH = 2 MAX_DEPTH = 2
def list_dir(dir, level=0): def list_dir(dir, level=0):
if level >= MAX_DEPTH: if level >= MAX_DEPTH:
return return
...@@ -25,6 +31,7 @@ def list_dir(dir, level=0): ...@@ -25,6 +31,7 @@ def list_dir(dir, level=0):
elif os.path.isdir(path): elif os.path.isdir(path):
yield from list_dir(path, level + 1) yield from list_dir(path, level + 1)
class Route: class Route:
def __init__(self, file, name, parameters): def __init__(self, file, name, parameters):
self.file = file self.file = file
...@@ -38,13 +45,15 @@ class Route: ...@@ -38,13 +45,15 @@ class Route:
def get_parameter_set(self): def get_parameter_set(self):
return {parameter.name for parameter in self.parameters} return {parameter.name for parameter in self.parameters}
class Parameter: class Parameter:
def __init__(self, name, type=None): def __init__(self, name, type=None):
self.name = name self.name = name
self.type = type self.type = type
def __repr__(self): def __repr__(self):
return "Parameter({name}, {type})".format(name=self.name, type=self.type) return "Parameter({name}, {type})".format(
name=self.name, type=self.type)
@staticmethod @staticmethod
def from_string(text): def from_string(text):
...@@ -53,6 +62,7 @@ class Parameter: ...@@ -53,6 +62,7 @@ class Parameter:
return Parameter(name, type) return Parameter(name, type)
return Parameter(text) return Parameter(text)
def split_url_parameters(url): def split_url_parameters(url):
params = [] params = []
current_param = None current_param = None
...@@ -68,9 +78,11 @@ def split_url_parameters(url): ...@@ -68,9 +78,11 @@ def split_url_parameters(url):
current_param += char current_param += char
return params return params
def split_function_parameters(parameters): def split_function_parameters(parameters):
return list(map(str.strip, parameters.split(","))) return list(map(str.strip, parameters.split(",")))
def read_url_for_parameters(content): def read_url_for_parameters(content):
params = [] params = []
bracket_level = 1 bracket_level = 1
...@@ -92,6 +104,7 @@ def read_url_for_parameters(content): ...@@ -92,6 +104,7 @@ def read_url_for_parameters(content):
elif char == ")": elif char == ")":
bracket_level -= 1 bracket_level -= 1
class UrlFor: class UrlFor:
def __init__(self, file, name, parameters): def __init__(self, file, name, parameters):
self.file = file self.file = file
...@@ -99,8 +112,10 @@ class UrlFor: ...@@ -99,8 +112,10 @@ class UrlFor:
self.parameters = parameters self.parameters = parameters
def __repr__(self): def __repr__(self):
return "UrlFor(file={file}, name={name}, parameters={parameters})".format( return (
file=self.file, name=self.name, parameters=self.parameters) "UrlFor(file={file}, name={name}, parameters={parameters})".format(
file=self.file, name=self.name, parameters=self.parameters))
routes = {} routes = {}
url_fors = [] url_fors = []
...@@ -109,24 +124,29 @@ for file in list_dir(ROOT_DIR): ...@@ -109,24 +124,29 @@ for file in list_dir(ROOT_DIR):
content = infile.read() content = infile.read()
for match in re.finditer(ROUTE_PATTERN, content): for match in re.finditer(ROUTE_PATTERN, content):
name = match.group("name") name = match.group("name")
function_parameters = split_function_parameters(match.group("params")) function_parameters = split_function_parameters(
match.group("params"))
url_parameters = split_url_parameters(match.group("url")) url_parameters = split_url_parameters(match.group("url"))
routes[name] = Route(file, name, url_parameters) routes[name] = Route(file, name, url_parameters)
for match in re.finditer(URL_FOR_PATTERN, content): for match in re.finditer(URL_FOR_PATTERN, content):
name = match.group("name") name = match.group("name")
begin, end = match.span() begin, end = match.span()
parameters = read_url_for_parameters(content[end:]) parameters = read_url_for_parameters(content[end:])
url_fors.append(UrlFor(file=file, name=name, parameters=parameters)) url_fors.append(UrlFor(
file=file, name=name, parameters=parameters))
for url_for in url_fors: for url_for in url_fors:
if url_for.name not in routes: if url_for.name not in routes:
print("Missing route '{}' (for url_for in '{}')".format(url_for.name, url_for.file)) print("Missing route '{}' (for url_for in '{}')".format(
url_for.name, url_for.file))
continue continue
route = routes[url_for.name] route = routes[url_for.name]
route_parameters = route.get_parameter_set() route_parameters = route.get_parameter_set()
url_parameters = set(url_for.parameters) url_parameters = set(url_for.parameters)
if len(route_parameters ^ url_parameters) > 0: if len(route_parameters ^ url_parameters) > 0:
print("Parameters not matching for '{}' in '{}:'".format(url_for.name, url_for.file)) print("Parameters not matching for '{}' in '{}:'".format(
url_for.name, url_for.file))
only_route = route_parameters - url_parameters only_route = route_parameters - url_parameters
only_url = url_parameters - route_parameters only_url = url_parameters - route_parameters
if len(only_route) > 0: if len(only_route) > 0:
......
from flask import redirect, flash, request, url_for from flask import flash
from functools import wraps from functools import wraps
from models.database import ALL_MODELS from models.database import ALL_MODELS
from shared import db, current_user from shared import current_user
import back import back
ID_KEY = "id" ID_KEY = "id"
...@@ -12,12 +12,15 @@ OBJECT_DOES_NOT_EXIST_MESSAGE = "There is no {} with id {}." ...@@ -12,12 +12,15 @@ OBJECT_DOES_NOT_EXIST_MESSAGE = "There is no {} with id {}."
MISSING_VIEW_RIGHT = "Dir fehlenden die nötigen Zugriffsrechte." MISSING_VIEW_RIGHT = "Dir fehlenden die nötigen Zugriffsrechte."
def default_redirect(): def default_redirect():
return back.redirect() return back.redirect()
def login_redirect(): def login_redirect():
return back.redirect("login") return back.redirect("login")
def db_lookup(*models, check_exists=True): def db_lookup(*models, check_exists=True):
def _decorator(function): def _decorator(function):
@wraps(function) @wraps(function)
...@@ -32,7 +35,8 @@ def db_lookup(*models, check_exists=True): ...@@ -32,7 +35,8 @@ def db_lookup(*models, check_exists=True):
obj = model.query.filter_by(id=obj_id).first() obj = model.query.filter_by(id=obj_id).first()
if check_exists and obj is None: if check_exists and obj is None:
model_name = model.__class__.__name__ model_name = model.__class__.__name__
flash(OBJECT_DOES_NOT_EXIST_MESSAGE.format(model_name, obj_id), flash(OBJECT_DOES_NOT_EXIST_MESSAGE.format(
model_name, obj_id),
"alert-error") "alert-error")
return default_redirect() return default_redirect()
kwargs[key] = obj kwargs[key] = obj
...@@ -41,8 +45,10 @@ def db_lookup(*models, check_exists=True): ...@@ -41,8 +45,10 @@ def db_lookup(*models, check_exists=True):
return _decorated_function return _decorated_function
return _decorator return _decorator
def require_right(right, require_exist): def require_right(right, require_exist):
necessary_right_name = "has_{}_right".format(right) necessary_right_name = "has_{}_right".format(right)
def _decorator(function): def _decorator(function):
@wraps(function) @wraps(function)
def _decorated_function(*args, **kwargs): def _decorated_function(*args, **kwargs):
...@@ -65,17 +71,22 @@ def require_right(right, require_exist): ...@@ -65,17 +71,22 @@ def require_right(right, require_exist):
return _decorated_function return _decorated_function
return _decorator return _decorator
def require_public_view_right(require_exist=True): def require_public_view_right(require_exist=True):
return require_right("public_view", require_exist) return require_right("public_view", require_exist)
def require_private_view_right(require_exist=True): def require_private_view_right(require_exist=True):
return require_right("private_view", require_exist) return require_right("private_view", require_exist)
def require_modify_right(require_exist=True): def require_modify_right(require_exist=True):
return require_right("modify", require_exist) return require_right("modify", require_exist)
def require_publish_right(require_exist=True): def require_publish_right(require_exist=True):
return require_right("publish", require_exist) return require_right("publish", require_exist)
def require_admin_right(require_exist=True): def require_admin_right(require_exist=True):
return require_right("admin", require_exist) return require_right("admin", require_exist)
from datetime import datetime from datetime import datetime
from fuzzywuzzy import fuzz, process from fuzzywuzzy import process
import tempfile
from models.database import Todo, OldTodo, Protocol, ProtocolType, TodoMail from models.database import OldTodo, Protocol, ProtocolType, TodoMail
from shared import db from shared import db
import config import config
def lookup_todo_id(old_candidates, new_who, new_description): def lookup_todo_id(old_candidates, new_who, new_description):
# Check for perfect matches # Check for perfect matches
for candidate in old_candidates: for candidate in old_candidates:
if candidate.who == new_who and candidate.description == new_description: if (candidate.who == new_who
and candidate.description == new_description):
return candidate.old_id return candidate.old_id
# Accept if who has been changed # Accept if who has been changed
for candidate in old_candidates: for candidate in old_candidates:
...@@ -32,11 +33,13 @@ def lookup_todo_id(old_candidates, new_who, new_description): ...@@ -32,11 +33,13 @@ def lookup_todo_id(old_candidates, new_who, new_description):
new_description, best_match, best_match_score)) new_description, best_match, best_match_score))
return None return None
INSERT_PROTOCOLTYPE = "INSERT INTO `protocolManager_protocoltype`" INSERT_PROTOCOLTYPE = "INSERT INTO `protocolManager_protocoltype`"
INSERT_PROTOCOL = "INSERT INTO `protocolManager_protocol`" INSERT_PROTOCOL = "INSERT INTO `protocolManager_protocol`"
INSERT_TODO = "INSERT INTO `protocolManager_todo`" INSERT_TODO = "INSERT INTO `protocolManager_todo`"
INSERT_TODOMAIL = "INSERT INTO `protocolManager_todonamemailassignment`" INSERT_TODOMAIL = "INSERT INTO `protocolManager_todonamemailassignment`"
def import_old_protocols(sql_text): def import_old_protocols(sql_text):
protocoltype_lines = [] protocoltype_lines = []
protocol_lines = [] protocol_lines = []
...@@ -50,18 +53,23 @@ def import_old_protocols(sql_text): ...@@ -50,18 +53,23 @@ def import_old_protocols(sql_text):
raise ValueError("Necessary lines not found.") raise ValueError("Necessary lines not found.")
type_id_to_handle = {} type_id_to_handle = {}
for type_line in protocoltype_lines: for type_line in protocoltype_lines:
for id, handle, name, mail, protocol_id in _split_insert_line(type_line): for id, handle, name, mail, protocol_id in _split_insert_line(
type_line):
type_id_to_handle[int(id)] = handle.lower() type_id_to_handle[int(id)] = handle.lower()
protocols = [] protocols = []
for protocol_line in protocol_lines: for protocol_line in protocol_lines:
for (protocol_id, old_type_id, date, source, textsummary, htmlsummary, for (protocol_id, old_type_id, date, source, textsummary, htmlsummary,
deleted, sent, document_id) in _split_insert_line(protocol_line): deleted, sent, document_id) in _split_insert_line(
protocol_line):
date = datetime.strptime(date, "%Y-%m-%d") date = datetime.strptime(date, "%Y-%m-%d")
handle = type_id_to_handle[int(old_type_id)] handle = type_id_to_handle[int(old_type_id)]
protocoltype = ProtocolType.query.filter(ProtocolType.short_name.ilike(handle)).first() protocoltype = ProtocolType.query.filter(
ProtocolType.short_name.ilike(handle)).first()
if protocoltype is None: if protocoltype is None:
raise KeyError("No protocoltype for handle '{}'.".format(handle)) raise KeyError(
protocol = Protocol(protocoltype_id=protocoltype.id, date=date, source=source) "No protocoltype for handle '{}'.".format(handle))
protocol = Protocol(
protocoltype_id=protocoltype.id, date=date, source=source)
db.session.add(protocol) db.session.add(protocol)
db.session.commit() db.session.commit()
import tasks import tasks
...@@ -70,6 +78,7 @@ def import_old_protocols(sql_text): ...@@ -70,6 +78,7 @@ def import_old_protocols(sql_text):
print(protocol.date) print(protocol.date)
tasks.parse_protocol(protocol) tasks.parse_protocol(protocol)
def import_old_todomails(sql_text): def import_old_todomails(sql_text):
todomail_lines = [] todomail_lines = []
for line in sql_text.splitlines(): for line in sql_text.splitlines():
...@@ -103,23 +112,29 @@ def import_old_todos(sql_text): ...@@ -103,23 +112,29 @@ def import_old_todos(sql_text):
raise ValueError("Necessary lines not found.") raise ValueError("Necessary lines not found.")
type_id_to_handle = {} type_id_to_handle = {}
for type_line in protocoltype_lines: for type_line in protocoltype_lines:
for id, handle, name, mail, protocol_id in _split_insert_line(type_line): for id, handle, name, mail, protocol_id in _split_insert_line(
type_line):
type_id_to_handle[int(id)] = handle.lower() type_id_to_handle[int(id)] = handle.lower()
protocol_id_to_key = {} protocol_id_to_key = {}
for protocol_line in protocol_lines: for protocol_line in protocol_lines:
for (protocol_id, type_id, date, source, textsummary, htmlsummary, for (protocol_id, type_id, date, source, textsummary, htmlsummary,
deleted, sent, document_id) in _split_insert_line(protocol_line): deleted, sent, document_id) in _split_insert_line(
protocol_line):
handle = type_id_to_handle[int(type_id)] handle = type_id_to_handle[int(type_id)]
date_string = date[2:] date_string = date[2:]
protocol_id_to_key[int(protocol_id)] = "{}-{}".format(handle, date_string) protocol_id_to_key[int(protocol_id)] = "{}-{}".format(
handle, date_string)
todos = [] todos = []
for todo_line in todo_lines: for todo_line in todo_lines:
for old_id, protocol_id, who, what, start_time, end_time, done in _split_insert_line(todo_line): for (old_id, protocol_id, who, what, start_time, end_time,
done) in _split_insert_line(todo_line):
protocol_id = int(protocol_id) protocol_id = int(protocol_id)
if protocol_id not in protocol_id_to_key: if protocol_id not in protocol_id_to_key:
print("Missing protocol with ID {} for Todo {}".format(protocol_id, what)) print("Missing protocol with ID {} for Todo {}".format(
protocol_id, what))
continue continue
todo = OldTodo(old_id=old_id, who=who, description=what, todo = OldTodo(
old_id=old_id, who=who, description=what,
protocol_key=protocol_id_to_key[protocol_id]) protocol_key=protocol_id_to_key[protocol_id])
todos.append(todo) todos.append(todo)
OldTodo.query.delete() OldTodo.query.delete()
...@@ -128,11 +143,15 @@ def import_old_todos(sql_text): ...@@ -128,11 +143,15 @@ def import_old_todos(sql_text):
db.session.add(todo) db.session.add(todo)
db.session.commit() db.session.commit()
def _split_insert_line(line): def _split_insert_line(line):
insert_part, values_part = line.split("VALUES", 1) insert_part, values_part = line.split("VALUES", 1)
return _split_base_level(values_part) return _split_base_level(values_part)
def _split_base_level(text, begin="(", end=")", separator=",", string_terminator="'", line_end=";", ignore=" ", escape="\\"):
def _split_base_level(
text, begin="(", end=")", separator=",", string_terminator="'",
line_end=";", ignore=" ", escape="\\"):
raw_parts = [] raw_parts = []
current_part = None current_part = None
index = 0 index = 0
...@@ -210,5 +229,3 @@ def _split_base_level(text, begin="(", end=")", separator=",", string_terminator ...@@ -210,5 +229,3 @@ def _split_base_level(text, begin="(", end=")", separator=",", string_terminator
fields.append(current_field) fields.append(current_field)
parts.append(fields) parts.append(fields)
return parts return parts
This diff is collapsed.
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from flask import session, redirect, url_for, request from flask import session, redirect, url_for, flash
import re import re
from functools import wraps from functools import wraps
...@@ -11,7 +11,8 @@ import config ...@@ -11,7 +11,8 @@ import config
db = SQLAlchemy() db = SQLAlchemy()
# the following code is written by Lars Beckers and not to be published without permission # the following code escape_tex is written by Lars Beckers
# and not to be published without permission
latex_chars = [ latex_chars = [
("\\", "\\backslash"), # this needs to be first ("\\", "\\backslash"), # this needs to be first
("$", "\$"), ("$", "\$"),
...@@ -23,7 +24,6 @@ latex_chars = [ ...@@ -23,7 +24,6 @@ latex_chars = [
('}', '\\}'), ('}', '\\}'),
('[', '\\['), ('[', '\\['),
(']', '\\]'), (']', '\\]'),
#('"', '"\''),
('~', r'$\sim{}$'), ('~', r'$\sim{}$'),
('^', r'\textasciicircum{}'), ('^', r'\textasciicircum{}'),
('Ë„', r'\textasciicircum{}'), ('Ë„', r'\textasciicircum{}'),
...@@ -40,72 +40,96 @@ latex_chars = [ ...@@ -40,72 +40,96 @@ latex_chars = [
('\\backslash', '$\\backslash$') # this needs to be last ('\\backslash', '$\\backslash$') # this needs to be last
] ]
def escape_tex(text): def escape_tex(text):
out = text out = text
for old, new in latex_chars: for old, new in latex_chars:
out = out.replace(old, new) out = out.replace(old, new)
# beware, the following is carefully crafted code # beware, the following is carefully crafted code
res = '' res = ''
k, l = (0, -1) start, end = (0, -1)
while k >= 0: while start >= 0:
k = out.find('"', l+1) start = out.find('"', end + 1)
if k >= 0: if start >= 0:
res += out[l+1:k] res += out[end + 1:start]
l = out.find('"', k+1) end = out.find('"', start + 1)
if l >= 0: if end >= 0:
res += '\\enquote{' + out[k+1:l] + '}' res += '\\enquote{' + out[start + 1:end] + '}'
else: else:
res += '"\'' + out[k+1:] res += '"\'' + out[start + 1:]
k = l start = end
else: else:
res += out[l+1:] res += out[end + 1:]
# yes, this is not quite escaping latex chars, but anyway... # yes, this is not quite escaping latex chars, but anyway...
res = re.sub('([a-z])\(', '\\1 (', res) res = re.sub('([a-z])\(', '\\1 (', res)
res = re.sub('\)([a-z])', ') \\1', res) res = re.sub('\)([a-z])', ') \\1', res)
#logging.debug('escape latex ({0}/{1}): {2} --> {3}'.format(len(text), len(res), text.split('\n')[0], res.split('\n')[0]))
return res return res
def unhyphen(text): def unhyphen(text):
return " ".join([r"\mbox{" + word + "}" for word in text.split(" ")]) return " ".join([r"\mbox{" + word + "}" for word in text.split(" ")])
def date_filter(date): def date_filter(date):
return date.strftime("%d. %B %Y") return date.strftime("%d. %B %Y")
def datetime_filter(date): def datetime_filter(date):
return date.strftime("%d. %B %Y, %H:%M") return date.strftime("%d. %B %Y, %H:%M")
def date_filter_long(date): def date_filter_long(date):
return date.strftime("%A, %d.%m.%Y, Kalenderwoche %W") return date.strftime("%A, %d.%m.%Y, Kalenderwoche %W")
def date_filter_short(date): def date_filter_short(date):
return date.strftime("%d.%m.%Y") return date.strftime("%d.%m.%Y")
def time_filter(time): def time_filter(time):
return time.strftime("%H:%M Uhr") return time.strftime("%H:%M Uhr")
def time_filter_short(time): def time_filter_short(time):
return time.strftime("%H:%M") return time.strftime("%H:%M")
def needs_date_test(todostate): def needs_date_test(todostate):
return todostate.needs_date() return todostate.needs_date()
def todostate_name_filter(todostate): def todostate_name_filter(todostate):
return todostate.get_name() return todostate.get_name()
def indent_tab_filter(text): def indent_tab_filter(text):
return "\n".join(map(lambda l: "\t{}".format(l), text.splitlines())) return "\n".join(map(lambda l: "\t{}".format(l), text.splitlines()))
def class_filter(obj): def class_filter(obj):
return obj.__class__.__name__ return obj.__class__.__name__
def code_filter(text): def code_filter(text):
return "<code>{}</code>".format(text) return "<code>{}</code>".format(text)
from auth import UserManager, SecurityManager, User from auth import UserManager, SecurityManager, User
max_duration = getattr(config, "AUTH_MAX_DURATION") max_duration = getattr(config, "AUTH_MAX_DURATION")
user_manager = UserManager(backends=config.AUTH_BACKENDS) user_manager = UserManager(backends=config.AUTH_BACKENDS)
security_manager = SecurityManager(config.SECURITY_KEY, max_duration) security_manager = SecurityManager(config.SECURITY_KEY, max_duration)
def check_login(): def check_login():
return "auth" in session and security_manager.check_user(session["auth"]) return "auth" in session and security_manager.check_user(session["auth"])
def current_user(): def current_user():
if not check_login(): if not check_login():
return None return None
return User.from_hashstring(session["auth"]) return User.from_hashstring(session["auth"])
def login_required(function): def login_required(function):
@wraps(function) @wraps(function)
def decorated_function(*args, **kwargs): def decorated_function(*args, **kwargs):
...@@ -115,6 +139,7 @@ def login_required(function): ...@@ -115,6 +139,7 @@ def login_required(function):
return redirect(url_for("login")) return redirect(url_for("login"))
return decorated_function return decorated_function
def group_required(group): def group_required(group):
def decorator(function): def decorator(function):
@wraps(function) @wraps(function)
...@@ -122,16 +147,19 @@ def group_required(group): ...@@ -122,16 +147,19 @@ def group_required(group):
if group in current_user().groups: if group in current_user().groups:
return function(*args, **kwargs) return function(*args, **kwargs)
else: else:
flash("You do not have the necessary permissions to view this page.") flash("You do not have the necessary permissions to "
"view this page.")
return back.redirect() return back.redirect()
return decorated_function return decorated_function
return decorator return decorator
DATE_KEY = "Datum" DATE_KEY = "Datum"
START_TIME_KEY = "Beginn" START_TIME_KEY = "Beginn"
END_TIME_KEY = "Ende" END_TIME_KEY = "Ende"
KNOWN_KEYS = [DATE_KEY, START_TIME_KEY, END_TIME_KEY] KNOWN_KEYS = [DATE_KEY, START_TIME_KEY, END_TIME_KEY]
class WikiType(Enum): class WikiType(Enum):
MEDIAWIKI = 0 MEDIAWIKI = 0
DOKUWIKI = 1 DOKUWIKI = 1
from flask import render_template, request from flask import request
import random import random
import string import string
import regex
import math import math
import smtplib import smtplib
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.mime.application import MIMEApplication from email.mime.application import MIMEApplication
from datetime import datetime, date, timedelta from datetime import datetime
import requests import requests
from io import BytesIO from io import BytesIO
import ipaddress import ipaddress
...@@ -18,12 +17,16 @@ import subprocess ...@@ -18,12 +17,16 @@ import subprocess
import config import config
def random_string(length): def random_string(length):
return "".join((random.choice(string.ascii_letters) for i in range(length))) return "".join((random.choice(string.ascii_letters)
for i in range(length)))
def is_past(some_date): def is_past(some_date):
return (datetime.now() - some_date).total_seconds() > 0 return (datetime.now() - some_date).total_seconds() > 0
def encode_kwargs(kwargs): def encode_kwargs(kwargs):
encoded_kwargs = {} encoded_kwargs = {}
for key in kwargs: for key in kwargs:
...@@ -34,6 +37,7 @@ def encode_kwargs(kwargs): ...@@ -34,6 +37,7 @@ def encode_kwargs(kwargs):
encoded_kwargs[key] = (type(value), value, False) encoded_kwargs[key] = (type(value), value, False)
return encoded_kwargs return encoded_kwargs
def decode_kwargs(encoded_kwargs): def decode_kwargs(encoded_kwargs):
kwargs = {} kwargs = {}
for name in encoded_kwargs: for name in encoded_kwargs:
...@@ -45,27 +49,6 @@ def decode_kwargs(encoded_kwargs): ...@@ -45,27 +49,6 @@ def decode_kwargs(encoded_kwargs):
return kwargs return kwargs
class UrlManager:
def __init__(self, config):
self.pattern = regex.compile(r"(?:(?<proto>https?):\/\/)?(?<hostname>[[:alnum:]_.]+(?:\:[[:digit:]]+)?)?(?<path>(?:\/[[:alnum:]_#]*)+)?(?:\?(?<params>.*))?")
self.base = "{}://{}{}{}"
self.proto = getattr(config, "URL_PROTO", "https")
self.root = getattr(config, "URL_ROOT", "example.com")
self.path = getattr(config, "URL_PATH", "/")
self.params = getattr(config, "URL_PARAMS", "")
def complete(self, url):
match = self.pattern.match(url)
if match is None:
return None
proto = match.group("proto") or self.proto
root = match.group("hostname") or self.root
path = match.group("path") or self.path
params = match.group("params") or self.params
return self.base.format(proto, root, path, "?" + params if len(params) > 0 else "")
#url_manager = UrlManager(config)
class MailManager: class MailManager:
def __init__(self, config): def __init__(self, config):
self.active = getattr(config, "MAIL_ACTIVE", False) self.active = getattr(config, "MAIL_ACTIVE", False)
...@@ -76,12 +59,17 @@ class MailManager: ...@@ -76,12 +59,17 @@ class MailManager:
self.use_tls = getattr(config, "MAIL_USE_TLS", True) self.use_tls = getattr(config, "MAIL_USE_TLS", True)
self.use_starttls = getattr(config, "MAIL_USE_STARTTLS", False) self.use_starttls = getattr(config, "MAIL_USE_STARTTLS", False)
def _get_smtp(self):
if self.use_tls:
return smtplib.SMTP_SSL
return smtplib.SMTP
def send(self, to_addr, subject, content, appendix=None, reply_to=None): def send(self, to_addr, subject, content, appendix=None, reply_to=None):
if (not self.active if (not self.active
or not self.hostname or not self.hostname
or not self.from_addr): or not self.from_addr):
return return
msg = MIMEMultipart("mixed") # todo: test if clients accept attachment-free mails set to multipart/mixed msg = MIMEMultipart("mixed")
msg["From"] = self.from_addr msg["From"] = self.from_addr
msg["To"] = to_addr msg["To"] = to_addr
msg["Subject"] = subject msg["Subject"] = subject
...@@ -92,9 +80,10 @@ class MailManager: ...@@ -92,9 +80,10 @@ class MailManager:
if appendix is not None: if appendix is not None:
for name, file_like in appendix: for name, file_like in appendix:
part = MIMEApplication(file_like.read(), "octet-stream") part = MIMEApplication(file_like.read(), "octet-stream")
part["Content-Disposition"] = 'attachment; filename="{}"'.format(name) part["Content-Disposition"] = (
'attachment; filename="{}"'.format(name))
msg.attach(part) msg.attach(part)
server = (smtplib.SMTP_SSL if self.use_tls else smtplib.SMTP)(self.hostname) server = self._get_smtp()(self.hostname)
if self.use_starttls: if self.use_starttls:
server.starttls() server.starttls()
if self.username not in [None, ""] and self.password not in [None, ""]: if self.username not in [None, ""] and self.password not in [None, ""]:
...@@ -102,8 +91,10 @@ class MailManager: ...@@ -102,8 +91,10 @@ class MailManager:
server.sendmail(self.from_addr, to_addr.split(","), msg.as_string()) server.sendmail(self.from_addr, to_addr.split(","), msg.as_string())
server.quit() server.quit()
mail_manager = MailManager(config) mail_manager = MailManager(config)
def get_first_unused_int(numbers): def get_first_unused_int(numbers):
positive_numbers = [number for number in numbers if number >= 0] positive_numbers = [number for number in numbers if number >= 0]
if len(positive_numbers) == 0: if len(positive_numbers) == 0:
...@@ -114,24 +105,34 @@ def get_first_unused_int(numbers): ...@@ -114,24 +105,34 @@ def get_first_unused_int(numbers):
return linear return linear
return highest + 1 return highest + 1
def normalize_pad(pad): def normalize_pad(pad):
return pad.replace(" ", "_") return pad.replace(" ", "_")
def get_etherpad_url(pad): def get_etherpad_url(pad):
return "{}/p/{}".format(config.ETHERPAD_URL, normalize_pad(pad)) return "{}/p/{}".format(config.ETHERPAD_URL, normalize_pad(pad))
def get_etherpad_export_url(pad): def get_etherpad_export_url(pad):
return "{}/p/{}/export/txt".format(config.ETHERPAD_URL, normalize_pad(pad)) return "{}/p/{}/export/txt".format(config.ETHERPAD_URL, normalize_pad(pad))
def get_etherpad_import_url(pad): def get_etherpad_import_url(pad):
return "{}/p/{}/import".format(config.ETHERPAD_URL, normalize_pad(pad)) return "{}/p/{}/import".format(config.ETHERPAD_URL, normalize_pad(pad))
def get_etherpad_text(pad): def get_etherpad_text(pad):
req = requests.get(get_etherpad_export_url(pad)) req = requests.get(get_etherpad_export_url(pad))
return req.text return req.text
def set_etherpad_text(pad, text, only_if_default=True): def set_etherpad_text(pad, text, only_if_default=True):
print(pad) print(pad)
if only_if_default: if only_if_default:
current_text = get_etherpad_text(pad) current_text = get_etherpad_text(pad)
if current_text != config.EMPTY_ETHERPAD and len(current_text.strip()) > 0: if (current_text != config.EMPTY_ETHERPAD
and len(current_text.strip()) > 0):
return False return False
file_like = BytesIO(text.encode("utf-8")) file_like = BytesIO(text.encode("utf-8"))
files = {"file": file_like} files = {"file": file_like}
...@@ -140,6 +141,7 @@ def set_etherpad_text(pad, text, only_if_default=True): ...@@ -140,6 +141,7 @@ def set_etherpad_text(pad, text, only_if_default=True):
req = requests.post(url, files=files) req = requests.post(url, files=files)
return req.status_code == 200 return req.status_code == 200
def split_terms(text, quote_chars="\"'", separators=" \t\n"): def split_terms(text, quote_chars="\"'", separators=" \t\n"):
terms = [] terms = []
in_quote = False in_quote = False
...@@ -169,12 +171,14 @@ def split_terms(text, quote_chars="\"'", separators=" \t\n"): ...@@ -169,12 +171,14 @@ def split_terms(text, quote_chars="\"'", separators=" \t\n"):
terms.append(current_term) terms.append(current_term)
return terms return terms
def optional_int_arg(name): def optional_int_arg(name):
try: try:
return int(request.args.get(name)) return int(request.args.get(name))
except (ValueError, TypeError): except (ValueError, TypeError):
return None return None
def add_line_numbers(text): def add_line_numbers(text):
raw_lines = text.splitlines() raw_lines = text.splitlines()
linenumber_length = math.ceil(math.log10(len(raw_lines)) + 1) linenumber_length = math.ceil(math.log10(len(raw_lines)) + 1)
...@@ -186,9 +190,11 @@ def add_line_numbers(text): ...@@ -186,9 +190,11 @@ def add_line_numbers(text):
)) ))
return "\n".join(lines) return "\n".join(lines)
def check_ip_in_networks(networks_string): def check_ip_in_networks(networks_string):
address = ipaddress.ip_address(request.remote_addr) address = ipaddress.ip_address(request.remote_addr)
if address == ipaddress.ip_address("127.0.0.1") and "X-Real-Ip" in request.headers: if (address == ipaddress.ip_address("127.0.0.1")
and "X-Real-Ip" in request.headers):
address = ipaddress.ip_address(request.headers["X-Real-Ip"]) address = ipaddress.ip_address(request.headers["X-Real-Ip"])
try: try:
for network_string in networks_string.split(","): for network_string in networks_string.split(","):
...@@ -199,6 +205,7 @@ def check_ip_in_networks(networks_string): ...@@ -199,6 +205,7 @@ def check_ip_in_networks(networks_string):
except ValueError: except ValueError:
return False return False
def fancy_join(values, sep1=" und ", sep2=", "): def fancy_join(values, sep1=" und ", sep2=", "):
values = list(values) values = list(values)
if len(values) <= 1: if len(values) <= 1:
...@@ -207,9 +214,11 @@ def fancy_join(values, sep1=" und ", sep2=", "): ...@@ -207,9 +214,11 @@ def fancy_join(values, sep1=" und ", sep2=", "):
start = values[:-1] start = values[:-1]
return "{}{}{}".format(sep2.join(start), sep1, last) return "{}{}{}".format(sep2.join(start), sep1, last)
def footnote_hash(text, length=5): def footnote_hash(text, length=5):
return str(sum(ord(c) * i for i, c in enumerate(text)) % 10**length) return str(sum(ord(c) * i for i, c in enumerate(text)) % 10**length)
def parse_datetime_from_string(text): def parse_datetime_from_string(text):
text = text.strip() text = text.strip()
for format in ("%d.%m.%Y", "%d.%m.%y", "%Y-%m-%d", for format in ("%d.%m.%Y", "%d.%m.%y", "%Y-%m-%d",
...@@ -220,9 +229,10 @@ def parse_datetime_from_string(text): ...@@ -220,9 +229,10 @@ def parse_datetime_from_string(text):
pass pass
for format in ("%d.%m.", "%d. %m.", "%d.%m", "%d.%m"): for format in ("%d.%m.", "%d. %m.", "%d.%m", "%d.%m"):
try: try:
return datetime.strptime(text, format).replace(year=datetime.now().year) return datetime.strptime(text, format).replace(
except ValueError as exc: year=datetime.now().year)
print(exc) except ValueError:
pass
raise ValueError("Date '{}' does not match any known format!".format(text)) raise ValueError("Date '{}' does not match any known format!".format(text))
...@@ -248,4 +258,3 @@ def get_max_page_length_exp(objects): ...@@ -248,4 +258,3 @@ def get_max_page_length_exp(objects):
def get_internal_filename(protocol, document, filename): def get_internal_filename(protocol, document, filename):
return "{}-{}-{}".format(protocol.id, document.id, filename) return "{}-{}-{}".format(protocol.id, document.id, filename)
from models.database import TodoState from models.database import TodoState
from wtforms import ValidationError from wtforms import ValidationError
from wtforms.validators import InputRequired from wtforms.validators import InputRequired
from shared import db
class CheckTodoDateByState: class CheckTodoDateByState:
def __init__(self): def __init__(self):
...@@ -16,4 +16,3 @@ class CheckTodoDateByState: ...@@ -16,4 +16,3 @@ class CheckTodoDateByState:
date_check(form, form.date) date_check(form, form.date)
except ValueError: except ValueError:
raise ValidationError("Invalid state.") raise ValidationError("Invalid state.")
import requests import requests
import json
import config import config
HTTP_STATUS_OK = 200 HTTP_STATUS_OK = 200
HTTP_STATUS_AUTHENTICATE = 401 HTTP_STATUS_AUTHENTICATE = 401
class WikiException(Exception): class WikiException(Exception):
pass pass
def _filter_params(params): def _filter_params(params):
result = {} result = {}
for key, value in sorted(params.items(), key=lambda t: t[0] == "token"): for key, value in sorted(params.items(), key=lambda t: t[0] == "token"):
...@@ -19,14 +20,20 @@ def _filter_params(params): ...@@ -19,14 +20,20 @@ def _filter_params(params):
result[key] = value result[key] = value
return result return result
class WikiClient: class WikiClient:
def __init__(self, active=None, endpoint=None, anonymous=None, user=None, password=None, domain=None): def __init__(self, active=None, endpoint=None, anonymous=None, user=None,
self.active = active if active is not None else config.WIKI_ACTIVE password=None, domain=None):
self.endpoint = endpoint if endpoint is not None else config.WIKI_API_URL def _or_default(value, default):
self.anonymous = anonymous if anonymous is not None else config.WIKI_ANONYMOUS if value is None:
self.user = user if user is not None else config.WIKI_USER return default
self.password = password if password is not None else config.WIKI_PASSWORD return value
self.domain = domain if domain is not None else config.WIKI_DOMAIN self.active = _or_default(active, config.WIKI_ACTIVE)
self.endpoint = _or_default(endpoint, config.WIKI_API_URL)
self.anonymous = _or_default(anonymous, config.WIKI_ANONYMOUS)
self.user = _or_default(user, config.WIKI_USER)
self.password = _or_default(password, config.WIKI_PASSWORD)
self.domain = _or_default(domain, config.WIKI_DOMAIN)
self.token = None self.token = None
self.cookies = requests.cookies.RequestsCookieJar() self.cookies = requests.cookies.RequestsCookieJar()
...@@ -45,28 +52,31 @@ class WikiClient: ...@@ -45,28 +52,31 @@ class WikiClient:
def login(self): def login(self):
if not self.active: if not self.active:
return return
# todo: Change this to the new MediaWiki tokens api once the wiki is updated # todo: Change this to the new MediaWiki tokens api
token_answer = self.do_action("login", method="post", lgname=self.user) token_answer = self.do_action("login", method="post", lgname=self.user)
if "login" not in token_answer or "token" not in token_answer["login"]: if "login" not in token_answer or "token" not in token_answer["login"]:
raise WikiException("No token in login answer.") raise WikiException("No token in login answer.")
lgtoken = token_answer["login"]["token"] lgtoken = token_answer["login"]["token"]
login_answer = self.do_action("login", method="post", lgname=self.user, lgpassword=self.password, lgdomain=self.domain, lgtoken=lgtoken) login_answer = self.do_action(
"login", method="post", lgname=self.user, lgpassword=self.password,
lgdomain=self.domain, lgtoken=lgtoken)
if ("login" not in login_answer if ("login" not in login_answer
or "result" not in login_answer["login"] or "result" not in login_answer["login"]
or login_answer["login"]["result"] != "Success"): or login_answer["login"]["result"] != "Success"):
raise WikiException("Login not successful.") raise WikiException("Login not successful.")
def logout(self): def logout(self):
if not self.active: if not self.active:
return return
self.do_action("logout") self.do_action("logout")
def edit_page(self, title, content, summary, recreate=True, createonly=False): def edit_page(self, title, content, summary, recreate=True,
createonly=False):
if not self.active: if not self.active:
return return
# todo: port to new api once the wiki is updated # todo: port to new api once the wiki is updated
prop_answer = self.do_action("query", method="get", prop="info", intoken="edit", titles=title) prop_answer = self.do_action(
"query", method="get", prop="info", intoken="edit", titles=title)
if ("query" not in prop_answer if ("query" not in prop_answer
or "pages" not in prop_answer["query"]): or "pages" not in prop_answer["query"]):
raise WikiException("Can't get token for page {}".format(title)) raise WikiException("Can't get token for page {}".format(title))
...@@ -78,7 +88,8 @@ class WikiClient: ...@@ -78,7 +88,8 @@ class WikiClient:
break break
if edit_token is None: if edit_token is None:
raise WikiException("Can't get token for page {}".format(title)) raise WikiException("Can't get token for page {}".format(title))
edit_answer = self.do_action(action="edit", method="post", data={"text": content}, self.do_action(
action="edit", method="post", data={"text": content},
token=edit_token, title=title, token=edit_token, title=title,
summary=summary, recreate=recreate, summary=summary, recreate=recreate,
createonly=createonly, bot=True) createonly=createonly, bot=True)
...@@ -89,18 +100,28 @@ class WikiClient: ...@@ -89,18 +100,28 @@ class WikiClient:
kwargs["action"] = action kwargs["action"] = action
kwargs["format"] = "json" kwargs["format"] = "json"
params = _filter_params(kwargs) params = _filter_params(kwargs)
def _do_request(): def _do_request():
if method == "get": if method == "get":
return requests.get(self.endpoint, cookies=self.cookies, params=params, auth=requests.auth.HTTPBasicAuth(self.user, self.password)) return requests.get(
self.endpoint, cookies=self.cookies, params=params,
auth=requests.auth.HTTPBasicAuth(self.user, self.password))
elif method == "post": elif method == "post":
return requests.post(self.endpoint, cookies=self.cookies, data=data, params=params, auth=requests.auth.HTTPBasicAuth(self.user, self.password)) return requests.post(
self.endpoint, cookies=self.cookies, data=data,
params=params, auth=requests.auth.HTTPBasicAuth(
self.user, self.password))
req = _do_request() req = _do_request()
if req.status_code != HTTP_STATUS_OK: if req.status_code != HTTP_STATUS_OK:
raise WikiException("HTTP status code {} on action {}.".format(req.status_code, action)) raise WikiException(
"HTTP status code {} on action {}.".format(
req.status_code, action))
self.cookies.update(req.cookies) self.cookies.update(req.cookies)
return req.json() return req.json()
def main(): def main():
with WikiClient() as client: with WikiClient() as client:
client.edit_page(title="Test", content="This is a very long text.", summary="API client test") client.edit_page(
title="Test", content="This is a very long text.",
summary="API client test")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment