Skip to content
Snippets Groups Projects
Select Git revision
  • 83e94e102feaa8d265d3c91a8c68e968cbd9f6c6
  • master default protected
  • intros
  • live_sources
  • bootstrap4
  • modules
6 results

sorter.py

Blame
  • Forked from Video AG Infrastruktur / website
    Source project has a limited visibility.
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    drift_detector.py 21.89 KiB
    import re
    import sys
    import traceback
    from typing import Callable
    
    from sqlalchemy import Engine, Table, MetaData, Column, types, DefaultClause, TextClause, Constraint, CheckConstraint, \
        ForeignKeyConstraint, UniqueConstraint, PrimaryKeyConstraint
    from sqlalchemy.exc import *
    from sqlalchemy.sql.base import _NoneName, ReadOnlyColumnCollection
    from sqlalchemy.sql.schema import ColumnCollectionConstraint, ForeignKey
    from sqlalchemy.dialects import postgresql as postgresql
    
    from .utc_timestamp import UTCTimestamp
    
    
    #
    # This file attempts its best to detect any difference between the schema in the python files and the actual database.
    # 'Simple' differences can also be fixed. This includes:
    #  - Adding a missing table # TODO
    #  - Adding a missing type
    #  - Adding a missing column
    #
    # The following aspects are currently being checked with varying accuracy
    # - Table names
    # - Columns
    #   - Name
    #   - Type
    #   - Nullable
    #   - Primary Key
    #   - Auto Increment
    #   - Server default value (only if set in python schema, for autoincrement)
    #   - Unique
    #
    # What is NOT checked
    # - Collations
    # - Indices   TODO
    #
    # A lot of the code here is 'a bit messy' due to all the different dialects (This only really needs to work for postgres
    # but SQLAlchemy is designed to work with all of them). Especially the type comparison. If you use new types, you might
    # need to tweak some stuff here
    #
    
    
    # Improved string output
    def _to_string(value) -> str:
        match value:
            case DefaultClause() as clause:
                return f"DefaultClause({_to_string(clause.arg)}, for_update={clause.for_update})"
            
            case TextClause() as clause:
                return f"TextClause(text='{clause.text}')"  # Because TextClause doesn't print its text by default
            
            case PrimaryKeyConstraint() as constraint:
                assert isinstance(constraint, PrimaryKeyConstraint)
                return f"PrimaryKeyConstraint({_to_string(constraint.columns)}, table.name='{constraint.table.name}')"
            
            case CheckConstraint() as constraint:
                assert isinstance(constraint, CheckConstraint)
                return f"CheckConstraint({_to_string(constraint.columns)}, sqltext='{constraint.sqltext}', table.name='{constraint.table.name}')"
            
            case UniqueConstraint() as constraint:
                assert isinstance(constraint, UniqueConstraint)
                return f"UniqueConstraint({_to_string(constraint.columns)}, table.name='{constraint.table.name}')"
            
            case ForeignKeyConstraint() as constraint:
                assert isinstance(constraint, ForeignKeyConstraint)
                return (f"ForeignKeyConstraint("
                        f"columns={_to_string(constraint.columns)}, "
                        f"elements={_to_string(constraint.elements)}, "
                        f"onupdate={constraint.onupdate}, "
                        f"ondelete={constraint.ondelete}, "
                        f"ondelete={constraint.match}, "
                        f"table.name={constraint.table.name})")
            
            case ReadOnlyColumnCollection() as cols:
                return "[" + ",".join(map(lambda c: repr(c), cols)) + "]"
                
            case _:
                return repr(value)
    
    
    def _check_columns_equal(first: ReadOnlyColumnCollection[str, Column],
                             second: ReadOnlyColumnCollection[str, Column]) -> bool:
        if len(first) != len(second):
            return False
        
        # TODO how does this work with columns of multiple tables? Can the collection map one key to multiple columns?
        for column in first:
            if column.name not in second:
                return False
            if second[column.name].table.name != column.table.name:
                return False
        
        return True
    
    
    def _check_check_constraint_clause_equal(actual_constraint: CheckConstraint, schema_constraint: CheckConstraint) -> bool:
        """
        This method tries its best to determine if two constraints are equal. However, in some cases the database transforms
        the constraints. We do some regex magic to reverse these transforms to compare them with the schema constraint.
        
        However, these can never catch all cases. When you encounter an error, you might need to add some additional
        transform here. These transforms are also not 100% safe (e.g. you can construct cases where this return True, even
        if the clauses are not actually equal), however that is very unlikely and probably involves some weird string constants
        in the constraint.
        """
        schema_text = str(schema_constraint.sqltext)
        actual_text = str(actual_constraint.sqltext)
        
        if schema_text == actual_text:
            return True
        
        # Transform IN checks (which the db transformed)
        #   type <> ALL (ARRAY['plain_video'::medium_metadata_type, 'plain_audio'::medium_metadata_type])
        # to
        #   type NOT IN ('plain_video', 'plain_audio')
        actual_text = re.sub(
            "\\(([a-zA-Z0-9_]+) <> ALL \\(ARRAY\\[('[a-zA-Z0-9_.-]+'::[a-zA-Z0-9_]+(?:, '[a-zA-Z0-9_.-]+'::[a-zA-Z0-9_]+)+)]\\)\\)",
            lambda match: f"{match.group(1)} NOT IN (" + re.sub("::[a-zA-Z0-9_]+", "", match.group(2)) + ")",
            actual_text
        )
        if schema_text == actual_text:
            return True
        
        # Transform one-element IN checks
        #   type <> 'course'::featured_type OR course_id IS NOT NULL
        # to
        #   type NOT IN ('course') OR course_id IS NOT NULL
        actual_text = re.sub(
            "([a-zA-Z0-9_]+) <> ('[a-zA-Z0-9_.-]+')::[a-zA-Z0-9_]+",
            lambda match: f"{match.group(1)} NOT IN ({match.group(2)})",
            actual_text
        )
        if schema_text == actual_text:
            return True
        
        # Transform NOT ILIKE checks
        #   file_path::text !~~* '/%'::text
        # to
        #   file_path NOT ILIKE '/%'
        actual_text = re.sub(
            "([a-zA-Z0-9_]+)(?:::[a-zA-Z0-9_]+)? !~~\\* ('[^']*')::[a-zA-Z0-9_]+",
            lambda match: f"{match.group(1)} NOT ILIKE {match.group(2)}",
            actual_text
        )
        if schema_text == actual_text:
            return True
    
        # Sometimes the db returns it without correct brackets at the start/end. Just strip them
        actual_text = actual_text.strip("()")
        schema_text = schema_text.strip("()")
        if schema_text == actual_text:
            return True
        
        print(f"Error: Check contraint '{schema_constraint.name}' has different clause in schema\n"
              f"    {schema_constraint.sqltext}\n"
              f"  than in database\n"
              f"    {actual_constraint.sqltext}\n"
              f"  Either, the database constraint is outdated, or we are unable to compare them because the database"
              f" transformed the constraint. In the latter case, you should add some more checks to handle these"
              f" transforms in _check_check_constraint_clause_equal() of drift_detector.py")
        return False
    
    
    def _check_constraint_equal(actual_constraint: Constraint, schema_constraint: Constraint) -> bool:
        if (actual_constraint.name is not None
                and schema_constraint.name is not None
                and actual_constraint.name != schema_constraint.name):
            return False
        
        if actual_constraint.deferrable != schema_constraint.deferrable:
            # TODO add to to_string
            return False
        
        if not isinstance(schema_constraint, ColumnCollectionConstraint):
            print(f"Unknown type of constraint {schema_constraint}. Can't compare")
            return False
        
        if not isinstance(actual_constraint, ColumnCollectionConstraint):
            return False
        
        if not _check_columns_equal(actual_constraint.columns, schema_constraint.columns):
            return False
        
        match schema_constraint:
            case CheckConstraint():
                assert isinstance(schema_constraint, CheckConstraint)  # For pycharm
                
                if not isinstance(actual_constraint, CheckConstraint):
                    return False
                
                return _check_check_constraint_clause_equal(actual_constraint, schema_constraint)
            case ForeignKeyConstraint():
                assert isinstance(schema_constraint, ForeignKeyConstraint)  # For pycharm
                
                if not isinstance(actual_constraint, ForeignKeyConstraint):
                    return False
                
                # TODO elements?
                return (schema_constraint.onupdate == actual_constraint.onupdate
                        and schema_constraint.ondelete == actual_constraint.ondelete
                        and schema_constraint.match == actual_constraint.match)
            case PrimaryKeyConstraint():
                return isinstance(actual_constraint, PrimaryKeyConstraint)
            case UniqueConstraint():
                return isinstance(actual_constraint, UniqueConstraint)
            case _:
                print(f"Unknown type of constraint {schema_constraint}. Can't compare")
                return False
    
    
    def _check_types_equal(actual: types.TypeEngine, schema: types.TypeEngine) -> bool:
        if type(schema) is types.Integer:
            return isinstance(actual, types.Integer)
        
        if type(schema) is types.SmallInteger:
            return isinstance(actual, types.SmallInteger)
        
        if type(schema) is types.BigInteger:
            return isinstance(actual, types.BigInteger)
        
        if type(schema) is types.Float:
            return isinstance(actual, types.Float)
        
        if type(schema) is types.Double:
            return isinstance(actual, types.Double)
        
        if type(schema) is types.Boolean:
            return isinstance(actual, types.Boolean)
        
        if type(schema) is types.String:
            # TODO Check collation. Right now sqlalchemy does return the collation via reflection (https://github.com/sqlalchemy/sqlalchemy/issues/6511)
            assert isinstance(schema, types.String)  # For pycharm
            return (isinstance(actual, types.VARCHAR)
                    and actual.length == schema.length)
        
        if type(schema) is types.Text:
            # Check collation. See above
            assert isinstance(schema, types.Text)  # For pycharm
            return (isinstance(actual, types.Text)
                    and actual.length == schema.length)
        
        if type(schema) is types.JSON:
            assert isinstance(schema, types.JSON)
            return (isinstance(actual, types.JSON)
                    and schema.none_as_null == actual.none_as_null)
        
        if type(schema) is postgresql.JSON:
            assert isinstance(schema, postgresql.JSON)
            return (isinstance(actual, postgresql.JSON)
                    and schema.none_as_null == actual.none_as_null)
        
        if type(schema) is postgresql.JSONB:
            assert isinstance(schema, postgresql.JSONB)
            return (isinstance(actual, postgresql.JSONB)
                    and schema.none_as_null == actual.none_as_null)
        
        if type(schema) is types.DateTime:
            assert isinstance(schema, types.DateTime)  # For pycharm
            return (isinstance(actual, types.DateTime)
                    and actual.timezone == schema.timezone)
        
        if type(schema) is types.Date:
            return isinstance(actual, types.Date)
        
        if type(schema) is types.TIMESTAMP:
            return isinstance(actual, types.TIMESTAMP)
        
        if type(schema) is UTCTimestamp:
            return isinstance(actual, types.TIMESTAMP)
        
        if isinstance(schema, types.Enum):
            if not isinstance(actual, types.Enum):
                return False
            if not schema.name:
                print(f"Enum '{schema}' in schema has no name. Can't compare")
                return False
            if not schema.values_callable:
                print(f"Enum '{schema}' in schema has no values_callable. Can't compare")
                return False
            # Name in actual not required because of sqlite
            return (not actual.name or schema.name == actual.name) and schema.enums == actual.enums
        
        if isinstance(schema, types.ARRAY):
            if not isinstance(actual, types.ARRAY):
                return False
            return _check_types_equal(actual.item_type, schema.item_type)
        
        raise RuntimeError(f"Comparison of types '{_to_string(actual)}' and '{_to_string(schema)}' is not supported")
    
    
    def _check_default_clause_equal(actual: DefaultClause, schema: DefaultClause) -> bool:
        if schema is None:
            return (actual is None or
                    (isinstance(actual.arg, TextClause) and actual.arg.text.startswith("NULL::")))
        
        if actual is None:
            return False
        
        if actual.for_update != schema.for_update:
            return False
        
        # TODO check other attributes?
        if isinstance(schema.arg, str):
            return actual.arg.text == schema.arg or actual.arg.text == f"'{schema.arg}'"
        elif isinstance(schema.arg, TextClause):
            return actual.arg.text == schema.arg.text
        else:
            raise ValueError(f"Unable to compare default clause '{_to_string(actual)}' with '{_to_string(schema)}'")
    
    
    def _get_actual_autoincrement(column: Column) -> bool:
        match column.autoincrement:
            case False:
                return False
            case True:
                return True
            case "auto":
                return (column.default is None
                        and column.server_default is None
                        and column.primary_key
                        and not any(map(lambda c: c is not column and c.primary_key, column.table.columns))
                        and _check_types_equal(column.type, types.Integer())
                        and len(column.foreign_keys) == 0
                        )
            case _:
                raise ValueError(f"Unknown value for autoincrement: {column.autoincrement}")
    
    
    def _check_column_autoincrement(
            actual_column: Column,
            schema_column: Column) -> bool:
        actual_autoincrement = _get_actual_autoincrement(actual_column)
        schema_autoincrement = _get_actual_autoincrement(schema_column)
        if actual_autoincrement != schema_autoincrement:
            print(f"Table '{actual_column.table.name}': Column '{actual_column.name}' has 'autoincrement' value "
                  f"'{_to_string(actual_column.autoincrement)}' in database but should have '{_to_string(schema_column.autoincrement)}'")
            return False
        return True
    
    
    def _check_column_server_default_with_autoincrement(actual_column: Column, schema_column: Column) -> bool:
        if actual_column.server_default is not None and not isinstance(actual_column.server_default, DefaultClause):
            print(f"Table '{actual_column.table}': Column '{actual_column.name}' has server_default value in database which "
                  f"is not of type DefaultClause. Can't compare")
            return False
        
        if schema_column.server_default is not None and not isinstance(schema_column.server_default, DefaultClause):
            print(f"Table '{schema_column.table}': Column '{schema_column.name}' has server_default value in schema which "
                  f"is not of type DefaultClause. Can't compare")
            return False
        
        if schema_column.server_default is not None:
            if _check_default_clause_equal(actual_column.server_default, schema_column.server_default):
                return True
            
            print(
                f"Table '{actual_column.table}': Column '{actual_column.name}' has server_default "
                f"'{_to_string(actual_column.server_default)}' but should have '{_to_string(schema_column.server_default)}'")
            return False
        
        if not _get_actual_autoincrement(schema_column):
            return True
        
        if (actual_column.server_default
                and isinstance(actual_column.server_default.arg, TextClause)
                and actual_column.server_default.arg.text.startswith("nextval(")):
            return True
        
        print(f"Table '{actual_column.table}': Column '{actual_column.name}' has autoincrement enabled but server default "
              f"value is not nextval() function in database")
        return False
    
    
    def _check_column_attribute(
            actual_column: Column,
            schema_column: Column,
            attr_name: str,
            comparator: Callable[[object, object], bool] = lambda a, s: a == s) -> bool:
        actual_attr = getattr(actual_column, attr_name)
        schema_attr = getattr(schema_column, attr_name)
        if not comparator(actual_attr, schema_attr):
            print(f"Table '{actual_column.table.name}': Column '{actual_column.name}' has '{attr_name}' value "
                  f"'{_to_string(actual_attr)}' in database but should have '{_to_string(schema_attr)}'")
            return False
        return True
    
    
    def _check_table_equal(actual_table: Table, schema_table: Table) -> bool:
        correct = True
        if actual_table.name != schema_table.name:
            correct = False
            print(f"Actual table has name '{actual_table.name}' but should have '{schema_table.name}'")
        
        for column_name, actual_column in actual_table.columns.items():
            if column_name not in schema_table.columns:
                if actual_column.server_default is None:
                    correct = False
                    print(
                        f"Table '{actual_table.name}': Database contains unknown column '{column_name}' which does not have a default value")
                continue
            schema_column = schema_table.columns[column_name]
            
            # Note that 'default' is only a client-side attribute!
            assert isinstance(actual_column, Column)
            correct &= _check_column_autoincrement(actual_column, schema_column)
            correct &= _check_column_server_default_with_autoincrement(actual_column, schema_column)
            correct &= _check_column_attribute(actual_column, schema_column, "computed")
            correct &= _check_column_attribute(actual_column, schema_column, "identity")
            # TODO key
            correct &= _check_column_attribute(actual_column, schema_column, "nullable")
            correct &= _check_column_attribute(actual_column, schema_column, "onupdate")
            correct &= _check_column_attribute(actual_column, schema_column, "primary_key")
            correct &= _check_column_attribute(actual_column, schema_column, "server_onupdate")
            correct &= _check_column_attribute(actual_column, schema_column, "system")
            # noinspection PyTypeChecker
            correct &= _check_column_attribute(actual_column, schema_column, "type", _check_types_equal)
            
            # Unique and Foreign Key is checked with constraints
        
        for column_name, schema_column in schema_table.columns.items():
            if column_name not in actual_table.columns:
                correct = False
                print(f"Missing column '{column_name}' in database")
        
        unmatched_actual_constraints = set(actual_table.constraints)
        for schema_constraint in schema_table.constraints:
            for actual_constraint in actual_table.constraints:
                if _check_constraint_equal(actual_constraint, schema_constraint):
                    unmatched_actual_constraints.remove(actual_constraint)
                    break
            else:
                correct = False
                print(f"Missing constraint\n  {_to_string(schema_constraint)}\nin database for table {schema_table.name}. "
                      f"The following constraints do not match:")
                print("\n".join(map(lambda c: f"  {_to_string(c)}", actual_table.constraints)))
        
        if len(unmatched_actual_constraints) > 0:
            correct = False
            print(f"Got unexpected constraint\n  {_to_string(next(iter(unmatched_actual_constraints)))}\nin database for "
                  f"table {schema_table.name}. The following schema constraints do not match:")
            print("\n".join(map(lambda c: f"  {_to_string(c)}", schema_table.constraints)))
        
        return correct
    
    
    def _add_column(engine: Engine, column: Column):
        import sqlalchemy as sql
        from sqlalchemy.sql import ddl
        
        # Not really nice but works for the simple cases
        col_spec = str(ddl.CreateColumn(column).compile(bind=engine))
        with engine.begin() as ses:
            command = f"ALTER TABLE \"{column.table.name}\" ADD {col_spec}"
            ses.execute(sql.text(command))
            ses.commit()
    
    
    def check_for_drift_and_migrate(metadata: MetaData, engine: Engine, migrate: bool = False) -> bool:
        correct = True
        actual_metadata = MetaData()
        
        if migrate:
            print("Letting SQLAlchemy create missing entities...")
            import logging
            sqlalchemy_logger = logging.getLogger("sqlalchemy.engine")
            
            own_handler = logging.StreamHandler(sys.stdout)
            own_handler.setLevel(logging.INFO)
            
            def _filter(record: logging.LogRecord) -> bool:
                msg = record.getMessage().upper()
                return not msg.startswith("SELECT") and not (
                    msg.startswith("[CACHED SINCE") and msg.endswith("}")
                ) and not (
                    msg.startswith("[GENERATED IN") and msg.endswith("}")
                )
            
            own_handler.addFilter(_filter)
            
            old_level = sqlalchemy_logger.level
            sqlalchemy_logger.setLevel(logging.INFO)
            sqlalchemy_logger.addHandler(own_handler)
            
            metadata.create_all(engine)
            
            # Try our own luck with adding missing columns
            migration_actual_metadata = MetaData()
            for table_name, schema_table in metadata.tables.items():
                try:
                    actual_table = Table(table_name, migration_actual_metadata, autoload_with=engine)
                    for column_name, schema_column in schema_table.columns.items():
                        if column_name in actual_table.columns:
                            continue
                        _add_column(engine, schema_column)
                except NoSuchTableError:
                    pass
            
            sqlalchemy_logger.removeHandler(own_handler)
            sqlalchemy_logger.setLevel(old_level)
            print(f"SQLAlchemy is done. See log for executed statements (selects excluded)")
        
        print(f"Checking schema against actual entities")
        for table_name, schema_table in metadata.tables.items():
            assert isinstance(schema_table, Table)
            try:
                actual_table = Table(table_name, actual_metadata, autoload_with=engine)
                correct &= _check_table_equal(actual_table, schema_table)
            except NoSuchTableError:
                correct = False
                if migrate:
                    print(f"Missing table '{table_name}' in database (Migration is enabled but SQLAlchemy did not create this table for some reason)")
                else:
                    print(f"Missing table '{table_name}' in database (Migration is disabled)")
        
        print(f"Drift detection is done. Schema {'is ok' if correct else 'has drifted'}")
        return correct