Select Git revision
Forked from
Video AG Infrastruktur / website
Source project has a limited visibility.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
object_class.py 26.67 KiB
from enum import Enum
from videoag_common.database import *
from videoag_common.miscellaneous import *
from .fields import *
from .object import ApiObject, DeletableApiObject, AC_INCLUDE_DELETED, AC_IS_MOD
from .util import get_relationship_foreign_class
def _get_variant_id(clazz: type[ApiObject], optional: bool = False) -> str or None:
identity = None
if hasattr(clazz, "__mapper_args__"):
identity = clazz.__mapper_args__.get("polymorphic_identity")
if identity is None:
if optional:
return None
raise Exception(f"Missing 'polymorphic_identity' in '__mapper_args__' for class '{clazz.__name__}'")
if isinstance(identity, Enum):
identity = identity.value
if not isinstance(identity, str):
raise Exception(f"'polymorphic_identity' in '__mapper_args__' for class '{clazz.__name__}' is not a string "
f"or an enum with a string value")
return identity
class ApiObjectClass:
def __init__(self,
parent_relationship_config_ids: list[str] or None = None,
enable_config: bool or None = None,
config_allow_creation: bool = True,
enable_data: bool or None = None,
):
self._parent_relationship_config_ids = parent_relationship_config_ids
self.enable_config = enable_config
self.config_allow_creation = config_allow_creation
self.enable_data = enable_data
self.orm_class = None
self.id = None
self.sql_table = None
self.sql_id_column = None
self._config_load_options: list[ExecutableOption] = []
self._fields_by_variant_by_config_id: dict[str or None, dict[str, ApiConfigField]] = {
None: {}
}
self._fields_by_variant_by_data_id: dict[str or None, dict[str, ApiDataField]] = {
None: {}
}
self._parent_relationship_id_by_class_id = {}
self._variant_field: ApiEnumField or None = None
self._orm_classes_by_variant: dict[str or None, type[ApiObject]] = {}
def set_class(self, orm_class: type[ApiObject]):
self.orm_class = orm_class
self.id = alphanum_camel_case_to_snake_case(orm_class.__name__)
self.sql_table = Base.metadata.tables[self.id]
self.sql_id_column = self.sql_table.columns["id"]
def _post_init(self, all_classes: dict[str, "ApiObjectClass"]):
self._post_init_fields(all_classes)
self._post_init_variants()
self._post_init_variant_fields(all_classes)
if self.enable_config is None:
self.enable_config = any(map(
lambda fields_by_id: len(fields_by_id) > 0,
self._fields_by_variant_by_config_id.values()
))
if self.enable_data is None:
self.enable_data = any(map(
lambda fields_by_id: len(fields_by_id) > 0,
self._fields_by_variant_by_data_id.values()
))
self._post_init_parents(all_classes)
self._post_init_creation_config()
def _post_init_fields(self, all_classes: dict[str, "ApiObjectClass"]):
for clazz in set(recursive_flat_map_single(lambda c: c.__bases__, self.orm_class)):
if not hasattr(clazz, "__all_class_api_fields__"):
continue
# noinspection PyUnresolvedReferences
for field in clazz.__all_class_api_fields__:
assert isinstance(field, ApiField)
try:
self._add_field(all_classes, field, None, None)
except Exception as e:
raise Exception(
f"While initializing field for member '{field.member_name}' in class '{clazz.__name__}' for "
f"'{self.orm_class.__name__}'") from e
def _post_init_variants(self):
variant_field_id = None
if hasattr(self.orm_class, "__mapper_args__"):
# May be none
variant_field_id = self.orm_class.__mapper_args__.get("polymorphic_on")
if variant_field_id is None:
return
var_field = self._fields_by_variant_by_config_id[None].get(variant_field_id)
if var_field is None:
var_field = self._fields_by_variant_by_data_id[None].get(variant_field_id)
if var_field is None:
raise Exception(f"Unknown variant field '{variant_field_id}' (determined by 'polymorphic_on' with config id, "
f"or data id if config id not found) for API class '{self.orm_class.__name__}' (Variant "
f"field may not be in a subclass!, and must be included in the config or data)")
if self.enable_config and self.config_allow_creation and not var_field.include_in_config:
raise Exception(f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' must be included "
f"in the config if the config allows creation")
if not isinstance(var_field, ApiEnumField):
raise Exception(f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' must be an "
f"enum!")
self._variant_field = var_field
if self._variant_field.may_be_none:
raise Exception(f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' may not be "
f"nullable")
if self.enable_config and self.config_allow_creation and not self._variant_field.config_only_at_creation:
raise Exception(
f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' must be only "
f"at creation")
for variant_id in self._variant_field.str_enums:
self._fields_by_variant_by_config_id[variant_id] = {}
self._fields_by_variant_by_data_id[variant_id] = {}
# Note that the base variant actually can't have any variant fields (All fields without a variant belong
# to it (and to all other variants))
# The base variant is also optional
base_variant_id = _get_variant_id(self.orm_class, optional=True)
if base_variant_id is not None and base_variant_id not in self._variant_field.str_enums:
raise Exception(f"Class '{self.orm_class.__name__}' has unknown variant '{base_variant_id}'")
if base_variant_id is not None:
self._orm_classes_by_variant[base_variant_id] = self.orm_class
def _post_init_variant_fields(self, all_classes: dict[str, "ApiObjectClass"]):
variant_ignore_classes = set(recursive_flat_map_single(lambda c: c.__bases__, self.orm_class))
for sub_class in self.orm_class.__subclasses__():
if not hasattr(sub_class, "__all_class_api_fields__"):
continue
if self._variant_field is None:
raise Exception(f"Found API subclass '{sub_class.__name__}' of API class '{self.orm_class.__name__}' "
f"but 'polymorphic_on' is missing in the orm '__mapper_args__'")
variant_id = _get_variant_id(sub_class)
if variant_id not in self._variant_field.str_enums:
raise Exception(f"Unknown variant '{variant_id}' for class '{sub_class.__name__}' in enum "
f"'{self._variant_field.enum_class.__name__}'")
self._orm_classes_by_variant[variant_id] = sub_class
# noinspection PyUnresolvedReferences
for super_sub_class in set(recursive_flat_map_single(lambda c: c.__bases__, sub_class)) - variant_ignore_classes:
if not hasattr(super_sub_class, "__all_class_api_fields__"):
continue
for sub_field in super_sub_class.__all_class_api_fields__:
assert isinstance(sub_field, ApiField)
try:
self._add_field(all_classes, sub_field, sub_class, variant_id)
except Exception as e:
raise Exception(
f"While initializing field for member '{sub_field.member_name}' in class '{super_sub_class.__name__}'") from e
for sub2_class in filter(
lambda c: hasattr(c, "__all_class_api_fields__"),
recursive_flat_map(
lambda c: c.__subclasses__(),
sub_class.__subclasses__()
)):
raise Exception(
f"API class '{sub2_class.__name__}' is indirect child of API class '{self.orm_class.__name__}'."
f" Indirect inheritance is not supported (Only direct)")
if self._variant_field is not None:
for variant_id in self._variant_field.str_enums:
if variant_id not in self._orm_classes_by_variant:
raise Exception(f"No ORM class for variant '{variant_id}' of API object class '{self.id}'")
def _post_init_parents(self, all_classes: dict[str, "ApiObjectClass"]):
for parent_relation_id in (self._parent_relationship_config_ids or []):
parent_relation_attr = None
if hasattr(self.orm_class, parent_relation_id):
parent_relation_attr = getattr(self.orm_class, parent_relation_id)
if parent_relation_attr is None:
raise TypeError(f"Unknown parent relationship id '{parent_relation_id}' for class "
f"'{self.orm_class.__name__}'")
parent_relation = parent_relation_attr.prop
if not isinstance(parent_relation, orm.Relationship):
raise TypeError(f"Parent relationship '{parent_relation_id}' for class '{self.orm_class.__name__}' "
f"is not a relationship")
if parent_relation.collection_class is not None:
raise TypeError(f"Parent relationship '{parent_relation_id}' for class '{self.orm_class.__name__}' "
f"can reference more than one object")
parent_class = get_relationship_foreign_class(all_classes, parent_relation)
if parent_class.id in self._parent_relationship_id_by_class_id:
raise ValueError(f"Multiple parent fields for parent class '{parent_class.id}' in class '{self.orm_class.__name__}' ")
self._parent_relationship_id_by_class_id[parent_class.id] = parent_relation_id
def _post_init_creation_config(self):
self._creation_config_json = None
if not self.enable_config or not self.config_allow_creation:
return
var_fields = {
variant_id: [
field.config_get_description()
for field in var_fields.values()
if field is not self._variant_field
] for variant_id, var_fields in self._fields_by_variant_by_config_id.items()
}
self._creation_config_json = {
"fields": var_fields[None]
}
if self._variant_field is not None:
var_fields.pop(None)
self._creation_config_json["variant_fields"] = var_fields
def _add_field(self,
all_classes: dict[str, "ApiObjectClass"],
field: ApiField,
variant_class: type or None,
variant_id: str or None):
attr_class = self.orm_class if variant_id is None else variant_class
attr = None
if field.member_name is not None and hasattr(attr_class, field.member_name):
attr = getattr(attr_class, field.member_name)
field = field.pre_post_init_copy()
context = FieldContext(
all_classes=all_classes,
api_class=self,
own_member=attr,
variant_id=variant_id
)
field.post_init(context)
field.post_init_check(context)
if isinstance(field, ApiConfigField) and field.include_in_config:
fields_by_id = self._fields_by_variant_by_config_id[variant_id]
if field.config_id in fields_by_id:
raise ValueError(f"Config id conflict with '{field.config_id}' in variant '{variant_id}'")
fields_by_id[field.config_id] = field
if variant_id is not None:
fields_by_id = self._fields_by_variant_by_config_id[None]
if field.config_id in fields_by_id:
raise ValueError(f"Config id conflict with '{field.config_id}' in variant '{variant_id}' and non-variant")
self._config_load_options.extend(field.config_get_load_options())
if isinstance(field, ApiDataField) and field.include_in_data:
fields_by_id = self._fields_by_variant_by_data_id[variant_id]
if field.data_id in fields_by_id:
raise ValueError(f"Data id conflict with '{field.data_id}' in variant '{variant_id}'")
fields_by_id[field.data_id] = field
if variant_id is not None:
fields_by_id = self._fields_by_variant_by_data_id[None]
if field.data_id in fields_by_id:
raise ValueError(f"Data id conflict with '{field.data_id}' in variant '{variant_id}' and non-variant")
def _get_current_variant(self, obj) -> str or None:
if self._variant_field is None:
return None
assert isinstance(self._variant_field, ApiEnumField)
variant_id = self._variant_field.config_get_value(obj) # For enums the same as data_get_value
if variant_id not in self._variant_field.str_enums:
raise ValueError(f"Unknown variant for class '{self.id}' in database: '{variant_id}' (object id: {obj.id})")
return variant_id
def _get_current_variant_set(self, obj) -> set[str or None]:
return {self._get_current_variant(obj), None}
def get_variants(self) -> list[str] or None:
if self._variant_field is None:
return None
return self._variant_field.str_enums
def get_data_fields(self) -> list[ApiDataField]:
return list(flat_map(lambda by_id: by_id.values(), self._fields_by_variant_by_data_id.values()))
def get_config_fields(self) -> list[ApiConfigField]:
return list(flat_map(lambda by_id: by_id.values(), self._fields_by_variant_by_config_id.values()))
def get_data_fields_with_id(self, data_id: str) -> list[ApiDataField]:
field_list = []
for by_id in self._fields_by_variant_by_data_id.values():
field = by_id.get(data_id)
if field is not None:
field_list.append(field)
return field_list
def get_config_fields_with_id(self, config_id: str) -> list[ApiConfigField]:
field_list = []
for by_id in self._fields_by_variant_by_config_id.values():
field = by_id.get(config_id)
if field is not None:
field_list.append(field)
return field_list
def serialize(self, obj, **kwargs) -> dict:
return self._serialize_catch_error(obj, to_context=False, **kwargs)
def serialize_to_context(self, obj, **kwargs) -> int:
return self._serialize_catch_error(obj, to_context=True, **kwargs)
def _serialize_catch_error(self, obj, to_context: bool, **kwargs) -> dict or int:
if not self.enable_data:
raise Exception("Serialization not enabled")
try:
return self.serialize_object_args(obj, ArgAttributeObject(**kwargs), to_context)
except AttributeError as e:
if not str(e).startswith("'ArgAttributeObject' object has no attribute"):
raise e
raise AttributeError(f"Missing a keyword argument for serialization of object '{self.id}': " + str(e)) from e
def serialize_object_args(self, obj, args, to_context: bool = False) -> dict or int:
if not self.enable_data:
raise Exception("Serialization not enabled")
if not to_context:
res = {}
for variant_id in self._get_current_variant_set(obj):
for field in self._fields_by_variant_by_data_id[variant_id].values():
field.add_to_data(obj, res, args)
return res
context_name = f"{self.id}_context"
if not hasattr(args, context_name):
raise Exception(f"Missing serialization argument '{context_name}'")
context = getattr(args, context_name)
if not isinstance(context, dict):
raise Exception(f"serialization argument '{context_name}' must be a dict")
if str(obj.id) not in context:
context[str(obj.id)] = None # Indicate that object is being serialized. Prevent loops
context[str(obj.id)] = self.serialize_object_args(obj, args, to_context=False)
return obj.id
def get_creation_config(self) -> JsonTypes or None:
if not self.enable_config:
raise Exception("Config not enabled")
return self._creation_config_json
def is_deletion_allowed(self) -> bool:
return issubclass(self.orm_class, DeletableApiObject)
def get_current_config(self, session: SessionDb, object_id: int):
if not self.enable_config:
raise Exception("Config not enabled")
obj = session.scalar(
self.orm_class.sudo_select()
.where(self.sql_id_column == object_id)
.options(*self._config_load_options)
)
if obj is None:
raise ApiClientException(ERROR_UNKNOWN_OBJECT)
config = {
"fields": [
field.config_get_value_with_description(obj)
for field in flat_map(
lambda var_id: self._fields_by_variant_by_config_id[var_id].values(),
self._get_current_variant_set(obj)
)
if not field.config_only_at_creation
]
}
session.rollback()
return config
def modify_current_config(self,
session: SessionDb,
modifying_user_id: int,
object_id: int,
expected_old_values: CJsonObject,
new_values: CJsonObject):
if not self.enable_config:
raise Exception("Config not enabled")
obj = session.scalar(
self.orm_class.sudo_select()
.where(self.sql_id_column == object_id)
.options(*self._config_load_options)
)
if obj is None:
raise ApiClientException(ERROR_UNKNOWN_OBJECT)
variant_id = self._get_current_variant(obj)
from videoag_common.objects import ChangelogModificationEntry
expected_keys = set(expected_old_values.keys())
for field_id in new_values.keys():
new_value = new_values.get(field_id)
field = self._fields_by_variant_by_config_id[variant_id].get(field_id)
if field is None and variant_id is not None:
field = self._fields_by_variant_by_config_id[None].get(field_id)
if field is None:
new_value.raise_error(f"Unknown field (object variant is: '{variant_id}')")
if field.config_only_at_creation:
new_value.raise_error("Field may only be set at creation")
expected_old = expected_old_values.get(field_id, optional=True)
old_value_json = field.config_get_value(obj)
if expected_old is not None and not expected_old.equals_json(old_value_json):
raise ApiClientException(ERROR_MODIFICATION_UNEXPECTED_CURRENT_VALUE)
field.config_set_value(session, obj, new_value)
new_value_json = field.config_get_value(obj)
if old_value_json != new_value_json:
session.add(ChangelogModificationEntry(
modifying_user_id=modifying_user_id,
object_type=self.id,
object_id=obj.id,
field_id=field.config_id,
old_value=old_value_json,
new_value=new_value_json,
))
expected_keys.discard(field_id)
if len(expected_keys) > 0:
expected_old_values.raise_error(f"No new value was set for fields '{expected_keys}' which have expected old values")
def create_new_object(self,
session: SessionDb,
modifying_user_id: int,
parent_class_id: str or None,
parent_id: int or None,
variant_id: str or None,
values: CJsonObject) -> int:
if not self.enable_config:
raise Exception("Config not enabled")
if not self.config_allow_creation:
raise Exception("Creation not allowed")
if self._variant_field is not None:
if variant_id is None:
raise ApiClientException(ERROR_OBJECT_ERROR("Missing object variant"))
if variant_id not in self._variant_field.str_enums:
raise ApiClientException(ERROR_OBJECT_ERROR("Unknown object variant"))
obj = self._orm_classes_by_variant[variant_id]()
else:
if variant_id is not None:
raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no variant"))
obj = self.orm_class()
parent_class = None
if len(self._parent_relationship_id_by_class_id) > 0:
if parent_class_id is None:
raise ApiClientException(ERROR_OBJECT_ERROR("Missing parent type"))
if parent_id is None:
raise ApiClientException(ERROR_OBJECT_ERROR("Missing parent id"))
if parent_class_id not in self._parent_relationship_id_by_class_id:
raise ApiClientException(ERROR_OBJECT_ERROR("Unknown parent type"))
from ..objects import API_CLASSES_BY_ID
parent_class = API_CLASSES_BY_ID[parent_class_id]
parent_relationship_id = self._parent_relationship_id_by_class_id[parent_class_id]
parent_obj = session.scalar(
parent_class.orm_class.sudo_select()
.where(parent_class.sql_id_column == parent_id)
)
if parent_obj is None:
raise ApiClientException(ERROR_OBJECT_ERROR("Unknown parent object"))
setattr(obj, parent_relationship_id, parent_obj)
else:
if parent_class_id is not None or parent_id is not None:
raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no parent"))
from videoag_common.objects import ChangelogModificationEntry, ChangelogCreationEntry
changelog_entries: list[ChangelogModificationEntry] = []
remaining_value_keys = set(values.keys())
for var_id in {variant_id, None}:
for field in self._fields_by_variant_by_config_id[var_id].values():
if field is self._variant_field:
continue
remaining_value_keys.discard(field.config_id)
if values.has(field.config_id):
field_value = values.get(field.config_id)
else:
field_value = field.config_get_default()
if field_value is None:
raise ApiClientException(ERROR_OBJECT_ERROR(f"Missing value for field '{field.config_id}'"))
if field_value == (None,):
field_value = None
field_value = CJsonValue(field_value)
field.config_set_value(session, obj, field_value)
new_value_json = field.config_get_value(obj)
from videoag_common.objects import ChangelogModificationEntry
changelog_entries.append(ChangelogModificationEntry(
modifying_user_id=modifying_user_id,
object_type=self.id,
field_id=field.config_id,
old_value=None,
new_value=new_value_json,
))
if len(remaining_value_keys) > 0:
values.raise_error(f"Unknown fields: '{remaining_value_keys}' (May only be available in other variants)")
session.add(obj)
session.flush()
for entry in changelog_entries:
entry.object_id = obj.id
session.add(entry)
session.add(ChangelogCreationEntry(
modifying_user_id=modifying_user_id,
object_type=self.id,
object_id=obj.id,
parent_type=None if parent_class is None else parent_class.id,
parent_id=None if parent_class is None else parent_id,
variant=variant_id
))
session.flush()
return obj.id
def modify_deletion_state(self,
session: SessionDb,
modifying_user_id: int,
object_id: int,
new_deleted: bool):
if not self.enable_config:
raise Exception("Config not enabled")
if not self.is_deletion_allowed():
raise Exception("Deletion not allowed")
obj = session.scalar(
self.orm_class.select({
AC_IS_MOD: True,
AC_INCLUDE_DELETED: True,
})
.where(self.sql_id_column == object_id)
)
if obj is None:
raise ApiClientException(ERROR_UNKNOWN_OBJECT)
if obj.deleted == new_deleted:
raise ApiClientException(ERROR_OBJECT_ERROR("Object is already deleted" if new_deleted else "Object is not deleted"))
obj.deleted = new_deleted
from videoag_common.objects import ChangelogDeletionChangeEntry
session.add(ChangelogDeletionChangeEntry(
modifying_user_id=modifying_user_id,
object_type=self.id,
object_id=obj.id,
is_now_deleted=new_deleted
))
def _check_no_api_object_subclass(clazz: type):
if issubclass(clazz, ApiObject):
raise Exception(f"Class '{clazz}' is a sub class of '{ApiObject}' but the ORM Base class is inherited by an "
f"(indirect) parent. This is not supported")
for subclass in clazz.__subclasses__():
_check_no_api_object_subclass(subclass)