Select Git revision
Forked from
Video AG Infrastruktur / website
Source project has a limited visibility.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
drift_detector.py 21.89 KiB
import re
import sys
import traceback
from typing import Callable
from sqlalchemy import Engine, Table, MetaData, Column, types, DefaultClause, TextClause, Constraint, CheckConstraint, \
ForeignKeyConstraint, UniqueConstraint, PrimaryKeyConstraint
from sqlalchemy.exc import *
from sqlalchemy.sql.base import _NoneName, ReadOnlyColumnCollection
from sqlalchemy.sql.schema import ColumnCollectionConstraint, ForeignKey
from sqlalchemy.dialects import postgresql as postgresql
from .utc_timestamp import UTCTimestamp
#
# This file attempts its best to detect any difference between the schema in the python files and the actual database.
# 'Simple' differences can also be fixed. This includes:
# - Adding a missing table # TODO
# - Adding a missing type
# - Adding a missing column
#
# The following aspects are currently being checked with varying accuracy
# - Table names
# - Columns
# - Name
# - Type
# - Nullable
# - Primary Key
# - Auto Increment
# - Server default value (only if set in python schema, for autoincrement)
# - Unique
#
# What is NOT checked
# - Collations
# - Indices TODO
#
# A lot of the code here is 'a bit messy' due to all the different dialects (This only really needs to work for postgres
# but SQLAlchemy is designed to work with all of them). Especially the type comparison. If you use new types, you might
# need to tweak some stuff here
#
# Improved string output
def _to_string(value) -> str:
match value:
case DefaultClause() as clause:
return f"DefaultClause({_to_string(clause.arg)}, for_update={clause.for_update})"
case TextClause() as clause:
return f"TextClause(text='{clause.text}')" # Because TextClause doesn't print its text by default
case PrimaryKeyConstraint() as constraint:
assert isinstance(constraint, PrimaryKeyConstraint)
return f"PrimaryKeyConstraint({_to_string(constraint.columns)}, table.name='{constraint.table.name}')"
case CheckConstraint() as constraint:
assert isinstance(constraint, CheckConstraint)
return f"CheckConstraint({_to_string(constraint.columns)}, sqltext='{constraint.sqltext}', table.name='{constraint.table.name}')"
case UniqueConstraint() as constraint:
assert isinstance(constraint, UniqueConstraint)
return f"UniqueConstraint({_to_string(constraint.columns)}, table.name='{constraint.table.name}')"
case ForeignKeyConstraint() as constraint:
assert isinstance(constraint, ForeignKeyConstraint)
return (f"ForeignKeyConstraint("
f"columns={_to_string(constraint.columns)}, "
f"elements={_to_string(constraint.elements)}, "
f"onupdate={constraint.onupdate}, "
f"ondelete={constraint.ondelete}, "
f"ondelete={constraint.match}, "
f"table.name={constraint.table.name})")
case ReadOnlyColumnCollection() as cols:
return "[" + ",".join(map(lambda c: repr(c), cols)) + "]"
case _:
return repr(value)
def _check_columns_equal(first: ReadOnlyColumnCollection[str, Column],
second: ReadOnlyColumnCollection[str, Column]) -> bool:
if len(first) != len(second):
return False
# TODO how does this work with columns of multiple tables? Can the collection map one key to multiple columns?
for column in first:
if column.name not in second:
return False
if second[column.name].table.name != column.table.name:
return False
return True
def _check_check_constraint_clause_equal(actual_constraint: CheckConstraint, schema_constraint: CheckConstraint) -> bool:
"""
This method tries its best to determine if two constraints are equal. However, in some cases the database transforms
the constraints. We do some regex magic to reverse these transforms to compare them with the schema constraint.
However, these can never catch all cases. When you encounter an error, you might need to add some additional
transform here. These transforms are also not 100% safe (e.g. you can construct cases where this return True, even
if the clauses are not actually equal), however that is very unlikely and probably involves some weird string constants
in the constraint.
"""
schema_text = str(schema_constraint.sqltext)
actual_text = str(actual_constraint.sqltext)
if schema_text == actual_text:
return True
# Transform IN checks (which the db transformed)
# type <> ALL (ARRAY['plain_video'::medium_metadata_type, 'plain_audio'::medium_metadata_type])
# to
# type NOT IN ('plain_video', 'plain_audio')
actual_text = re.sub(
"\\(([a-zA-Z0-9_]+) <> ALL \\(ARRAY\\[('[a-zA-Z0-9_.-]+'::[a-zA-Z0-9_]+(?:, '[a-zA-Z0-9_.-]+'::[a-zA-Z0-9_]+)+)]\\)\\)",
lambda match: f"{match.group(1)} NOT IN (" + re.sub("::[a-zA-Z0-9_]+", "", match.group(2)) + ")",
actual_text
)
if schema_text == actual_text:
return True
# Transform one-element IN checks
# type <> 'course'::featured_type OR course_id IS NOT NULL
# to
# type NOT IN ('course') OR course_id IS NOT NULL
actual_text = re.sub(
"([a-zA-Z0-9_]+) <> ('[a-zA-Z0-9_.-]+')::[a-zA-Z0-9_]+",
lambda match: f"{match.group(1)} NOT IN ({match.group(2)})",
actual_text
)
if schema_text == actual_text:
return True
# Transform NOT ILIKE checks
# file_path::text !~~* '/%'::text
# to
# file_path NOT ILIKE '/%'
actual_text = re.sub(
"([a-zA-Z0-9_]+)(?:::[a-zA-Z0-9_]+)? !~~\\* ('[^']*')::[a-zA-Z0-9_]+",
lambda match: f"{match.group(1)} NOT ILIKE {match.group(2)}",
actual_text
)
if schema_text == actual_text:
return True
# Sometimes the db returns it without correct brackets at the start/end. Just strip them
actual_text = actual_text.strip("()")
schema_text = schema_text.strip("()")
if schema_text == actual_text:
return True
print(f"Error: Check contraint '{schema_constraint.name}' has different clause in schema\n"
f" {schema_constraint.sqltext}\n"
f" than in database\n"
f" {actual_constraint.sqltext}\n"
f" Either, the database constraint is outdated, or we are unable to compare them because the database"
f" transformed the constraint. In the latter case, you should add some more checks to handle these"
f" transforms in _check_check_constraint_clause_equal() of drift_detector.py")
return False
def _check_constraint_equal(actual_constraint: Constraint, schema_constraint: Constraint) -> bool:
if (actual_constraint.name is not None
and schema_constraint.name is not None
and actual_constraint.name != schema_constraint.name):
return False
if actual_constraint.deferrable != schema_constraint.deferrable:
# TODO add to to_string
return False
if not isinstance(schema_constraint, ColumnCollectionConstraint):
print(f"Unknown type of constraint {schema_constraint}. Can't compare")
return False
if not isinstance(actual_constraint, ColumnCollectionConstraint):
return False
if not _check_columns_equal(actual_constraint.columns, schema_constraint.columns):
return False
match schema_constraint:
case CheckConstraint():
assert isinstance(schema_constraint, CheckConstraint) # For pycharm
if not isinstance(actual_constraint, CheckConstraint):
return False
return _check_check_constraint_clause_equal(actual_constraint, schema_constraint)
case ForeignKeyConstraint():
assert isinstance(schema_constraint, ForeignKeyConstraint) # For pycharm
if not isinstance(actual_constraint, ForeignKeyConstraint):
return False
# TODO elements?
return (schema_constraint.onupdate == actual_constraint.onupdate
and schema_constraint.ondelete == actual_constraint.ondelete
and schema_constraint.match == actual_constraint.match)
case PrimaryKeyConstraint():
return isinstance(actual_constraint, PrimaryKeyConstraint)
case UniqueConstraint():
return isinstance(actual_constraint, UniqueConstraint)
case _:
print(f"Unknown type of constraint {schema_constraint}. Can't compare")
return False
def _check_types_equal(actual: types.TypeEngine, schema: types.TypeEngine) -> bool:
if type(schema) is types.Integer:
return isinstance(actual, types.Integer)
if type(schema) is types.SmallInteger:
return isinstance(actual, types.SmallInteger)
if type(schema) is types.BigInteger:
return isinstance(actual, types.BigInteger)
if type(schema) is types.Float:
return isinstance(actual, types.Float)
if type(schema) is types.Double:
return isinstance(actual, types.Double)
if type(schema) is types.Boolean:
return isinstance(actual, types.Boolean)
if type(schema) is types.String:
# TODO Check collation. Right now sqlalchemy does return the collation via reflection (https://github.com/sqlalchemy/sqlalchemy/issues/6511)
assert isinstance(schema, types.String) # For pycharm
return (isinstance(actual, types.VARCHAR)
and actual.length == schema.length)
if type(schema) is types.Text:
# Check collation. See above
assert isinstance(schema, types.Text) # For pycharm
return (isinstance(actual, types.Text)
and actual.length == schema.length)
if type(schema) is types.JSON:
assert isinstance(schema, types.JSON)
return (isinstance(actual, types.JSON)
and schema.none_as_null == actual.none_as_null)
if type(schema) is postgresql.JSON:
assert isinstance(schema, postgresql.JSON)
return (isinstance(actual, postgresql.JSON)
and schema.none_as_null == actual.none_as_null)
if type(schema) is postgresql.JSONB:
assert isinstance(schema, postgresql.JSONB)
return (isinstance(actual, postgresql.JSONB)
and schema.none_as_null == actual.none_as_null)
if type(schema) is types.DateTime:
assert isinstance(schema, types.DateTime) # For pycharm
return (isinstance(actual, types.DateTime)
and actual.timezone == schema.timezone)
if type(schema) is types.Date:
return isinstance(actual, types.Date)
if type(schema) is types.TIMESTAMP:
return isinstance(actual, types.TIMESTAMP)
if type(schema) is UTCTimestamp:
return isinstance(actual, types.TIMESTAMP)
if isinstance(schema, types.Enum):
if not isinstance(actual, types.Enum):
return False
if not schema.name:
print(f"Enum '{schema}' in schema has no name. Can't compare")
return False
if not schema.values_callable:
print(f"Enum '{schema}' in schema has no values_callable. Can't compare")
return False
# Name in actual not required because of sqlite
return (not actual.name or schema.name == actual.name) and schema.enums == actual.enums
if isinstance(schema, types.ARRAY):
if not isinstance(actual, types.ARRAY):
return False
return _check_types_equal(actual.item_type, schema.item_type)
raise RuntimeError(f"Comparison of types '{_to_string(actual)}' and '{_to_string(schema)}' is not supported")
def _check_default_clause_equal(actual: DefaultClause, schema: DefaultClause) -> bool:
if schema is None:
return (actual is None or
(isinstance(actual.arg, TextClause) and actual.arg.text.startswith("NULL::")))
if actual is None:
return False
if actual.for_update != schema.for_update:
return False
# TODO check other attributes?
if isinstance(schema.arg, str):
return actual.arg.text == schema.arg or actual.arg.text == f"'{schema.arg}'"
elif isinstance(schema.arg, TextClause):
return actual.arg.text == schema.arg.text
else:
raise ValueError(f"Unable to compare default clause '{_to_string(actual)}' with '{_to_string(schema)}'")
def _get_actual_autoincrement(column: Column) -> bool:
match column.autoincrement:
case False:
return False
case True:
return True
case "auto":
return (column.default is None
and column.server_default is None
and column.primary_key
and not any(map(lambda c: c is not column and c.primary_key, column.table.columns))
and _check_types_equal(column.type, types.Integer())
and len(column.foreign_keys) == 0
)
case _:
raise ValueError(f"Unknown value for autoincrement: {column.autoincrement}")
def _check_column_autoincrement(
actual_column: Column,
schema_column: Column) -> bool:
actual_autoincrement = _get_actual_autoincrement(actual_column)
schema_autoincrement = _get_actual_autoincrement(schema_column)
if actual_autoincrement != schema_autoincrement:
print(f"Table '{actual_column.table.name}': Column '{actual_column.name}' has 'autoincrement' value "
f"'{_to_string(actual_column.autoincrement)}' in database but should have '{_to_string(schema_column.autoincrement)}'")
return False
return True
def _check_column_server_default_with_autoincrement(actual_column: Column, schema_column: Column) -> bool:
if actual_column.server_default is not None and not isinstance(actual_column.server_default, DefaultClause):
print(f"Table '{actual_column.table}': Column '{actual_column.name}' has server_default value in database which "
f"is not of type DefaultClause. Can't compare")
return False
if schema_column.server_default is not None and not isinstance(schema_column.server_default, DefaultClause):
print(f"Table '{schema_column.table}': Column '{schema_column.name}' has server_default value in schema which "
f"is not of type DefaultClause. Can't compare")
return False
if schema_column.server_default is not None:
if _check_default_clause_equal(actual_column.server_default, schema_column.server_default):
return True
print(
f"Table '{actual_column.table}': Column '{actual_column.name}' has server_default "
f"'{_to_string(actual_column.server_default)}' but should have '{_to_string(schema_column.server_default)}'")
return False
if not _get_actual_autoincrement(schema_column):
return True
if (actual_column.server_default
and isinstance(actual_column.server_default.arg, TextClause)
and actual_column.server_default.arg.text.startswith("nextval(")):
return True
print(f"Table '{actual_column.table}': Column '{actual_column.name}' has autoincrement enabled but server default "
f"value is not nextval() function in database")
return False
def _check_column_attribute(
actual_column: Column,
schema_column: Column,
attr_name: str,
comparator: Callable[[object, object], bool] = lambda a, s: a == s) -> bool:
actual_attr = getattr(actual_column, attr_name)
schema_attr = getattr(schema_column, attr_name)
if not comparator(actual_attr, schema_attr):
print(f"Table '{actual_column.table.name}': Column '{actual_column.name}' has '{attr_name}' value "
f"'{_to_string(actual_attr)}' in database but should have '{_to_string(schema_attr)}'")
return False
return True
def _check_table_equal(actual_table: Table, schema_table: Table) -> bool:
correct = True
if actual_table.name != schema_table.name:
correct = False
print(f"Actual table has name '{actual_table.name}' but should have '{schema_table.name}'")
for column_name, actual_column in actual_table.columns.items():
if column_name not in schema_table.columns:
if actual_column.server_default is None:
correct = False
print(
f"Table '{actual_table.name}': Database contains unknown column '{column_name}' which does not have a default value")
continue
schema_column = schema_table.columns[column_name]
# Note that 'default' is only a client-side attribute!
assert isinstance(actual_column, Column)
correct &= _check_column_autoincrement(actual_column, schema_column)
correct &= _check_column_server_default_with_autoincrement(actual_column, schema_column)
correct &= _check_column_attribute(actual_column, schema_column, "computed")
correct &= _check_column_attribute(actual_column, schema_column, "identity")
# TODO key
correct &= _check_column_attribute(actual_column, schema_column, "nullable")
correct &= _check_column_attribute(actual_column, schema_column, "onupdate")
correct &= _check_column_attribute(actual_column, schema_column, "primary_key")
correct &= _check_column_attribute(actual_column, schema_column, "server_onupdate")
correct &= _check_column_attribute(actual_column, schema_column, "system")
# noinspection PyTypeChecker
correct &= _check_column_attribute(actual_column, schema_column, "type", _check_types_equal)
# Unique and Foreign Key is checked with constraints
for column_name, schema_column in schema_table.columns.items():
if column_name not in actual_table.columns:
correct = False
print(f"Missing column '{column_name}' in database")
unmatched_actual_constraints = set(actual_table.constraints)
for schema_constraint in schema_table.constraints:
for actual_constraint in actual_table.constraints:
if _check_constraint_equal(actual_constraint, schema_constraint):
unmatched_actual_constraints.remove(actual_constraint)
break
else:
correct = False
print(f"Missing constraint\n {_to_string(schema_constraint)}\nin database for table {schema_table.name}. "
f"The following constraints do not match:")
print("\n".join(map(lambda c: f" {_to_string(c)}", actual_table.constraints)))
if len(unmatched_actual_constraints) > 0:
correct = False
print(f"Got unexpected constraint\n {_to_string(next(iter(unmatched_actual_constraints)))}\nin database for "
f"table {schema_table.name}. The following schema constraints do not match:")
print("\n".join(map(lambda c: f" {_to_string(c)}", schema_table.constraints)))
return correct
def _add_column(engine: Engine, column: Column):
import sqlalchemy as sql
from sqlalchemy.sql import ddl
# Not really nice but works for the simple cases
col_spec = str(ddl.CreateColumn(column).compile(bind=engine))
with engine.begin() as ses:
command = f"ALTER TABLE \"{column.table.name}\" ADD {col_spec}"
ses.execute(sql.text(command))
ses.commit()
def check_for_drift_and_migrate(metadata: MetaData, engine: Engine, migrate: bool = False) -> bool:
correct = True
actual_metadata = MetaData()
if migrate:
print("Letting SQLAlchemy create missing entities...")
import logging
sqlalchemy_logger = logging.getLogger("sqlalchemy.engine")
own_handler = logging.StreamHandler(sys.stdout)
own_handler.setLevel(logging.INFO)
def _filter(record: logging.LogRecord) -> bool:
msg = record.getMessage().upper()
return not msg.startswith("SELECT") and not (
msg.startswith("[CACHED SINCE") and msg.endswith("}")
) and not (
msg.startswith("[GENERATED IN") and msg.endswith("}")
)
own_handler.addFilter(_filter)
old_level = sqlalchemy_logger.level
sqlalchemy_logger.setLevel(logging.INFO)
sqlalchemy_logger.addHandler(own_handler)
metadata.create_all(engine)
# Try our own luck with adding missing columns
migration_actual_metadata = MetaData()
for table_name, schema_table in metadata.tables.items():
try:
actual_table = Table(table_name, migration_actual_metadata, autoload_with=engine)
for column_name, schema_column in schema_table.columns.items():
if column_name in actual_table.columns:
continue
_add_column(engine, schema_column)
except NoSuchTableError:
pass
sqlalchemy_logger.removeHandler(own_handler)
sqlalchemy_logger.setLevel(old_level)
print(f"SQLAlchemy is done. See log for executed statements (selects excluded)")
print(f"Checking schema against actual entities")
for table_name, schema_table in metadata.tables.items():
assert isinstance(schema_table, Table)
try:
actual_table = Table(table_name, actual_metadata, autoload_with=engine)
correct &= _check_table_equal(actual_table, schema_table)
except NoSuchTableError:
correct = False
if migrate:
print(f"Missing table '{table_name}' in database (Migration is enabled but SQLAlchemy did not create this table for some reason)")
else:
print(f"Missing table '{table_name}' in database (Migration is disabled)")
print(f"Drift detection is done. Schema {'is ok' if correct else 'has drifted'}")
return correct