From 4bd3d253f3f71f8809df0940a7c7e5ba7e625883 Mon Sep 17 00:00:00 2001
From: Robin Sonnabend <robin@fsmpi.rwth-aachen.de>
Date: Fri, 2 Mar 2018 16:39:18 +0100
Subject: [PATCH] make auth backends optional

Only define authentication backends that require external packages if
their required modules have been found.
There is no need to install packages for unused backends.
---
 auth.py | 274 ++++++++++++++++++++++++++++++--------------------------
 1 file changed, 146 insertions(+), 128 deletions(-)

diff --git a/auth.py b/auth.py
index 0323a2f..64e8fdd 100644
--- a/auth.py
+++ b/auth.py
@@ -1,12 +1,11 @@
 import hmac, hashlib
 import ssl
-import ldap3
-from ldap3.utils.dn import parse_dn
 from datetime import datetime
-import grp, pwd, pam
+
 
 class User:
-    def __init__(self, username, groups, timestamp=None, obsolete=False, permanent=False):
+    def __init__(self, username, groups, timestamp=None, obsolete=False,
+            permanent=False):
         self.username = username
         self.groups = groups
         if timestamp is not None:
@@ -17,7 +16,8 @@ class User:
         self.permanent = permanent
 
     def summarize(self):
-        return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups), str(self.timestamp.timestamp()), self.obsolete, self.permanent)
+        return "{}:{}:{}:{}:{}".format(self.username, ",".join(self.groups),
+            str(self.timestamp.timestamp()), self.obsolete, self.permanent)
 
     @staticmethod
     def from_summary(summary):
@@ -36,6 +36,7 @@ class User:
         summary, hash = secure_string.split("=", 1)
         return User.from_summary(summary)
 
+
 class UserManager:
     def __init__(self, backends):
         self.backends = backends
@@ -44,7 +45,8 @@ class UserManager:
         for backend in self.backends:
             if backend.authenticate(username, password):
                 groups = sorted(list(set(backend.groups(username, password))))
-                return User(username, groups, obsolete=backend.obsolete, permanent=permanent)
+                return User(username, groups, obsolete=backend.obsolete,
+                    permanent=permanent)
         return None
 
     def all_groups(self):
@@ -52,89 +54,33 @@ class UserManager:
             yield from backend.all_groups()
 
 
-class LdapManager:
-    def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True, obsolete=False):
-        self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
-        self.user_dn = user_dn
-        self.group_dn = group_dn
-        self.obsolete = obsolete
-
-    def authenticate(self, username, password):
-        try:
-            connection = ldap3.Connection(self.server, self.user_dn.format(username), password)
-            return connection.bind()
-        except ldap3.core.exceptions.LDAPSocketOpenError:
-            return False
-
-    def groups(self, username, password=None):
-        connection = ldap3.Connection(self.server)
-        obj_def = ldap3.ObjectDef("posixgroup", connection)
-        group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
-        username = username.lower()
-        for group in group_reader.search():
-            members = group.memberUid.value
-            if members is not None and username in members:
-                yield group.cn.value
+class SecurityManager:
+    def __init__(self, key, max_duration=300):
+        self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
+        self.max_duration = max_duration
 
-    def all_groups(self):
-        connection = ldap3.Connection(self.server)
-        obj_def = ldap3.ObjectDef("posixgroup", connection)
-        group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
-        for group in group_reader.search():
-            yield group.cn.value
-
-
-class ADManager:
-    def __init__(self, host, domain, user_dn, group_dn,
-        port=636, use_ssl=True, ca_cert=None, obsolete=False):
-        tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
-        if ca_cert is not None:
-            tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED,
-                ca_certs_file=ca_cert)
-        self.server = ldap3.Server(host, port=port, use_ssl=use_ssl,
-            tls=tls_config)
-        self.domain = domain
-        self.user_dn = user_dn
-        self.group_dn = group_dn
-        self.obsolete = obsolete
+    def hash_user(self, user):
+        maccer = self.maccer.copy()
+        summary = user.summarize()
+        maccer.update(summary.encode("utf-8"))
+        return "{}={}".format(summary, maccer.hexdigest())
 
