From 1482d1ea92e9ea07b823ed30231218b641d83662 Mon Sep 17 00:00:00 2001
From: Julian Rother <julianr@fsmpi.rwth-aachen.de>
Date: Fri, 9 Sep 2016 23:13:08 +0200
Subject: [PATCH] Refactore db and ldap wrapper

The functionality of query() is now split into query() and modify().
For all queries that modify the db (like INSERT and UPDATE) the latter
must be used. This simplifies the interface (e.g. the return value) and
later allows query caching.
---
 db.py       | 174 ++++++++++++++++++++++++++++------------------------
 importer.py |   2 +-
 server.py   |  26 ++++----
 3 files changed, 106 insertions(+), 96 deletions(-)

diff --git a/db.py b/db.py
index c0c13e0..056c41b 100644
--- a/db.py
+++ b/db.py
@@ -1,74 +1,82 @@
 from server import *
-import sqlite3
 import re
 
 if config['DB_ENGINE'] == 'sqlite':
-	created = not os.path.exists(config['SQLITE_DB'])
-	db = sqlite3.connect(config['SQLITE_DB'])
-	cur = db.cursor()
-	if config['SQLITE_INIT_SCHEMA']:
-		cur.executescript(open(config['DB_SCHEMA']).read())
-	if config['SQLITE_INIT_DATA'] and created:
-		cur.executescript(open(config['DB_DATA']).read())
-	db.commit()
-	db.close()
-
-# Row wrapper for sqlite
-def dict_factory(cursor, row):
-	d = {}
-	for idx, col in enumerate(cursor.description):
-		if type(row[idx]) == str:
-			d[col[0].split('.')[-1]] = row[idx].replace('\\n','\n').replace('\\r','\r')
-		else:
-			d[col[0].split('.')[-1]] = row[idx]
-	return d
-
-# From sqlite3 module, but with error catching
-def convert_timestamp(val):
-	try:
-		datepart, timepart = val.split(b" ")
-		year, month, day = map(int, datepart.split(b"-"))
-		timepart_full = timepart.split(b".")
-		hours, minutes, seconds = map(int, timepart_full[0].split(b":"))
-		val = datetime(year, month, day, hours, minutes, seconds, 0)
-	except ValueError:
-		val = None
-	return val
-
-sqlite3.register_converter('datetime', convert_timestamp)
-sqlite3.register_converter('timestamp', convert_timestamp)
-
-def query(operation, *params):
-	_params = []
-	for p in params:
-		if isinstance(p, datetime):
-			p = p.replace(microsecond=0)
-		_params.append(p)
-	params = _params
-	if config['DB_ENGINE'] == 'mysql':
-		import mysql.connector
-		if 'db' not in g or not g.db.is_connected():
-			g.db = mysql.connector.connect(user=config['MYSQL_USER'], password=config['MYSQL_PASSWD'], host=config['MYSQL_HOST'], database=config['MYSQL_DB'])
-		if not hasattr(request, 'db'):
-			request.db = g.db.cursor(dictionary=True)
-		request.db.execute(operation.replace('?', '%s'), params)
-	elif config['DB_ENGINE'] == 'sqlite':
+	import sqlite3
+
+	# From sqlite3 module, but with error catching
+	def convert_timestamp(val):
+		try:
+			datepart, timepart = val.split(b" ")
+			year, month, day = map(int, datepart.split(b"-"))
+			timepart_full = timepart.split(b".")
+			hours, minutes, seconds = map(int, timepart_full[0].split(b":"))
+			val = datetime(year, month, day, hours, minutes, seconds, 0)
+		except ValueError:
+			val = None
+		return val
+
+	sqlite3.register_converter('datetime', convert_timestamp)
+	sqlite3.register_converter('timestamp', convert_timestamp)
+
+	if config['DB_ENGINE'] == 'sqlite':
+		created = not os.path.exists(config['SQLITE_DB'])
+		db = sqlite3.connect(config['SQLITE_DB'])
+		cur = db.cursor()
+		if config['SQLITE_INIT_SCHEMA']:
+			cur.executescript(open(config['DB_SCHEMA']).read())
+		if config['SQLITE_INIT_DATA'] and created:
+			cur.executescript(open(config['db_data']).read())
+		db.commit()
+		db.close()
+
+	def get_dbcursor():
 		if 'db' not in g:
 			g.db = sqlite3.connect(config['SQLITE_DB'], detect_types=sqlite3.PARSE_DECLTYPES)
