Commit 1520fe6c authored by Robin Sonnabend's avatar Robin Sonnabend
Browse files

Move protect_csrf and db_lookup to common-web

parent 7879d04d
Subproject commit 3dfc2b71eb6a7e0746fe8794854f739c3305587b Subproject commit c7eac74cf1f7e03b06f255cfd01e54162a2b9631
from flask import request, flash, abort
from functools import wraps from functools import wraps
from hmac import compare_digest
from flask import flash
from models.database import ALL_MODELS from models.database import ALL_MODELS
from shared import current_user from shared import current_user
from utils import get_csrf_token
from common import back from common import back
ID_KEY = "id"
KEY_NOT_PRESENT_MESSAGE = "Missing {}_id."
OBJECT_DOES_NOT_EXIST_MESSAGE = "There is no {} with id {}."
MISSING_VIEW_RIGHT = "Dir fehlenden die nötigen Zugriffsrechte."
def default_redirect(): def default_redirect():
return back.redirect() return back.redirect()
...@@ -23,29 +15,7 @@ def login_redirect(): ...@@ -23,29 +15,7 @@ def login_redirect():
return back.redirect("login") return back.redirect("login")
def db_lookup(*models, check_exists=True): MISSING_VIEW_RIGHT = "Dir fehlenden die nötigen Zugriffsrechte."
def _decorator(function):
@wraps(function)
def _decorated_function(*args, **kwargs):
for model in models:
key = model.__model_name__
id_key = "{}_{}".format(key, ID_KEY)
if id_key not in kwargs:
flash(KEY_NOT_PRESENT_MESSAGE.format(key), "alert-error")
return default_redirect()
obj_id = kwargs[id_key]
obj = model.query.filter_by(id=obj_id).first()
if check_exists and obj is None:
model_name = model.__class__.__name__
flash(OBJECT_DOES_NOT_EXIST_MESSAGE.format(
model_name, obj_id),
"alert-error")
return default_redirect()
kwargs[key] = obj
kwargs.pop(id_key)
return function(*args, **kwargs)
return _decorated_function
return _decorator
def require_right(right, require_exist): def require_right(right, require_exist):
...@@ -92,14 +62,3 @@ def require_publish_right(require_exist=True): ...@@ -92,14 +62,3 @@ def require_publish_right(require_exist=True):
def require_admin_right(require_exist=True): def require_admin_right(require_exist=True):
return require_right("admin", require_exist) return require_right("admin", require_exist)
def protect_csrf(function):
@wraps(function)
def _decorated_function(*args, **kwargs):
token = request.args.get("csrf_token")
true_token = get_csrf_token()
if token is None or not compare_digest(token, true_token):
abort(400)
return function(*args, **kwargs)
return _decorated_function
...@@ -31,9 +31,8 @@ from shared import ( ...@@ -31,9 +31,8 @@ from shared import (
from utils import ( from utils import (
get_first_unused_int, get_etherpad_text, split_terms, optional_int_arg, get_first_unused_int, get_etherpad_text, split_terms, optional_int_arg,
fancy_join, footnote_hash, get_git_revision, get_max_page_length_exp, fancy_join, footnote_hash, get_git_revision, get_max_page_length_exp,
get_internal_filename, get_csrf_token, get_current_ip) get_internal_filename, get_current_ip)
from decorators import ( from decorators import (
db_lookup, protect_csrf,
require_private_view_right, require_modify_right, require_publish_right, require_private_view_right, require_modify_right, require_publish_right,
require_admin_right) require_admin_right)
from models.database import ( from models.database import (
...@@ -56,6 +55,9 @@ from views.tables import ( ...@@ -56,6 +55,9 @@ from views.tables import (
TodoMailsTable, DefaultMetasTable, DecisionCategoriesTable) TodoMailsTable, DefaultMetasTable, DecisionCategoriesTable)
from legacy import import_old_todos, import_old_protocols, import_old_todomails from legacy import import_old_todos, import_old_protocols, import_old_todomails
from common import back from common import back
from common.csrf import protect_csrf, get_csrf_token
from common.database import db_lookup
app = Flask(__name__) app = Flask(__name__)
app.config.from_object(config) app.config.from_object(config)
......
from flask import request, session from flask import request
import random import random
import string import string
...@@ -14,8 +14,6 @@ import ipaddress ...@@ -14,8 +14,6 @@ import ipaddress
from socket import getfqdn from socket import getfqdn
from uuid import uuid4 from uuid import uuid4
import subprocess import subprocess
import os
import hashlib
import config import config
...@@ -265,9 +263,3 @@ def get_max_page_length_exp(objects): ...@@ -265,9 +263,3 @@ def get_max_page_length_exp(objects):
def get_internal_filename(protocol, document, filename): def get_internal_filename(protocol, document, filename):
return "{}-{}-{}".format(protocol.id, document.id, filename) return "{}-{}-{}".format(protocol.id, document.id, filename)
def get_csrf_token():
if "_csrf" not in session:
session["_csrf"] = hashlib.sha1(os.urandom(64)).hexdigest()
return session["_csrf"]
from flask import Markup, url_for from flask import Markup, url_for
from shared import date_filter, datetime_filter, time_filter, current_user from shared import date_filter, datetime_filter, time_filter, current_user
from utils import get_csrf_token from common.csrf import get_csrf_token
import config import config
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment