Skip to content
Snippets Groups Projects
Select Git revision
  • 2d60cfa87b637b9417ca1f1e355d15113f9dc74c
  • master default protected
  • forbid-save-as
  • upload-via-token
  • moodle-integration
  • patch-double-tap-seek
  • patch_datum_anzeigen
  • patch_raum_anzeigen
  • intros
  • live_sources
  • bootstrap4
  • modules
12 results

edit.py

Blame
  • Forked from Video AG Infrastruktur / website
    Source project has a limited visibility.
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    object_class.py 26.67 KiB
    from enum import Enum
    
    from videoag_common.database import *
    from videoag_common.miscellaneous import *
    from .fields import *
    from .object import ApiObject, DeletableApiObject, AC_INCLUDE_DELETED, AC_IS_MOD
    from .util import get_relationship_foreign_class
    
    
    def _get_variant_id(clazz: type[ApiObject], optional: bool = False) -> str or None:
        identity = None
        if hasattr(clazz, "__mapper_args__"):
            identity = clazz.__mapper_args__.get("polymorphic_identity")
        if identity is None:
            if optional:
                return None
            raise Exception(f"Missing 'polymorphic_identity' in '__mapper_args__' for class '{clazz.__name__}'")
        if isinstance(identity, Enum):
            identity = identity.value
        
        if not isinstance(identity, str):
            raise Exception(f"'polymorphic_identity' in '__mapper_args__' for class '{clazz.__name__}' is not a string "
                            f"or an enum with a string value")
        
        return identity
    
    
    class ApiObjectClass:
        
        def __init__(self,
                     parent_relationship_config_ids: list[str] or None = None,
                     enable_config: bool or None = None,
                     config_allow_creation: bool = True,
                     enable_data: bool or None = None,
                     ):
            self._parent_relationship_config_ids = parent_relationship_config_ids
            self.enable_config = enable_config
            self.config_allow_creation = config_allow_creation
            self.enable_data = enable_data
            
            self.orm_class = None
            self.id = None
            self.sql_table = None
            self.sql_id_column = None
            self._config_load_options: list[ExecutableOption] = []
            
            self._fields_by_variant_by_config_id: dict[str or None, dict[str, ApiConfigField]] = {
                None: {}
            }
            self._fields_by_variant_by_data_id: dict[str or None, dict[str, ApiDataField]] = {
                None: {}
            }
    
            self._parent_relationship_id_by_class_id = {}
            
            self._variant_field: ApiEnumField or None = None
            self._orm_classes_by_variant: dict[str or None, type[ApiObject]] = {}
        
        def set_class(self, orm_class: type[ApiObject]):
            self.orm_class = orm_class
            self.id = alphanum_camel_case_to_snake_case(orm_class.__name__)
            
            self.sql_table = Base.metadata.tables[self.id]
            self.sql_id_column = self.sql_table.columns["id"]
        
        def _post_init(self, all_classes: dict[str, "ApiObjectClass"]):
            self._post_init_fields(all_classes)
            self._post_init_variants()
            self._post_init_variant_fields(all_classes)
            
            if self.enable_config is None:
                self.enable_config = any(map(
                    lambda fields_by_id: len(fields_by_id) > 0,
                    self._fields_by_variant_by_config_id.values()
                ))
                
            if self.enable_data is None:
                self.enable_data = any(map(
                    lambda fields_by_id: len(fields_by_id) > 0,
                    self._fields_by_variant_by_data_id.values()
                ))
            
            self._post_init_parents(all_classes)
            self._post_init_creation_config()
        
        def _post_init_fields(self, all_classes: dict[str, "ApiObjectClass"]):
            for clazz in set(recursive_flat_map_single(lambda c: c.__bases__, self.orm_class)):
                if not hasattr(clazz, "__all_class_api_fields__"):
                    continue
                # noinspection PyUnresolvedReferences
                for field in clazz.__all_class_api_fields__:
                    assert isinstance(field, ApiField)
                    try:
                        self._add_field(all_classes, field, None, None)
                    except Exception as e:
                        raise Exception(
                            f"While initializing field for member '{field.member_name}' in class '{clazz.__name__}' for "
                            f"'{self.orm_class.__name__}'") from e
        
        def _post_init_variants(self):
            variant_field_id = None
            if hasattr(self.orm_class, "__mapper_args__"):
                # May be none
                variant_field_id = self.orm_class.__mapper_args__.get("polymorphic_on")
            
            if variant_field_id is None:
                return
            
            var_field = self._fields_by_variant_by_config_id[None].get(variant_field_id)
            if var_field is None:
                var_field = self._fields_by_variant_by_data_id[None].get(variant_field_id)
            
            if var_field is None:
                raise Exception(f"Unknown variant field '{variant_field_id}' (determined by 'polymorphic_on' with config id, "
                                f"or data id if config id not found) for API class '{self.orm_class.__name__}' (Variant "
                                f"field may not be in a subclass!, and must be included in the config or data)")
            if self.enable_config and self.config_allow_creation and not var_field.include_in_config:
                raise Exception(f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' must be included "
                                f"in the config if the config allows creation")
            if not isinstance(var_field, ApiEnumField):
                raise Exception(f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' must be an "
                                f"enum!")
            self._variant_field = var_field
            if self._variant_field.may_be_none:
                raise Exception(f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' may not be "
                                f"nullable")
            if self.enable_config and self.config_allow_creation and not self._variant_field.config_only_at_creation:
                raise Exception(
                    f"Variant field '{variant_field_id}' for class '{self.orm_class.__name__}' must be only "
                    f"at creation")
            
            for variant_id in self._variant_field.str_enums:
                self._fields_by_variant_by_config_id[variant_id] = {}
                self._fields_by_variant_by_data_id[variant_id] = {}
            
            # Note that the base variant actually can't have any variant fields (All fields without a variant belong
            # to it (and to all other variants))
            # The base variant is also optional
            base_variant_id = _get_variant_id(self.orm_class, optional=True)
            if base_variant_id is not None and base_variant_id not in self._variant_field.str_enums:
                raise Exception(f"Class '{self.orm_class.__name__}' has unknown variant '{base_variant_id}'")
            
            if base_variant_id is not None:
                self._orm_classes_by_variant[base_variant_id] = self.orm_class
        
        def _post_init_variant_fields(self, all_classes: dict[str, "ApiObjectClass"]):
            variant_ignore_classes = set(recursive_flat_map_single(lambda c: c.__bases__, self.orm_class))
            
            for sub_class in self.orm_class.__subclasses__():
                if not hasattr(sub_class, "__all_class_api_fields__"):
                    continue
                if self._variant_field is None:
                    raise Exception(f"Found API subclass '{sub_class.__name__}' of API class '{self.orm_class.__name__}' "
                                    f"but 'polymorphic_on' is missing in the orm '__mapper_args__'")
                
                variant_id = _get_variant_id(sub_class)
                if variant_id not in self._variant_field.str_enums:
                    raise Exception(f"Unknown variant '{variant_id}' for class '{sub_class.__name__}' in enum "
                                    f"'{self._variant_field.enum_class.__name__}'")
                
                self._orm_classes_by_variant[variant_id] = sub_class
                
                # noinspection PyUnresolvedReferences
                for super_sub_class in set(recursive_flat_map_single(lambda c: c.__bases__, sub_class)) - variant_ignore_classes:
                    if not hasattr(super_sub_class, "__all_class_api_fields__"):
                        continue
                    for sub_field in super_sub_class.__all_class_api_fields__:
                        assert isinstance(sub_field, ApiField)
                        try:
                            self._add_field(all_classes, sub_field, sub_class, variant_id)
                        except Exception as e:
                            raise Exception(
                                f"While initializing field for member '{sub_field.member_name}' in class '{super_sub_class.__name__}'") from e
                
                for sub2_class in filter(
                        lambda c: hasattr(c, "__all_class_api_fields__"),
                        recursive_flat_map(
                            lambda c: c.__subclasses__(),
                            sub_class.__subclasses__()
                        )):
                    raise Exception(
                        f"API class '{sub2_class.__name__}' is indirect child of API class '{self.orm_class.__name__}'."
                        f" Indirect inheritance is not supported (Only direct)")
            
            if self._variant_field is not None:
                for variant_id in self._variant_field.str_enums:
                    if variant_id not in self._orm_classes_by_variant:
                        raise Exception(f"No ORM class for variant '{variant_id}' of API object class '{self.id}'")
        
        def _post_init_parents(self, all_classes: dict[str, "ApiObjectClass"]):
            for parent_relation_id in (self._parent_relationship_config_ids or []):
                parent_relation_attr = None
                if hasattr(self.orm_class, parent_relation_id):
                    parent_relation_attr = getattr(self.orm_class, parent_relation_id)
                
                if parent_relation_attr is None:
                    raise TypeError(f"Unknown parent relationship id '{parent_relation_id}' for class "
                                    f"'{self.orm_class.__name__}'")
                parent_relation = parent_relation_attr.prop
                if not isinstance(parent_relation, orm.Relationship):
                    raise TypeError(f"Parent relationship '{parent_relation_id}' for class '{self.orm_class.__name__}' "
                                    f"is not a relationship")
                
                if parent_relation.collection_class is not None:
                    raise TypeError(f"Parent relationship '{parent_relation_id}' for class '{self.orm_class.__name__}' "
                                    f"can reference more than one object")
                
                parent_class = get_relationship_foreign_class(all_classes, parent_relation)
                if parent_class.id in self._parent_relationship_id_by_class_id:
                    raise ValueError(f"Multiple parent fields for parent class '{parent_class.id}' in class '{self.orm_class.__name__}' ")
                self._parent_relationship_id_by_class_id[parent_class.id] = parent_relation_id
        
        def _post_init_creation_config(self):
            self._creation_config_json = None
            if not self.enable_config or not self.config_allow_creation:
                return
            
            var_fields = {
                variant_id: [
                    field.config_get_description()
                    for field in var_fields.values()
                    if field is not self._variant_field
                ] for variant_id, var_fields in self._fields_by_variant_by_config_id.items()
            }
            self._creation_config_json = {
                "fields": var_fields[None]
            }
            if self._variant_field is not None:
                var_fields.pop(None)
                self._creation_config_json["variant_fields"] = var_fields
        
        def _add_field(self,
                       all_classes: dict[str, "ApiObjectClass"],
                       field: ApiField,
                       variant_class: type or None,
                       variant_id: str or None):
            
            attr_class = self.orm_class if variant_id is None else variant_class
            
            attr = None
            if field.member_name is not None and hasattr(attr_class, field.member_name):
                attr = getattr(attr_class, field.member_name)
             
            field = field.pre_post_init_copy()
            context = FieldContext(
                all_classes=all_classes,
                api_class=self,
                own_member=attr,
                variant_id=variant_id
            )
            field.post_init(context)
            field.post_init_check(context)
            
            if isinstance(field, ApiConfigField) and field.include_in_config:
                fields_by_id = self._fields_by_variant_by_config_id[variant_id]
                if field.config_id in fields_by_id:
                    raise ValueError(f"Config id conflict with '{field.config_id}' in variant '{variant_id}'")
                fields_by_id[field.config_id] = field
                if variant_id is not None:
                    fields_by_id = self._fields_by_variant_by_config_id[None]
                    if field.config_id in fields_by_id:
                        raise ValueError(f"Config id conflict with '{field.config_id}' in variant '{variant_id}' and non-variant")
                
                self._config_load_options.extend(field.config_get_load_options())
            
            if isinstance(field, ApiDataField) and field.include_in_data:
                fields_by_id = self._fields_by_variant_by_data_id[variant_id]
                if field.data_id in fields_by_id:
                    raise ValueError(f"Data id conflict with '{field.data_id}' in variant '{variant_id}'")
                fields_by_id[field.data_id] = field
                if variant_id is not None:
                    fields_by_id = self._fields_by_variant_by_data_id[None]
                    if field.data_id in fields_by_id:
                        raise ValueError(f"Data id conflict with '{field.data_id}' in variant '{variant_id}' and non-variant")
        
        def _get_current_variant(self, obj) -> str or None:
            if self._variant_field is None:
                return None
            assert isinstance(self._variant_field, ApiEnumField)
            variant_id = self._variant_field.config_get_value(obj)  # For enums the same as data_get_value
            if variant_id not in self._variant_field.str_enums:
                raise ValueError(f"Unknown variant for class '{self.id}' in database: '{variant_id}' (object id: {obj.id})")
            return variant_id
        
        def _get_current_variant_set(self, obj) -> set[str or None]:
            return {self._get_current_variant(obj), None}
        
        def get_variants(self) -> list[str] or None:
            if self._variant_field is None:
                return None
            return self._variant_field.str_enums
        
        def get_data_fields(self) -> list[ApiDataField]:
            return list(flat_map(lambda by_id: by_id.values(), self._fields_by_variant_by_data_id.values()))
        
        def get_config_fields(self) -> list[ApiConfigField]:
            return list(flat_map(lambda by_id: by_id.values(), self._fields_by_variant_by_config_id.values()))
        
        def get_data_fields_with_id(self, data_id: str) -> list[ApiDataField]:
            field_list = []
            for by_id in self._fields_by_variant_by_data_id.values():
                field = by_id.get(data_id)
                if field is not None:
                    field_list.append(field)
            return field_list
        
        def get_config_fields_with_id(self, config_id: str) -> list[ApiConfigField]:
            field_list = []
            for by_id in self._fields_by_variant_by_config_id.values():
                field = by_id.get(config_id)
                if field is not None:
                    field_list.append(field)
            return field_list
        
        def serialize(self, obj, **kwargs) -> dict:
            return self._serialize_catch_error(obj, to_context=False, **kwargs)
        
        def serialize_to_context(self, obj, **kwargs) -> int:
            return self._serialize_catch_error(obj, to_context=True, **kwargs)
            
        def _serialize_catch_error(self, obj, to_context: bool, **kwargs) -> dict or int:
            if not self.enable_data:
                raise Exception("Serialization not enabled")
            try:
                return self.serialize_object_args(obj, ArgAttributeObject(**kwargs), to_context)
            except AttributeError as e:
                if not str(e).startswith("'ArgAttributeObject' object has no attribute"):
                    raise e
                raise AttributeError(f"Missing a keyword argument for serialization of object '{self.id}': " + str(e)) from e
        
        def serialize_object_args(self, obj, args, to_context: bool = False) -> dict or int:
            if not self.enable_data:
                raise Exception("Serialization not enabled")
            if not to_context:
                res = {}
                for variant_id in self._get_current_variant_set(obj):
                    for field in self._fields_by_variant_by_data_id[variant_id].values():
                        field.add_to_data(obj, res, args)
                return res
            
            context_name = f"{self.id}_context"
            if not hasattr(args, context_name):
                raise Exception(f"Missing serialization argument '{context_name}'")
            context = getattr(args, context_name)
            if not isinstance(context, dict):
                raise Exception(f"serialization argument '{context_name}' must be a dict")
            if str(obj.id) not in context:
                context[str(obj.id)] = None  # Indicate that object is being serialized. Prevent loops
                context[str(obj.id)] = self.serialize_object_args(obj, args, to_context=False)
            return obj.id
        
        def get_creation_config(self) -> JsonTypes or None:
            if not self.enable_config:
                raise Exception("Config not enabled")
            return self._creation_config_json
        
        def is_deletion_allowed(self) -> bool:
            return issubclass(self.orm_class, DeletableApiObject)
        
        def get_current_config(self, session: SessionDb, object_id: int):
            if not self.enable_config:
                raise Exception("Config not enabled")
            obj = session.scalar(
                self.orm_class.sudo_select()
                    .where(self.sql_id_column == object_id)
                    .options(*self._config_load_options)
            )
            if obj is None:
                raise ApiClientException(ERROR_UNKNOWN_OBJECT)
            
            config = {
                "fields": [
                    field.config_get_value_with_description(obj)
                    for field in flat_map(
                        lambda var_id: self._fields_by_variant_by_config_id[var_id].values(),
                        self._get_current_variant_set(obj)
                    )
                    if not field.config_only_at_creation
                ]
            }
            session.rollback()
            return config
        
        def modify_current_config(self,
                                  session: SessionDb,
                                  modifying_user_id: int,
                                  object_id: int,
                                  expected_old_values: CJsonObject,
                                  new_values: CJsonObject):
            if not self.enable_config:
                raise Exception("Config not enabled")
            obj = session.scalar(
                self.orm_class.sudo_select()
                    .where(self.sql_id_column == object_id)
                    .options(*self._config_load_options)
            )
            if obj is None:
                raise ApiClientException(ERROR_UNKNOWN_OBJECT)
            
            variant_id = self._get_current_variant(obj)
            
            from videoag_common.objects import ChangelogModificationEntry
            
            expected_keys = set(expected_old_values.keys())
            for field_id in new_values.keys():
                new_value = new_values.get(field_id)
                
                field = self._fields_by_variant_by_config_id[variant_id].get(field_id)
                if field is None and variant_id is not None:
                    field = self._fields_by_variant_by_config_id[None].get(field_id)
                
                if field is None:
                    new_value.raise_error(f"Unknown field (object variant is: '{variant_id}')")
                
                if field.config_only_at_creation:
                    new_value.raise_error("Field may only be set at creation")
                
                expected_old = expected_old_values.get(field_id, optional=True)
                
                old_value_json = field.config_get_value(obj)
                if expected_old is not None and not expected_old.equals_json(old_value_json):
                    raise ApiClientException(ERROR_MODIFICATION_UNEXPECTED_CURRENT_VALUE)
                
                field.config_set_value(session, obj, new_value)
                
                new_value_json = field.config_get_value(obj)
                
                if old_value_json != new_value_json:
                    session.add(ChangelogModificationEntry(
                        modifying_user_id=modifying_user_id,
                        object_type=self.id,
                        object_id=obj.id,
                        field_id=field.config_id,
                        old_value=old_value_json,
                        new_value=new_value_json,
                    ))
                expected_keys.discard(field_id)
            
            if len(expected_keys) > 0:
                expected_old_values.raise_error(f"No new value was set for fields '{expected_keys}' which have expected old values")
        
        def create_new_object(self,
                              session: SessionDb,
                              modifying_user_id: int,
                              parent_class_id: str or None,
                              parent_id: int or None,
                              variant_id: str or None,
                              values: CJsonObject) -> int:
            if not self.enable_config:
                raise Exception("Config not enabled")
            if not self.config_allow_creation:
                raise Exception("Creation not allowed")
    
            if self._variant_field is not None:
                if variant_id is None:
                    raise ApiClientException(ERROR_OBJECT_ERROR("Missing object variant"))
                if variant_id not in self._variant_field.str_enums:
                    raise ApiClientException(ERROR_OBJECT_ERROR("Unknown object variant"))
                
                obj = self._orm_classes_by_variant[variant_id]()
            else:
                if variant_id is not None:
                    raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no variant"))
                obj = self.orm_class()
            
            parent_class = None
            if len(self._parent_relationship_id_by_class_id) > 0:
                if parent_class_id is None:
                    raise ApiClientException(ERROR_OBJECT_ERROR("Missing parent type"))
                if parent_id is None:
                    raise ApiClientException(ERROR_OBJECT_ERROR("Missing parent id"))
                if parent_class_id not in self._parent_relationship_id_by_class_id:
                    raise ApiClientException(ERROR_OBJECT_ERROR("Unknown parent type"))
                from ..objects import API_CLASSES_BY_ID
                parent_class = API_CLASSES_BY_ID[parent_class_id]
                parent_relationship_id = self._parent_relationship_id_by_class_id[parent_class_id]
                
                parent_obj = session.scalar(
                    parent_class.orm_class.sudo_select()
                                .where(parent_class.sql_id_column == parent_id)
                )
                if parent_obj is None:
                    raise ApiClientException(ERROR_OBJECT_ERROR("Unknown parent object"))
                setattr(obj, parent_relationship_id, parent_obj)
            else:
                if parent_class_id is not None or parent_id is not None:
                    raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no parent"))
            
            from videoag_common.objects import ChangelogModificationEntry, ChangelogCreationEntry
            changelog_entries: list[ChangelogModificationEntry] = []
            
            remaining_value_keys = set(values.keys())
            for var_id in {variant_id, None}:
                for field in self._fields_by_variant_by_config_id[var_id].values():
                    if field is self._variant_field:
                        continue
                    
                    remaining_value_keys.discard(field.config_id)
                    
                    if values.has(field.config_id):
                        field_value = values.get(field.config_id)
                    else:
                        field_value = field.config_get_default()
                        if field_value is None:
                            raise ApiClientException(ERROR_OBJECT_ERROR(f"Missing value for field '{field.config_id}'"))
                        if field_value == (None,):
                            field_value = None
                        field_value = CJsonValue(field_value)
                    
                    field.config_set_value(session, obj, field_value)
                    
                    new_value_json = field.config_get_value(obj)
                    
                    from videoag_common.objects import ChangelogModificationEntry
                    changelog_entries.append(ChangelogModificationEntry(
                        modifying_user_id=modifying_user_id,
                        object_type=self.id,
                        field_id=field.config_id,
                        old_value=None,
                        new_value=new_value_json,
                    ))
            
            if len(remaining_value_keys) > 0:
                values.raise_error(f"Unknown fields: '{remaining_value_keys}' (May only be available in other variants)")
            
            session.add(obj)
            session.flush()
            
            for entry in changelog_entries:
                entry.object_id = obj.id
                session.add(entry)
            
            session.add(ChangelogCreationEntry(
                modifying_user_id=modifying_user_id,
                object_type=self.id,
                object_id=obj.id,
                parent_type=None if parent_class is None else parent_class.id,
                parent_id=None if parent_class is None else parent_id,
                variant=variant_id
            ))
            session.flush()
            return obj.id
        
        def modify_deletion_state(self,
                                  session: SessionDb,
                                  modifying_user_id: int,
                                  object_id: int,
                                  new_deleted: bool):
            if not self.enable_config:
                raise Exception("Config not enabled")
            if not self.is_deletion_allowed():
                raise Exception("Deletion not allowed")
            obj = session.scalar(
                self.orm_class.select({
                    AC_IS_MOD: True,
                    AC_INCLUDE_DELETED: True,
                })
                .where(self.sql_id_column == object_id)
            )
            if obj is None:
                raise ApiClientException(ERROR_UNKNOWN_OBJECT)
            if obj.deleted == new_deleted:
                raise ApiClientException(ERROR_OBJECT_ERROR("Object is already deleted" if new_deleted else "Object is not deleted"))
            obj.deleted = new_deleted
            from videoag_common.objects import ChangelogDeletionChangeEntry
            session.add(ChangelogDeletionChangeEntry(
                modifying_user_id=modifying_user_id,
                object_type=self.id,
                object_id=obj.id,
                is_now_deleted=new_deleted
            ))
    
    
    def _check_no_api_object_subclass(clazz: type):
        if issubclass(clazz, ApiObject):
            raise Exception(f"Class '{clazz}' is a sub class of '{ApiObject}' but the ORM Base class is inherited by an "
                            f"(indirect) parent. This is not supported")
        for subclass in clazz.__subclasses__():
            _check_no_api_object_subclass(subclass)