-			g.db.row_factory = dict_factory
 			g.db.isolation_level = None
 		if not hasattr(request, 'db'):
 			request.db = g.db.cursor()
-		request.db.execute(operation, params)
-	else:
-		return []
-	try:
-		rows = request.db.fetchall()
-	except:
-		rows = []
-	if not rows and request.db.lastrowid != None:
-		return request.db.lastrowid
-	return rows
+		return request.db
+
+	def fix_query(operation, params):
+		params = [(p.replace(microsecond=0) if isinstance(p, datetime) else p) for p in params]
+		return operation, params
+
+elif config['DB_ENGINE'] == 'mysql':
+	import mysql.connector
+
+	def get_dbcursor():
+		if 'db' not in g or not g.db.is_connected():
+			g.db = mysql.connector.connect(user=config['MYSQL_USER'], password=config['MYSQL_PASSWD'], host=config['MYSQL_HOST'], database=config['MYSQL_DB'])
+		if not hasattr(request, 'db'):
+			request.db = g.db.cursor()
+		return request.db
+
+	def fix_query(operation, params):
+		operation = operation.replace('?', '%s')
+		params = [(p.replace(microsecond=0) if isinstance(p, datetime) else p) for p in params]
+		return operation, params
+
+def query(operation, *params, delim="sep"):
+	operation, params = fix_query(operation, params)
+	cur = get_dbcursor()
+	cur.execute(operation, params)
+	rows = cur.fetchall()
+	res = []
+	for row in rows:
+		res.append({})
+		for col, desc in zip(row, cur.description):
+			name = desc[0].split('.')[-1]
+			if type(col) == str:
+				col = col.replace('\\n', '\n').replace('\\r', '\r')
+			res[-1][name] = col
+	return res
+
+def modify(operation, *params):
+	operation, params = fix_query(operation, params)
+	cur = get_dbcursor()
+	cur.execute(operation, params)
+	return cur.lastrowid
 
 @app.teardown_request
 def commit_db(*args):
@@ -94,15 +102,11 @@ def searchquery(text, columns, match, tables, suffix, *suffixparams):
 	return query(expr, *params, *suffixparams)
 
 LDAP_USERRE = re.compile(r'[^a-z0-9]')
-notldap = {
-	'videoag':('videoag', ['users','videoag'], {'uid': 'videoag', 'givenName': 'Video', 'sn': 'Geier'}),
-	'gustav':('passwort', ['users'], {'uid': 'gustav', 'givenName': 'Gustav', 'sn': 'Geier'})
-}
-
-def ldapauth(user, password):
-	user = LDAP_USERRE.sub(r'', user.lower())
-	if 'LDAP_HOST' in config:
-		import ldap3
+if 'LDAP_HOST' in config:
+	import ldap3
+
+	def ldapauth(user, password):
+		user = LDAP_USERRE.sub(r'', user.lower())
 		try:
 			conn = ldap3.Connection(config['LDAP_HOST'], 'uid=%s,ou=users,dc=fsmpi,dc=rwth-aachen,dc=de'%user, password, auto_bind=True)
 			if conn.search("ou=groups,dc=fsmpi,dc=rwth-aachen,dc=de", "(&(cn=*)(memberUid=%s))"%user, attributes=['cn']):
@@ -111,14 +115,9 @@ def ldapauth(user, password):
 			return user, groups
 		except ldap3.core.exceptions.LDAPBindError:
 			pass
-	elif config.get('DEBUG') and user in notldap and password == notldap[user][0]:
-		return user, notldap[user][1]
-	return None, []
-
-def ldapget(user):
-	user = LDAP_USERRE.sub(r'', user.lower())
-	if 'LDAP_HOST' in config:
-		import ldap3
+
+	def ldapget(user):
+		user = LDAP_USERRE.sub(r'', user.lower())
 		conn = ldap3.Connection('ldaps://rumo.fsmpi.rwth-aachen.de', auto_bind=True)
 		conn.search("ou=users,dc=fsmpi,dc=rwth-aachen,dc=de", "(uid=%s)"%user,
 				attributes=ldap3.ALL_ATTRIBUTES)