-    def prepare_connection(self, username=None, password=None):
-        if username is not None and password is not None:
-            ad_user = "{}\\{}".format(self.domain, username)
-            return ldap3.Connection(self.server, ad_user, password)
-        return ldap3.Connection(self.server)
-        
-    def authenticate(self, username, password):
-        try:
-            return self.prepare_connection(username, password).bind()
-        except ldap3.core.exceptions.LDAPSocketOpenError:
+    def check_user(self, string):
+        parts = string.split("=", 1)
+        if len(parts) != 2:
+            # wrong format, expecting summary:hash
             return False
-
-    def groups(self, username, password):
-        connection = self.prepare_connection(username, password)
-        connection.bind()
-        obj_def = ldap3.ObjectDef("user", connection)
-        name_filter = "cn:={}".format(username)
-        user_reader = ldap3.Reader(connection, obj_def, self.user_dn, name_filter)
-        group_def = ldap3.ObjectDef("group", connection)
-        def _yield_recursive_groups(group_dn):
-            group_reader = ldap3.Reader(connection, group_def, group_dn, None)
-            for entry in group_reader.search():
-                yield entry.name.value
-                for child in entry.memberOf:
-                    yield from _yield_recursive_groups(child)
-        for result in user_reader.search():
-            for group_dn in result.memberOf:
-                yield from _yield_recursive_groups(group_dn)
-
-
-    def all_groups(self):
-        connection = self.prepare_connection()
-        connection.bind()
-        obj_def = ldap3.ObjectDef("group", connection)
-        group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
-        for result in reader.search():
-            yield result.name.value
+        summary, hash = map(lambda s: s.encode("utf-8"), parts)
+        maccer = self.maccer.copy()
+        maccer.update(summary)
+        user = User.from_hashstring(string)
+        if user is None:
+            return False
+        session_duration = datetime.now() - user.timestamp
+        macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"),
+            hash)
+        time_short = int(session_duration.total_seconds()) < self.max_duration
+        return macs_equal and (time_short or user.permanent)
 
 
 class StaticUserManager:
@@ -161,49 +107,121 @@ class StaticUserManager:
         yield from list(set(group for group in groups.values()))
 
 
-class PAMManager:
-    def __init__(self, obsolete=False):
-        self.pam = pam.pam()
-        self.obsolete = obsolete
+try:
+    import ldap3
+    from ldap3.utils.dn import parse_dn
+
+    class LdapManager:
+        def __init__(self, host, user_dn, group_dn, port=636, use_ssl=True,
+                obsolete=False):
+            self.server = ldap3.Server(host, port=port, use_ssl=use_ssl)
+            self.user_dn = user_dn
+            self.group_dn = group_dn
+            self.obsolete = obsolete
+
+        def authenticate(self, username, password):
+            try:
+                connection = ldap3.Connection(self.server,
+                    self.user_dn.format(username), password)
+                return connection.bind()
+            except ldap3.core.exceptions.LDAPSocketOpenError:
+                return False
+
+        def groups(self, username, password=None):
+            connection = ldap3.Connection(self.server)
+            obj_def = ldap3.ObjectDef("posixgroup", connection)
+            group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
+            username = username.lower()
+            for group in group_reader.search():
+                members = group.memberUid.value
+                if members is not None and username in members:
+                    yield group.cn.value
+
+        def all_groups(self):
+            connection = ldap3.Connection(self.server)
+            obj_def = ldap3.ObjectDef("posixgroup", connection)
+            group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
+            for group in group_reader.search():
+                yield group.cn.value
 
-    def authenticate(self, username, password):
-        return self.pam.authenticate(username, password)
 
