From f9e6d08c7bbe8a56fa6b54532a8510143933cf36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20K=C3=BCnzel?= <simonk@fsmpi.rwth-aachen.de> Date: Sun, 9 Feb 2025 23:50:22 +0100 Subject: [PATCH] Fix not using variant's subclass when creating objects --- src/videoag_common/api_object/object_class.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/videoag_common/api_object/object_class.py b/src/videoag_common/api_object/object_class.py index 12db900..672560d 100644 --- a/src/videoag_common/api_object/object_class.py +++ b/src/videoag_common/api_object/object_class.py @@ -54,6 +54,7 @@ class ApiObjectClass: self._parent_relationship_id_by_class_id = {} self._variant_field: ApiEnumField or None = None + self._orm_classes_by_variant: dict[str or None, type[ApiObject]] = {} def set_class(self, orm_class: type[ApiObject]): self.orm_class = orm_class @@ -138,6 +139,9 @@ class ApiObjectClass: base_variant_id = _get_variant_id(self.orm_class, optional=True) if base_variant_id is not None and base_variant_id not in self._variant_field.str_enums: raise Exception(f"Class '{self.orm_class.__name__}' has unknown variant '{base_variant_id}'") + + if base_variant_id is not None: + self._orm_classes_by_variant[base_variant_id] = self.orm_class def _post_init_variant_fields(self, all_classes: dict[str, "ApiObjectClass"]): variant_ignore_classes = set(recursive_flat_map_single(lambda c: c.__bases__, self.orm_class)) @@ -154,6 +158,8 @@ class ApiObjectClass: raise Exception(f"Unknown variant '{variant_id}' for class '{sub_class.__name__}' in enum " f"'{self._variant_field.enum_class.__name__}'") + self._orm_classes_by_variant[variant_id] = sub_class + # noinspection PyUnresolvedReferences for super_sub_class in set(recursive_flat_map_single(lambda c: c.__bases__, sub_class)) - variant_ignore_classes: if not hasattr(super_sub_class, "__all_class_api_fields__"): @@ -175,6 +181,11 @@ class ApiObjectClass: raise Exception( f"API class '{sub2_class.__name__}' is indirect child of API class '{self.orm_class.__name__}'." f" Indirect inheritance is not supported (Only direct)") + + if self._variant_field is not None: + for variant_id in self._variant_field.str_enums: + if variant_id not in self._orm_classes_by_variant: + raise Exception(f"No ORM class for variant '{variant_id}' of API object class '{self.id}'") def _post_init_parents(self, all_classes: dict[str, "ApiObjectClass"]): for parent_relation_id in (self._parent_relationship_config_ids or []): @@ -439,8 +450,19 @@ class ApiObjectClass: raise Exception("Config not enabled") if not self.config_allow_creation: raise Exception("Creation not allowed") + + if self._variant_field is not None: + if variant_id is None: + raise ApiClientException(ERROR_OBJECT_ERROR("Missing object variant")) + if variant_id not in self._variant_field.str_enums: + raise ApiClientException(ERROR_OBJECT_ERROR("Unknown object variant")) + + obj = self._orm_classes_by_variant[variant_id]() + else: + if variant_id is not None: + raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no variant")) + obj = self.orm_class() - obj = self.orm_class() parent_class = None if len(self._parent_relationship_id_by_class_id) > 0: if parent_class_id is None: @@ -464,22 +486,12 @@ class ApiObjectClass: if parent_class_id is not None or parent_id is not None: raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no parent")) - if self._variant_field is not None: - if variant_id is None: - raise ApiClientException(ERROR_OBJECT_ERROR("Missing object variant")) - if variant_id not in self._variant_field.str_enums: - raise ApiClientException(ERROR_OBJECT_ERROR("Unknown object variant")) - self._variant_field.config_set_value(session, obj, CJsonValue(variant_id)) - else: - if variant_id is not None: - raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no variant")) - from videoag_common.objects import ChangelogModificationEntry, ChangelogCreationEntry changelog_entries: list[ChangelogModificationEntry] = [] remaining_value_keys = set(values.keys()) - for variant_id in {variant_id, None}: - for field in self._fields_by_variant_by_config_id[variant_id].values(): + for var_id in {variant_id, None}: + for field in self._fields_by_variant_by_config_id[var_id].values(): if field is self._variant_field: continue -- GitLab