@@ -126,6 +125,19 @@ def ldapget(user):
 			return {}
 		e = conn.entries[0]
 		return {'uid': user, 'givenName': e.givenName.value, 'sn':e.sn.value}
-	else:
-		return notldap[user][2]
 
+else:
+	notldap = {
+		'videoag':('videoag', ['users','videoag'], {'uid': 'videoag', 'givenName': 'Video', 'sn': 'Geier'}),
+		'gustav':('passwort', ['users'], {'uid': 'gustav', 'givenName': 'Gustav', 'sn': 'Geier'})
+	}
+
+	def ldapauth(user, password):
+		user = LDAP_USERRE.sub(r'', user.lower())
+		if config.get('DEBUG') and user in notldap and password == notldap[user][0]:
+			return user, notldap[user][1]
+		return None, []
+
+	def ldapget(user):
+		user = LDAP_USERRE.sub(r'', user.lower())
+		return notldap[user][2]
diff --git a/importer.py b/importer.py
index f98381f..b47f0f9 100755
--- a/importer.py
+++ b/importer.py
@@ -17,7 +17,7 @@ def import_from(source=None, id=None):
 	for i in campus:
 		if i.startswith('new'):
 			if campus[i]['url'] != '':
-				query('INSERT INTO import_campus (url, type, course_id, last_checked, changed) VALUES (?, ?, ?, ?, 1)',campus[i]['url'],campus[i]['type'],id,datetime.now())
+				modify('INSERT INTO import_campus (url, type, course_id, last_checked, changed) VALUES (?, ?, ?, ?, 1)',campus[i]['url'],campus[i]['type'],id,datetime.now())
 		else:
 			if campus[i]['url'] != '':
 				query('UPDATE import_campus SET url = ?, `type` = ? WHERE (course_id = ?) AND (id = ?)', campus[i]['url'],campus[i]['type'],id,int(i))	
diff --git a/server.py b/server.py
index 1c2ab2d..15812db 100755
--- a/server.py
+++ b/server.py
@@ -34,7 +34,7 @@ if config['DEBUG']:
 if not config.get('SECRET_KEY', None):
 	config['SECRET_KEY'] = os.urandom(24)
 
-from db import query, searchquery, ldapauth, ldapget, convert_timestamp
+from db import query, modify, searchquery, ldapauth, ldapget
 
 mod_endpoints = []
 
@@ -235,7 +235,7 @@ def login():
 	session['user'] = ldapget(user)
 	dbuser = query('SELECT * FROM users WHERE name = ?', user)
 	if not dbuser:
-		query('INSERT INTO users (name, realname, fsacc, level, calendar_key, rfc6238) VALUES (?, ?, ?, 1, "", "")', user, session['user']['givenName'], user)
+		modify('INSERT INTO users (name, realname, fsacc, level, calendar_key, rfc6238) VALUES (?, ?, ?, 1, "", "")', user, session['user']['givenName'], user)
 		dbuser = query('SELECT * FROM users WHERE name = ?', user)
 	session['user']['dbid'] = dbuser[0]['id']
 	return redirect(request.values.get('ref', url_for('index')))
@@ -259,7 +259,7 @@ def edit(prefix="", ignore=[]):
 		'chapters': ('chapters', 'id', ['time', 'text', 'visible', 'deleted']),
 		'announcements': ('announcements', 'id', ['text', 'internal', 'level', 'visible', 'deleted'])
 	}
-	query('BEGIN')
+	modify('BEGIN')
 	if request.is_json:
 		changes = request.get_json().items()
 	else:
@@ -267,21 +267,19 @@ def edit(prefix="", ignore=[]):
 	for key, val in changes:
 		if key in ignore:
 			continue
-		print('edit:', key, val)
 		key = prefix+key
-		print (key,val)
 		table, id, column = key.split('.', 2)
 		assert table in tabs
 		assert column in tabs[table][2]
-		query('INSERT INTO changelog ("table",id_value,id_key,field,value_new,value_old,"when",who,executed) VALUES (?,?,?,?,?,(SELECT %s FROM %s WHERE %s = ?),?,?,1)'%(column,tabs[table][0],tabs[table][1]),table,id,tabs[table][1],column,val,id,datetime.now(),session['user']['givenName'])
-		query('UPDATE %s SET %s = ?, time_updated = ? WHERE %s = ?'%(tabs[table][0], column, tabs[table][1]), val, datetime.now(), id)
-	query('COMMIT')
+		modify('INSERT INTO changelog ("table",id_value,id_key,field,value_new,value_old,"when",who,executed) VALUES (?,?,?,?,?,(SELECT %s FROM %s WHERE %s = ?),?,?,1)'%(column,tabs[table][0],tabs[table][1]),table,id,tabs[table][1],column,val,id,datetime.now(),session['user']['givenName'])
+		modify('UPDATE %s SET %s = ?, time_updated = ? WHERE %s = ?'%(tabs[table][0], column, tabs[table][1]), val, datetime.now(), id)
+	modify('COMMIT')
 	return "OK", 200
 
 @app.route('/newcourse', methods=['GET', 'POST'])
 @mod_required
 def new_course():
-	id = query('''
+	id = modify('''
 		INSERT INTO courses_data
 			(visible, title, short, handle, organizer, subject, created_by, time_created,
 			 time_updated, semester, settings, description, internal, responsible, feed_url)
@@ -296,7 +294,7 @@ def new_course():
 @app.route('/newlecture/<courseid>', methods=['GET', 'POST'])
 @mod_required
 def new_lecture(courseid):
-	id = query('''
+	id = modify('''
 		INSERT INTO lectures_data
 			(course_id, visible, drehplan, title, comment, internal, speaker, place,
 				time, time_created, time_updated, jumplist, titlefile)
@@ -344,7 +342,7 @@ def auth(): # For use with nginx auth_request
 	if not types[0] or allowed or ismod() or \
 			(auth and check_mod(*ldapauth(auth.username, auth.password))):
 		return 'OK', 200
-		query('INSERT INTO log VALUES (?, "", ?, "video", ?, ?)', ip, datetime.now(), videos[0]['id'], url)
+		modify('INSERT INTO log VALUES (?, "", ?, "video", ?, ?)', ip, datetime.now(), videos[0]['id'], url)
 	elif 'password' in types:
 		return Response("Login required", 401, {'WWW-Authenticate': 'Basic realm="Login Required"'})
 	return "Not allowed", 403
@@ -375,7 +373,7 @@ def suggest_chapter(lectureid):
 	submitter = None
 	if not ismod():
 		submitter = request.environ['REMOTE_ADDR']
-	id = query('INSERT INTO chapters (lecture_id, time, text, time_created, time_updated, created_by, submitted_by) VALUES (?, ?, ?, ?, ?, ?, ?)',
+	id = modify('INSERT INTO chapters (lecture_id, time, text, time_created, time_updated, created_by, submitted_by) VALUES (?, ?, ?, ?, ?, ?, ?)',
 				lectureid, time, text, datetime.now(), datetime.now(), session.get('user', {'dbid':None})['dbid'], submitter)
 	if 'ref' in request.values:
 		return redirect(request.values['ref'])
@@ -384,7 +382,7 @@ def suggest_chapter(lectureid):
 @app.route('/newpsa', methods=['POST', 'GET'])
 @mod_required
 def new_announcement():
-	id = query('INSERT INTO announcements (text, internal, time_created, time_updated, created_by) VALUES ("Neue Ankündigung", "", ?, ?, ?)',
+	id = modify('INSERT INTO announcements (text, internal, time_created, time_updated, created_by) VALUES ("Neue Ankündigung", "", ?, ?, ?)',
 			datetime.now(), datetime.now(), session.get('user', {'dbid':None})['dbid'])
 	if 'ref' in request.values:
 		return redirect(request.values['ref'])
@@ -393,7 +391,7 @@ def new_announcement():
 @app.route('/newfeatured', methods=['POST', 'GET'])
 @mod_required
 def new_featured():
-	id = query('INSERT INTO featured (time_created, time_updated, created_by) VALUES (?, ?, ?)',
+	id = modify('INSERT INTO featured (time_created, time_updated, created_by) VALUES (?, ?, ?)',
 			datetime.now(), datetime.now(), session.get('user', {'dbid':None})['dbid'])
 	if 'ref' in request.values:
 		return redirect(request.values['ref'])
-- 
GitLab