-    def groups(self, username, password=None):
-        print(username)
-        yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name
-        for group in grp.getgrall():
-            if username in group.gr_mem:
+    class ADManager:
+        def __init__(self, host, domain, user_dn, group_dn,
+            port=636, use_ssl=True, ca_cert=None, obsolete=False):
+            tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED)
+            if ca_cert is not None:
+                tls_config = ldap3.Tls(validate=ssl.CERT_REQUIRED,
+                    ca_certs_file=ca_cert)
+            self.server = ldap3.Server(host, port=port, use_ssl=use_ssl,
+                tls=tls_config)
+            self.domain = domain
+            self.user_dn = user_dn
+            self.group_dn = group_dn
+            self.obsolete = obsolete
+
+        def prepare_connection(self, username=None, password=None):
+            if username is not None and password is not None:
+                ad_user = "{}\\{}".format(self.domain, username)
+                return ldap3.Connection(self.server, ad_user, password)
+            return ldap3.Connection(self.server)
+
+        def authenticate(self, username, password):
+            try:
+                return self.prepare_connection(username, password).bind()
+            except ldap3.core.exceptions.LDAPSocketOpenError:
+                return False
+
+        def groups(self, username, password):
+            connection = self.prepare_connection(username, password)
+            connection.bind()
+            obj_def = ldap3.ObjectDef("user", connection)
+            name_filter = "cn:={}".format(username)
+            user_reader = ldap3.Reader(connection, obj_def, self.user_dn,
+                name_filter)
+            group_def = ldap3.ObjectDef("group", connection)
+            def _yield_recursive_groups(group_dn):
+                group_reader = ldap3.Reader(connection, group_def, group_dn, None)
+                for entry in group_reader.search():
+                    yield entry.name.value
+                    for child in entry.memberOf:
+                        yield from _yield_recursive_groups(child)
+            for result in user_reader.search():
+                for group_dn in result.memberOf:
+                    yield from _yield_recursive_groups(group_dn)
+
+
+        def all_groups(self):
+            connection = self.prepare_connection()
+            connection.bind()
+            obj_def = ldap3.ObjectDef("group", connection)
+            group_reader = ldap3.Reader(connection, obj_def, self.group_dn)
+            for result in reader.search():
+                yield result.name.value
+except ModuleNotFoundError:
+    pass
+
+
+try:
+    import grp, pwd, pam
+
+    class PAMManager:
+        def __init__(self, obsolete=False):
+            self.pam = pam.pam()
+            self.obsolete = obsolete
+
+        def authenticate(self, username, password):
+            return self.pam.authenticate(username, password)
+
+        def groups(self, username, password=None):
+            print(username)
+            yield grp.getgrgid(pwd.getpwnam(username).pw_gid).gr_name
+            for group in grp.getgrall():
+                if username in group.gr_mem:
+                    yield group.gr_name
+
+        def all_groups(self):
+            for group in grp.getgrall():
                 yield group.gr_name
-
-    def all_groups(self):
-        for group in grp.getgrall():
-            yield group.gr_name
-
-class SecurityManager:
-    def __init__(self, key, max_duration=300):
-        self.maccer = hmac.new(key.encode("utf-8"), digestmod=hashlib.sha512)
-        self.max_duration = max_duration
-
-    def hash_user(self, user):
-        maccer = self.maccer.copy()
-        summary = user.summarize()
-        maccer.update(summary.encode("utf-8"))
-        return "{}={}".format(summary, maccer.hexdigest())
-
-    def check_user(self, string):
-        parts = string.split("=", 1)
-        if len(parts) != 2:
-            # wrong format, expecting summary:hash
-            return False
-        summary, hash = map(lambda s: s.encode("utf-8"), parts)
-        maccer = self.maccer.copy()
-        maccer.update(summary)
-        user = User.from_hashstring(string)
-        if user is None:
-            return False
-        session_duration = datetime.now() - user.timestamp
-        macs_equal = hmac.compare_digest(maccer.hexdigest().encode("utf-8"), hash)
-        time_short = int(session_duration.total_seconds()) < self.max_duration 
-        return macs_equal and (time_short or user.permanent)
+except ModuleNotFoundError:
+    pass
 
-- 
GitLab