Skip to content
Snippets Groups Projects
Commit 6d9c3fdd authored by Robin Sonnabend's avatar Robin Sonnabend
Browse files

Improve *.py code quality

Also fixes some latent, rather unimportant bugs
parent 663a9d49
No related branches found
No related tags found
No related merge requests found
import hmac, hashlib
import hmac
import hashlib
import ssl
from datetime import datetime
......@@ -16,8 +17,9 @@ class User:
self.permanent = permanent
def summarize(self):
return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent)
return ":".join((
self.username, ",".join(self.groups),
str(self.timestamp.timestamp()), self.obsolete, self.permanent))
@staticmethod
def from_summary(summary):
......@@ -45,7 +47,8 @@ class UserManager:
for backend in self.backends:
if backend.authenticate(username, password):
groups = sorted(list(set(backend.groups(username, password))))
return User(username, groups, obsolete=backend.obsolete,
return User(
username, groups, obsolete=backend.obsolete,
permanent=permanent)
return None
......@@ -77,8 +80,8 @@ class SecurityManager:
if user is None:
return False
session_duration = datetime.now() - user.timestamp
macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"),
hash)
macs_equal = hmac.compare_digest(
maccer.hexdigest().encode("utf-8"), hash)
time_short = int(session_duration.total_seconds()) < self.max_duration
return macs_equal and (time_short or user.permanent)
......@@ -104,12 +107,11 @@ class StaticUserManager:
yield from self.group_map[username]
def all_groups(self):
yield from list(set(group for group in groups.values()))
yield from list(set(group for group in self.group_map.values()))
try:
import ldap3
from ldap3.utils.dn import parse_dn
class LdapManager:
def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
......@@ -121,8 +123,8 @@ try:
def authenticate(self, username, password):
try:
connection = ldap3.Connection(self.server,
self.user_dn.format(username), password)
connection = ldap3.Connection(
self.server, self.user_dn.format(username), password)
return connection.bind()
except ldap3.core.exceptions.LDAPSocketOpenError:
return False
......@@ -144,16 +146,15 @@ try:
for group in group_reader.search():
yield group.cn.value
class ADManager:
def __init__(self, host, domain, user_dn, group_dn,
port=636, use_ssl=True, ca_cert=None, obsolete=False):
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
if ca_cert is not None:
tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED,
ca_certs_file=ca_cert)
self.server = ldap3.Server(host, port=port, use_ssl=use_ssl,
tls=tls_config)
tls_config = ldap3.Tls(
validate=ssl.CERT_REQUIRED, ca_certs_file=ca_cert)
self.server = ldap3.Server(
host, port=port, use_ssl=use_ssl, tls=tls_config)
self.domain = domain
self.user_dn = user_dn
self.group_dn = group_dn
......@@ -176,11 +177,13 @@ try:
connection.bind()
obj_def = ldap3.ObjectDef("user", connection)
name_filter = "cn:={}".format(username)
user_reader = ldap3.Reader(connection, obj_def, self.user_dn,
name_filter)
user_reader = ldap3.Reader(
connection, obj_def, self.user_dn, name_filter)
group_def = ldap3.ObjectDef("group", connection)
def _yield_recursive_groups(group_dn):
group_reader = ldap3.Reader(connection, group_def, group_dn, None)
group_reader = ldap3.Reader(
connection, group_def, group_dn, None)
for entry in group_reader.search():
yield entry.name.value
for child in entry.memberOf:
......@@ -189,20 +192,22 @@ try:
for group_dn in result.memberOf:
yield from _yield_recursive_groups(group_dn)
def all_groups(self):
connection = self.prepare_connection()
connection.bind()
obj_def = ldap3.ObjectDef("group", connection)
group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
for result in reader.search():
for result in group_reader.search():
yield result.name.value
except ModuleNotFoundError:
pass
try:
import grp, pwd, pam
import grp
import pwd
import pam
class PAMManager:
def __init__(self, obsolete=False):
......@@ -224,4 +229,3 @@ try:
yield group.gr_name
except ModuleNotFoundError:
pass
......@@ -10,6 +10,7 @@ import config
cookie = getattr(config, "REDIRECT_BACK_COOKIE", "back")
default_view = getattr(config, "REDIRECT_BACK_DEFAULT", "index")
def anchor(func, cookie=cookie):
@functools.wraps(func)
def result(*args, **kwargs):
......@@ -17,8 +18,10 @@ def anchor(func, cookie=cookie):
return func(*args, **kwargs)
return result
def url(default=default_view, cookie=cookie, **url_args):
return session.get(cookie, url_for(default, **url_args))
def redirect(default=default_view, cookie=cookie, **url_args):
return flask_redirect(url(default, cookie, **url_args))
......@@ -2,15 +2,16 @@ from datetime import datetime, timedelta
import random
import quopri
from caldav import DAVClient, Principal, Calendar, Event
from caldav.lib.error import PropfindError
from caldav import DAVClient
from vobject.base import ContentLine
import config
class CalendarException(Exception):
pass
class Client:
def __init__(self, calendar=None, url=None):
if not config.CALENDAR_ACTIVE:
......@@ -23,9 +24,12 @@ class Client:
self.principal = self.client.principal()
break
except Exception as exc:
print("Got exception {} from caldav, retrying".format(str(exc)))
print("Got exception {} from caldav, retrying".format(
str(exc)))
if self.principal is None:
raise CalendarException("Got {} CalDAV-error from the CalDAV server.".format(config.CALENDAR_MAX_REQUESTS))
raise CalendarException(
"Got {} CalDAV-error from the CalDAV server.".format(
config.CALENDAR_MAX_REQUESTS))
if calendar is not None:
self.calendar = self.get_calendar(calendar)
else:
......@@ -41,9 +45,11 @@ class Client:
for calendar in self.principal.calendars()
]
except Exception as exc:
print("Got exception {} from caldav, retrying".format(str(exc)))
raise CalendarException("Got {} CalDAV Errors from the CalDAV server.".format(config.CALENDAR_MAX_REQUESTS))
print("Got exception {} from caldav, retrying".format(
str(exc)))
raise CalendarException(
"Got {} CalDAV Errors from the CalDAV server.".format(
config.CALENDAR_MAX_REQUESTS))
def get_calendar(self, calendar_name):
candidates = self.principal.calendars()
......@@ -57,12 +63,14 @@ class Client:
return
candidates = [
Event.from_raw_event(raw_event)
for raw_event in self.calendar.date_search(begin, begin + timedelta(hours=1))
for raw_event in self.calendar.date_search(
begin, begin + timedelta(hours=1))
]
candidates = [event for event in candidates if event.name == name]
event = None
if len(candidates) == 0:
event = Event(None, name, description, begin,
event = Event(
None, name, description, begin,
begin + timedelta(hours=config.CALENDAR_DEFAULT_DURATION))
vevent = self.calendar.add_event(event.to_vcal())
event.vevent = vevent
......@@ -76,11 +84,14 @@ NAME_KEY = "summary"
DESCRIPTION_KEY = "description"
BEGIN_KEY = "dtstart"
END_KEY = "dtend"
def _get_item(content, key):
if key in content:
return content[key][0].value
return None
class Event:
def __init__(self, vevent, name, description, begin, end):
self.vevent = vevent
......@@ -97,7 +108,8 @@ class Event:
description = _get_item(content, DESCRIPTION_KEY)
begin = _get_item(content, BEGIN_KEY)
end = _get_item(content, END_KEY)
return Event(vevent=vevent, name=name, description=description,
return Event(
vevent=vevent, name=name, description=description,
begin=begin, end=end)
def set_description(self, description):
......@@ -105,7 +117,8 @@ class Event:
self.description = description
encoded = encode_quopri(description)
if DESCRIPTION_KEY not in raw_event.contents:
raw_event.contents[DESCRIPTION_KEY] = [ContentLine(DESCRIPTION_KEY, {"ENCODING": ["QUOTED-PRINTABLE"]}, encoded)]
raw_event.contents[DESCRIPTION_KEY] = [ContentLine(
DESCRIPTION_KEY, {"ENCODING": ["QUOTED-PRINTABLE"]}, encoded)]
else:
content_line = raw_event.contents[DESCRIPTION_KEY][0]
content_line.value = encoded
......@@ -129,21 +142,28 @@ SUMMARY:{summary}
DESCRIPTION;ENCODING=QUOTED-PRINTABLE:{description}
END:VEVENT
END:VCALENDAR""".format(
uid=create_uid(), now=date_format(datetime.now()-offset),
begin=date_format(self.begin-offset), end=date_format(self.end-offset),
uid=create_uid(),
now=date_format(datetime.now() - offset),
begin=date_format(self.begin - offset),
end=date_format(self.end - offset),
summary=self.name,
description=encode_quopri(self.description))
def create_uid():
return str(random.randint(0, 1e10)).rjust(10, "0")
def date_format(dt):
return dt.strftime("%Y%m%dT%H%M%SZ")
def get_timezone_offset():
difference = datetime.now() - datetime.utcnow()
return timedelta(hours=round(difference.seconds / 3600 + difference.days * 24))
return timedelta(
hours=round(difference.seconds / 3600 + difference.days * 24))
def encode_quopri(text):
return quopri.encodestring(text.encode("utf-8")).replace(b"\n", b"=0A").decode("utf-8")
def encode_quopri(text):
return quopri.encodestring(text.encode("utf-8")).replace(
b"\n", b"=0A").decode("utf-8")
......@@ -3,14 +3,20 @@ import regex as re
import os
import sys
ROUTE_PATTERN = r'@(?:[[:alpha:]])+\.route\(\"(?<url>[^"]+)"[^)]*\)\s*(?:@[[:alpha:]_()., ]+\s*)*def\s+(?<name>[[:alpha:]][[:alnum:]_]*)\((?<params>[[:alnum:], ]*)\):'
ROUTE_PATTERN = (
r'@(?:[[:alpha:]])+\.route\(\"(?<url>[^"]+)"[^)]*\)\s*'
r'(?:@[[:alpha:]_()., ]+\s*)*def\s+(?<name>[[:alpha:]][[:alnum:]_]*)'
r'\((?<params>[[:alnum:], ]*)\):')
quote_group = "[\"']"
URL_FOR_PATTERN = r'url_for\({quotes}(?<name>[[:alpha:]][[:alnum:]_]*){quotes}'.format(quotes=quote_group)
URL_FOR_PATTERN = (
r'url_for\({quotes}(?<name>[[:alpha:]][[:alnum:]_]*)'
'{quotes}'.format(quotes=quote_group))
ROOT_DIR = "."
ENDINGS = [".py", ".html", ".txt"]
MAX_DEPTH = 2
def list_dir(dir, level=0):
if level >= MAX_DEPTH:
return
......@@ -25,6 +31,7 @@ def list_dir(dir, level=0):
elif os.path.isdir(path):
yield from list_dir(path, level + 1)
class Route:
def __init__(self, file, name, parameters):
self.file = file
......@@ -38,13 +45,15 @@ class Route:
def get_parameter_set(self):
return {parameter.name for parameter in self.parameters}
class Parameter:
def __init__(self, name, type=None):
self.name = name
self.type = type
def __repr__(self):
return "Parameter({name}, {type})".format(name=self.name, type=self.type)
return "Parameter({name}, {type})".format(
name=self.name, type=self.type)
@staticmethod
def from_string(text):
......@@ -53,6 +62,7 @@ class Parameter:
return Parameter(name, type)
return Parameter(text)
def split_url_parameters(url):
params = []
current_param = None
......@@ -68,9 +78,11 @@ def split_url_parameters(url):
current_param += char
return params
def split_function_parameters(parameters):
return list(map(str.strip, parameters.split(",")))
def read_url_for_parameters(content):
params = []
bracket_level = 1
......@@ -92,6 +104,7 @@ def read_url_for_parameters(content):
elif char == ")":
bracket_level -= 1
class UrlFor:
def __init__(self, file, name, parameters):
self.file = file
......@@ -99,8 +112,10 @@ class UrlFor:
self.parameters = parameters
def __repr__(self):
return "UrlFor(file={file}, name={name}, parameters={parameters})".format(
file=self.file, name=self.name, parameters=self.parameters)
return (
"UrlFor(file={file}, name={name}, parameters={parameters})".format(
file=self.file, name=self.name, parameters=self.parameters))
routes = {}
url_fors = []
......@@ -109,24 +124,29 @@ for file in list_dir(ROOT_DIR):
content = infile.read()
for match in re.finditer(ROUTE_PATTERN, content):
name = match.group("name")
function_parameters = split_function_parameters(match.group("params"))
function_parameters = split_function_parameters(
match.group("params"))
url_parameters = split_url_parameters(match.group("url"))
routes[name] = Route(file, name, url_parameters)
for match in re.finditer(URL_FOR_PATTERN, content):
name = match.group("name")
begin, end = match.span()
parameters = read_url_for_parameters(content[end:])
url_fors.append(UrlFor(file=file, name=name, parameters=parameters))
url_fors.append(UrlFor(
file=file, name=name, parameters=parameters))
for url_for in url_fors:
if url_for.name not in routes:
print("Missing route '{}' (for url_for in '{}')".format(url_for.name, url_for.file))
print("Missing route '{}' (for url_for in '{}')".format(
url_for.name, url_for.file))
continue
route = routes[url_for.name]
route_parameters = route.get_parameter_set()
url_parameters = set(url_for.parameters)
if len(route_parameters ^ url_parameters) > 0:
print("Parameters not matching for '{}' in '{}:'".format(url_for.name, url_for.file))
print("Parameters not matching for '{}' in '{}:'".format(
url_for.name, url_for.file))
only_route = route_parameters - url_parameters
only_url = url_parameters - route_parameters
if len(only_route) > 0:
......
from flask import redirect, flash, request, url_for
from flask import flash
from functools import wraps
from models.database import ALL_MODELS
from shared import db, current_user
from shared import current_user
import back
ID_KEY = "id"
......@@ -12,12 +12,15 @@ OBJECT_DOES_NOT_EXIST_MESSAGE = "There is no {} with id {}."
MISSING_VIEW_RIGHT = "Dir fehlenden die nötigen Zugriffsrechte."
def default_redirect():
return back.redirect()
def login_redirect():
return back.redirect("login")
def db_lookup(*models, check_exists=True):
def _decorator(function):
@wraps(function)
......@@ -32,7 +35,8 @@ def db_lookup(*models, check_exists=True):
obj = model.query.filter_by(id=obj_id).first()
if check_exists and obj is None:
model_name = model.__class__.__name__
flash(OBJECT_DOES_NOT_EXIST_MESSAGE.format(model_name, obj_id),
flash(OBJECT_DOES_NOT_EXIST_MESSAGE.format(
model_name, obj_id),
"alert-error")
return default_redirect()
kwargs[key] = obj
......@@ -41,8 +45,10 @@ def db_lookup(*models, check_exists=True):
return _decorated_function
return _decorator
def require_right(right, require_exist):
necessary_right_name = "has_{}_right".format(right)
def _decorator(function):
@wraps(function)
def _decorated_function(*args, **kwargs):
......@@ -65,17 +71,22 @@ def require_right(right, require_exist):
return _decorated_function
return _decorator
def require_public_view_right(require_exist=True):
return require_right("public_view", require_exist)
def require_private_view_right(require_exist=True):
return require_right("private_view", require_exist)
def require_modify_right(require_exist=True):
return require_right("modify", require_exist)
def require_publish_right(require_exist=True):
return require_right("publish", require_exist)
def require_admin_right(require_exist=True):
return require_right("admin", require_exist)
from datetime import datetime
from fuzzywuzzy import fuzz, process
import tempfile
from fuzzywuzzy import process
from models.database import Todo, OldTodo, Protocol, ProtocolType, TodoMail
from models.database import OldTodo, Protocol, ProtocolType, TodoMail
from shared import db
import config
def lookup_todo_id(old_candidates, new_who, new_description):
# Check for perfect matches
for candidate in old_candidates:
if candidate.who == new_who and candidate.description == new_description:
if (candidate.who == new_who
and candidate.description == new_description):
return candidate.old_id
# Accept if who has been changed
for candidate in old_candidates:
......@@ -32,11 +33,13 @@ def lookup_todo_id(old_candidates, new_who, new_description):
new_description, best_match, best_match_score))
return None
INSERT_PROTOCOLTYPE = "INSERT INTO `protocolManager_protocoltype`"
INSERT_PROTOCOL = "INSERT INTO `protocolManager_protocol`"
INSERT_TODO = "INSERT INTO `protocolManager_todo`"
INSERT_TODOMAIL = "INSERT INTO `protocolManager_todonamemailassignment`"
def import_old_protocols(sql_text):
protocoltype_lines = []
protocol_lines = []
......@@ -50,18 +53,23 @@ def import_old_protocols(sql_text):
raise ValueError("Necessary lines not found.")
type_id_to_handle = {}
for type_line in protocoltype_lines:
for id, handle, name, mail, protocol_id in _split_insert_line(type_line):
for id, handle, name, mail, protocol_id in _split_insert_line(
type_line):
type_id_to_handle[int(id)] = handle.lower()
protocols = []
for protocol_line in protocol_lines:
for (protocol_id, old_type_id, date, source, textsummary, htmlsummary,
deleted, sent, document_id) in _split_insert_line(protocol_line):
deleted, sent, document_id) in _split_insert_line(
protocol_line):
date = datetime.strptime(date, "%Y-%m-%d")
handle = type_id_to_handle[int(old_type_id)]
protocoltype = ProtocolType.query.filter(ProtocolType.short_name.ilike(handle)).first()
protocoltype = ProtocolType.query.filter(
ProtocolType.short_name.ilike(handle)).first()
if protocoltype is None:
raise KeyError("No protocoltype for handle '{}'.".format(handle))
protocol = Protocol(protocoltype_id=protocoltype.id, date=date, source=source)
raise KeyError(
"No protocoltype for handle '{}'.".format(handle))
protocol = Protocol(
protocoltype_id=protocoltype.id, date=date, source=source)
db.session.add(protocol)
db.session.commit()
import tasks
......@@ -70,6 +78,7 @@ def import_old_protocols(sql_text):
print(protocol.date)
tasks.parse_protocol(protocol)
def import_old_todomails(sql_text):
todomail_lines = []
for line in sql_text.splitlines():
......@@ -103,23 +112,29 @@ def import_old_todos(sql_text):
raise ValueError("Necessary lines not found.")
type_id_to_handle = {}
for type_line in protocoltype_lines:
for id, handle, name, mail, protocol_id in _split_insert_line(type_line):
for id, handle, name, mail, protocol_id in _split_insert_line(
type_line):
type_id_to_handle[int(id)] = handle.lower()
protocol_id_to_key = {}
for protocol_line in protocol_lines:
for (protocol_id, type_id, date, source, textsummary, htmlsummary,
deleted, sent, document_id) in _split_insert_line(protocol_line):
deleted, sent, document_id) in _split_insert_line(
protocol_line):
handle = type_id_to_handle[int(type_id)]
date_string = date[2:]
protocol_id_to_key[int(protocol_id)] = "{}-{}".format(handle, date_string)
protocol_id_to_key[int(protocol_id)] = "{}-{}".format(
handle, date_string)
todos = []
for todo_line in todo_lines:
for old_id, protocol_id, who, what, start_time, end_time, done in _split_insert_line(todo_line):
for (old_id, protocol_id, who, what, start_time, end_time,
done) in _split_insert_line(todo_line):
protocol_id = int(protocol_id)
if protocol_id not in protocol_id_to_key:
print("Missing protocol with ID {} for Todo {}".format(protocol_id, what))
print("Missing protocol with ID {} for Todo {}".format(
protocol_id, what))
continue
todo = OldTodo(old_id=old_id, who=who, description=what,
todo = OldTodo(
old_id=old_id, who=who, description=what,
protocol_key=protocol_id_to_key[protocol_id])
todos.append(todo)
OldTodo.query.delete()
......@@ -128,11 +143,15 @@ def import_old_todos(sql_text):
db.session.add(todo)
db.session.commit()
def _split_insert_line(line):
insert_part, values_part = line.split("VALUES", 1)
return _split_base_level(values_part)
def _split_base_level(text, begin="(", end=")", separator=",", string_terminator="'", line_end=";", ignore=" ", escape="\\"):
def _split_base_level(
text, begin="(", end=")", separator=",", string_terminator="'",
line_end=";", ignore=" ", escape="\\"):
raw_parts = []
current_part = None
index = 0
......@@ -210,5 +229,3 @@ def _split_base_level(text, begin="(", end=")", separator=",", string_terminator
fields.append(current_field)
parts.append(fields)
return parts
This diff is collapsed.
from flask_sqlalchemy import SQLAlchemy
from flask import session, redirect, url_for, request
from flask import session, redirect, url_for, flash
import re
from functools import wraps
......@@ -11,7 +11,8 @@ import config
db = SQLAlchemy()
# the following code is written by Lars Beckers and not to be published without permission
# the following code escape_tex is written by Lars Beckers
# and not to be published without permission
latex_chars = [
("\\", "\\backslash"), # this needs to be first
("$", "\$"),
......@@ -23,7 +24,6 @@ latex_chars = [
('}', '\\}'),
('[', '\\['),
(']', '\\]'),
#('"', '"\''),
('~', r'$\sim{}$'),
('^', r'\textasciicircum{}'),
('Ë„', r'\textasciicircum{}'),
......@@ -40,72 +40,96 @@ latex_chars = [
('\\backslash', '$\\backslash$') # this needs to be last
]
def escape_tex(text):
out = text
for old, new in latex_chars:
out = out.replace(old, new)
# beware, the following is carefully crafted code
res = ''
k, l = (0, -1)
while k >= 0:
k = out.find('"', l+1)
if k >= 0:
res += out[l+1:k]
l = out.find('"', k+1)
if l >= 0:
res += '\\enquote{' + out[k+1:l] + '}'
start, end = (0, -1)
while start >= 0:
start = out.find('"', end + 1)
if start >= 0:
res += out[end + 1:start]
end = out.find('"', start + 1)
if end >= 0:
res += '\\enquote{' + out[start + 1:end] + '}'
else:
res += '"\'' + out[k+1:]
k = l
res += '"\'' + out[start + 1:]
start = end
else:
res += out[l+1:]
res += out[end + 1:]
# yes, this is not quite escaping latex chars, but anyway...
res = re.sub('([a-z])\(', '\\1 (', res)
res = re.sub('\)([a-z])', ') \\1', res)
#logging.debug('escape latex ({0}/{1}): {2} --> {3}'.format(len(text), len(res), text.split('\n')[0], res.split('\n')[0]))
return res
def unhyphen(text):
return " ".join([r"\mbox{" + word + "}" for word in text.split(" ")])
def date_filter(date):
return date.strftime("%d. %B %Y")
def datetime_filter(date):
return date.strftime("%d. %B %Y, %H:%M")
def date_filter_long(date):
return date.strftime("%A, %d.%m.%Y, Kalenderwoche %W")
def date_filter_short(date):
return date.strftime("%d.%m.%Y")
def time_filter(time):
return time.strftime("%H:%M Uhr")
def time_filter_short(time):
return time.strftime("%H:%M")
def needs_date_test(todostate):
return todostate.needs_date()
def todostate_name_filter(todostate):
return todostate.get_name()
def indent_tab_filter(text):
return "\n".join(map(lambda l: "\t{}".format(l), text.splitlines()))
def class_filter(obj):
return obj.__class__.__name__
def code_filter(text):
return "<code>{}</code>".format(text)
from auth import UserManager, SecurityManager, User
max_duration = getattr(config, "AUTH_MAX_DURATION")
user_manager = UserManager(backends=config.AUTH_BACKENDS)
security_manager = SecurityManager(config.SECURITY_KEY, max_duration)
def check_login():
return "auth" in session and security_manager.check_user(session["auth"])
def current_user():
if not check_login():
return None
return User.from_hashstring(session["auth"])
def login_required(function):
@wraps(function)
def decorated_function(*args, **kwargs):
......@@ -115,6 +139,7 @@ def login_required(function):
return redirect(url_for("login"))
return decorated_function
def group_required(group):
def decorator(function):
@wraps(function)
......@@ -122,16 +147,19 @@ def group_required(group):
if group in current_user().groups:
return function(*args, **kwargs)
else:
flash("You do not have the necessary permissions to view this page.")
flash("You do not have the necessary permissions to "
"view this page.")
return back.redirect()
return decorated_function
return decorator
DATE_KEY = "Datum"
START_TIME_KEY = "Beginn"
END_TIME_KEY = "Ende"
KNOWN_KEYS = [DATE_KEY, START_TIME_KEY, END_TIME_KEY]
class WikiType(Enum):
MEDIAWIKI = 0
DOKUWIKI = 1
from flask import render_template, request
from flask import request
import random
import string
import regex
import math
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.application import MIMEApplication
from datetime import datetime, date, timedelta
from datetime import datetime
import requests
from io import BytesIO
import ipaddress
......@@ -18,12 +17,16 @@ import subprocess
import config
def random_string(length):
return "".join((random.choice(string.ascii_letters) for i in range(length)))
return "".join((random.choice(string.ascii_letters)
for i in range(length)))
def is_past(some_date):
return (datetime.now() - some_date).total_seconds() > 0
def encode_kwargs(kwargs):
encoded_kwargs = {}
for key in kwargs:
......@@ -34,6 +37,7 @@ def encode_kwargs(kwargs):
encoded_kwargs[key] = (type(value), value, False)
return encoded_kwargs
def decode_kwargs(encoded_kwargs):
kwargs = {}
for name in encoded_kwargs:
......@@ -45,27 +49,6 @@ def decode_kwargs(encoded_kwargs):
return kwargs
class UrlManager:
def __init__(self, config):
self.pattern = regex.compile(r"(?:(?<proto>https?):\/\/)?(?<hostname>[[:alnum:]_.]+(?:\:[[:digit:]]+)?)?(?<path>(?:\/[[:alnum:]_#]*)+)?(?:\?(?<params>.*))?")
self.base = "{}://{}{}{}"
self.proto = getattr(config, "URL_PROTO", "https")
self.root = getattr(config, "URL_ROOT", "example.com")
self.path = getattr(config, "URL_PATH", "/")
self.params = getattr(config, "URL_PARAMS", "")
def complete(self, url):
match = self.pattern.match(url)
if match is None:
return None
proto = match.group("proto") or self.proto
root = match.group("hostname") or self.root
path = match.group("path") or self.path
params = match.group("params") or self.params
return self.base.format(proto, root, path, "?" + params if len(params) > 0 else "")
#url_manager = UrlManager(config)
class MailManager:
def __init__(self, config):
self.active = getattr(config, "MAIL_ACTIVE", False)
......@@ -76,12 +59,17 @@ class MailManager:
self.use_tls = getattr(config, "MAIL_USE_TLS", True)
self.use_starttls = getattr(config, "MAIL_USE_STARTTLS", False)
def _get_smtp(self):
if self.use_tls:
return smtplib.SMTP_SSL
return smtplib.SMTP
def send(self, to_addr, subject, content, appendix=None, reply_to=None):
if (not self.active
or not self.hostname
or not self.from_addr):
return
msg = MIMEMultipart("mixed") # todo: test if clients accept attachment-free mails set to multipart/mixed
msg = MIMEMultipart("mixed")
msg["From"] = self.from_addr
msg["To"] = to_addr
msg["Subject"] = subject
......@@ -92,9 +80,10 @@ class MailManager:
if appendix is not None:
for name, file_like in appendix:
part = MIMEApplication(file_like.read(), "octet-stream")
part["Content-Disposition"] = 'attachment; filename="{}"'.format(name)
part["Content-Disposition"] = (
'attachment; filename="{}"'.format(name))
msg.attach(part)
server = (smtplib.SMTP_SSL if self.use_tls else smtplib.SMTP)(self.hostname)
server = self._get_smtp()(self.hostname)
if self.use_starttls:
server.starttls()
if self.username not in [None, ""] and self.password not in [None, ""]:
......@@ -102,8 +91,10 @@ class MailManager:
server.sendmail(self.from_addr, to_addr.split(","), msg.as_string())
server.quit()
mail_manager = MailManager(config)
def get_first_unused_int(numbers):
positive_numbers = [number for number in numbers if number >= 0]
if len(positive_numbers) == 0:
......@@ -114,24 +105,34 @@ def get_first_unused_int(numbers):
return linear
return highest + 1
def normalize_pad(pad):
return pad.replace(" ", "_")
def get_etherpad_url(pad):
return "{}/p/{}".format(config.ETHERPAD_URL, normalize_pad(pad))
def get_etherpad_export_url(pad):
return "{}/p/{}/export/txt".format(config.ETHERPAD_URL, normalize_pad(pad))
def get_etherpad_import_url(pad):
return "{}/p/{}/import".format(config.ETHERPAD_URL, normalize_pad(pad))
def get_etherpad_text(pad):
req = requests.get(get_etherpad_export_url(pad))
return req.text
def set_etherpad_text(pad, text, only_if_default=True):
print(pad)
if only_if_default:
current_text = get_etherpad_text(pad)
if current_text != config.EMPTY_ETHERPAD and len(current_text.strip()) > 0:
if (current_text != config.EMPTY_ETHERPAD
and len(current_text.strip()) > 0):
return False
file_like = BytesIO(text.encode("utf-8"))
files = {"file": file_like}
......@@ -140,6 +141,7 @@ def set_etherpad_text(pad, text, only_if_default=True):
req = requests.post(url, files=files)
return req.status_code == 200
def split_terms(text, quote_chars="\"'", separators=" \t\n"):
terms = []
in_quote = False
......@@ -169,12 +171,14 @@ def split_terms(text, quote_chars="\"'", separators=" \t\n"):
terms.append(current_term)
return terms
def optional_int_arg(name):
try:
return int(request.args.get(name))
except (ValueError, TypeError):
return None
def add_line_numbers(text):
raw_lines = text.splitlines()
linenumber_length = math.ceil(math.log10(len(raw_lines)) + 1)
......@@ -186,9 +190,11 @@ def add_line_numbers(text):
))
return "\n".join(lines)
def check_ip_in_networks(networks_string):
address = ipaddress.ip_address(request.remote_addr)
if address == ipaddress.ip_address("127.0.0.1") and "X-Real-Ip" in request.headers:
if (address == ipaddress.ip_address("127.0.0.1")
and "X-Real-Ip" in request.headers):
address = ipaddress.ip_address(request.headers["X-Real-Ip"])
try:
for network_string in networks_string.split(","):
......@@ -199,6 +205,7 @@ def check_ip_in_networks(networks_string):
except ValueError:
return False
def fancy_join(values, sep1=" und ", sep2=", "):
values = list(values)
if len(values) <= 1:
......@@ -207,9 +214,11 @@ def fancy_join(values, sep1=" und ", sep2=", "):
start = values[:-1]
return "{}{}{}".format(sep2.join(start), sep1, last)
def footnote_hash(text, length=5):
return str(sum(ord(c) * i for i, c in enumerate(text)) % 10**length)
def parse_datetime_from_string(text):
text = text.strip()
for format in ("%d.%m.%Y", "%d.%m.%y", "%Y-%m-%d",
......@@ -220,9 +229,10 @@ def parse_datetime_from_string(text):
pass
for format in ("%d.%m.", "%d. %m.", "%d.%m", "%d.%m"):
try:
return datetime.strptime(text, format).replace(year=datetime.now().year)
except ValueError as exc:
print(exc)
return datetime.strptime(text, format).replace(
year=datetime.now().year)
except ValueError:
pass
raise ValueError("Date '{}' does not match any known format!".format(text))
......@@ -248,4 +258,3 @@ def get_max_page_length_exp(objects):
def get_internal_filename(protocol, document, filename):
return "{}-{}-{}".format(protocol.id, document.id, filename)
from models.database import TodoState
from wtforms import ValidationError
from wtforms.validators import InputRequired
from shared import db
class CheckTodoDateByState:
def __init__(self):
......@@ -16,4 +16,3 @@ class CheckTodoDateByState:
date_check(form, form.date)
except ValueError:
raise ValidationError("Invalid state.")
import requests
import json
import config
HTTP_STATUS_OK = 200
HTTP_STATUS_AUTHENTICATE = 401
class WikiException(Exception):
pass
def _filter_params(params):
result = {}
for key, value in sorted(params.items(), key=lambda t: t[0] == "token"):
......@@ -19,14 +20,20 @@ def _filter_params(params):
result[key] = value
return result
class WikiClient:
def __init__(self, active=None, endpoint=None, anonymous=None, user=None, password=None, domain=None):
self.active = active if active is not None else config.WIKI_ACTIVE
self.endpoint = endpoint if endpoint is not None else config.WIKI_API_URL
self.anonymous = anonymous if anonymous is not None else config.WIKI_ANONYMOUS
self.user = user if user is not None else config.WIKI_USER
self.password = password if password is not None else config.WIKI_PASSWORD
self.domain = domain if domain is not None else config.WIKI_DOMAIN
def __init__(self, active=None, endpoint=None, anonymous=None, user=None,
password=None, domain=None):
def _or_default(value, default):
if value is None:
return default
return value
self.active = _or_default(active, config.WIKI_ACTIVE)
self.endpoint = _or_default(endpoint, config.WIKI_API_URL)
self.anonymous = _or_default(anonymous, config.WIKI_ANONYMOUS)
self.user = _or_default(user, config.WIKI_USER)
self.password = _or_default(password, config.WIKI_PASSWORD)
self.domain = _or_default(domain, config.WIKI_DOMAIN)
self.token = None
self.cookies = requests.cookies.RequestsCookieJar()
......@@ -45,28 +52,31 @@ class WikiClient:
def login(self):
if not self.active:
return
# todo: Change this to the new MediaWiki tokens api once the wiki is updated
# todo: Change this to the new MediaWiki tokens api
token_answer = self.do_action("login", method="post", lgname=self.user)
if "login" not in token_answer or "token" not in token_answer["login"]:
raise WikiException("No token in login answer.")
lgtoken = token_answer["login"]["token"]
login_answer = self.do_action("login", method="post", lgname=self.user, lgpassword=self.password, lgdomain=self.domain, lgtoken=lgtoken)
login_answer = self.do_action(
"login", method="post", lgname=self.user, lgpassword=self.password,
lgdomain=self.domain, lgtoken=lgtoken)
if ("login" not in login_answer
or "result" not in login_answer["login"]
or login_answer["login"]["result"] != "Success"):
raise WikiException("Login not successful.")
def logout(self):
if not self.active:
return
self.do_action("logout")
def edit_page(self, title, content, summary, recreate=True, createonly=False):
def edit_page(self, title, content, summary, recreate=True,
createonly=False):
if not self.active:
return
# todo: port to new api once the wiki is updated
prop_answer = self.do_action("query", method="get", prop="info", intoken="edit", titles=title)
prop_answer = self.do_action(
"query", method="get", prop="info", intoken="edit", titles=title)
if ("query" not in prop_answer
or "pages" not in prop_answer["query"]):
raise WikiException("Can't get token for page {}".format(title))
......@@ -78,7 +88,8 @@ class WikiClient:
break
if edit_token is None:
raise WikiException("Can't get token for page {}".format(title))
edit_answer = self.do_action(action="edit", method="post", data={"text": content},
self.do_action(
action="edit", method="post", data={"text": content},
token=edit_token, title=title,
summary=summary, recreate=recreate,
createonly=createonly, bot=True)
......@@ -89,18 +100,28 @@ class WikiClient:
kwargs["action"] = action
kwargs["format"] = "json"
params = _filter_params(kwargs)
def _do_request():
if method == "get":
return requests.get(self.endpoint, cookies=self.cookies, params=params, auth=requests.auth.HTTPBasicAuth(self.user, self.password))
return requests.get(
self.endpoint, cookies=self.cookies, params=params,
auth=requests.auth.HTTPBasicAuth(self.user, self.password))
elif method == "post":
return requests.post(self.endpoint, cookies=self.cookies, data=data, params=params, auth=requests.auth.HTTPBasicAuth(self.user, self.password))
return requests.post(
self.endpoint, cookies=self.cookies, data=data,
params=params, auth=requests.auth.HTTPBasicAuth(
self.user, self.password))
req = _do_request()
if req.status_code != HTTP_STATUS_OK:
raise WikiException("HTTP status code {} on action {}.".format(req.status_code, action))
raise WikiException(
"HTTP status code {} on action {}.".format(
req.status_code, action))
self.cookies.update(req.cookies)
return req.json()
def main():
with WikiClient() as client:
client.edit_page(title="Test", content="This is a very long text.", summary="API client test")
client.edit_page(
title="Test", content="This is a very long text.",
summary="API client test")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment