From f9e6d08c7bbe8a56fa6b54532a8510143933cf36 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Simon=20K=C3=BCnzel?= <simonk@fsmpi.rwth-aachen.de>
Date: Sun, 9 Feb 2025 23:50:22 +0100
Subject: [PATCH] Fix not using variant's subclass when creating objects

---
 src/videoag_common/api_object/object_class.py | 38 ++++++++++++-------
 1 file changed, 25 insertions(+), 13 deletions(-)

diff --git a/src/videoag_common/api_object/object_class.py b/src/videoag_common/api_object/object_class.py
index 12db900..672560d 100644
--- a/src/videoag_common/api_object/object_class.py
+++ b/src/videoag_common/api_object/object_class.py
@@ -54,6 +54,7 @@ class ApiObjectClass:
         self._parent_relationship_id_by_class_id = {}
         
         self._variant_field: ApiEnumField or None = None
+        self._orm_classes_by_variant: dict[str or None, type[ApiObject]] = {}
     
     def set_class(self, orm_class: type[ApiObject]):
         self.orm_class = orm_class
@@ -138,6 +139,9 @@ class ApiObjectClass:
         base_variant_id = _get_variant_id(self.orm_class, optional=True)
         if base_variant_id is not None and base_variant_id not in self._variant_field.str_enums:
             raise Exception(f"Class '{self.orm_class.__name__}' has unknown variant '{base_variant_id}'")
+        
+        if base_variant_id is not None:
+            self._orm_classes_by_variant[base_variant_id] = self.orm_class
     
     def _post_init_variant_fields(self, all_classes: dict[str, "ApiObjectClass"]):
         variant_ignore_classes = set(recursive_flat_map_single(lambda c: c.__bases__, self.orm_class))
@@ -154,6 +158,8 @@ class ApiObjectClass:
                 raise Exception(f"Unknown variant '{variant_id}' for class '{sub_class.__name__}' in enum "
                                 f"'{self._variant_field.enum_class.__name__}'")
             
+            self._orm_classes_by_variant[variant_id] = sub_class
+            
             # noinspection PyUnresolvedReferences
             for super_sub_class in set(recursive_flat_map_single(lambda c: c.__bases__, sub_class)) - variant_ignore_classes:
                 if not hasattr(super_sub_class, "__all_class_api_fields__"):
@@ -175,6 +181,11 @@ class ApiObjectClass:
                 raise Exception(
                     f"API class '{sub2_class.__name__}' is indirect child of API class '{self.orm_class.__name__}'."
                     f" Indirect inheritance is not supported (Only direct)")
+        
+        if self._variant_field is not None:
+            for variant_id in self._variant_field.str_enums:
+                if variant_id not in self._orm_classes_by_variant:
+                    raise Exception(f"No ORM class for variant '{variant_id}' of API object class '{self.id}'")
     
     def _post_init_parents(self, all_classes: dict[str, "ApiObjectClass"]):
         for parent_relation_id in (self._parent_relationship_config_ids or []):
@@ -439,8 +450,19 @@ class ApiObjectClass:
             raise Exception("Config not enabled")
         if not self.config_allow_creation:
             raise Exception("Creation not allowed")
+
+        if self._variant_field is not None:
+            if variant_id is None:
+                raise ApiClientException(ERROR_OBJECT_ERROR("Missing object variant"))
+            if variant_id not in self._variant_field.str_enums:
+                raise ApiClientException(ERROR_OBJECT_ERROR("Unknown object variant"))
+            
+            obj = self._orm_classes_by_variant[variant_id]()
+        else:
+            if variant_id is not None:
+                raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no variant"))
+            obj = self.orm_class()
         
-        obj = self.orm_class()
         parent_class = None
         if len(self._parent_relationship_id_by_class_id) > 0:
             if parent_class_id is None:
@@ -464,22 +486,12 @@ class ApiObjectClass:
             if parent_class_id is not None or parent_id is not None:
                 raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no parent"))
         
-        if self._variant_field is not None:
-            if variant_id is None:
-                raise ApiClientException(ERROR_OBJECT_ERROR("Missing object variant"))
-            if variant_id not in self._variant_field.str_enums:
-                raise ApiClientException(ERROR_OBJECT_ERROR("Unknown object variant"))
-            self._variant_field.config_set_value(session, obj, CJsonValue(variant_id))
-        else:
-            if variant_id is not None:
-                raise ApiClientException(ERROR_OBJECT_ERROR("This object may have no variant"))
-
         from videoag_common.objects import ChangelogModificationEntry, ChangelogCreationEntry
         changelog_entries: list[ChangelogModificationEntry] = []
         
         remaining_value_keys = set(values.keys())
-        for variant_id in {variant_id, None}:
-            for field in self._fields_by_variant_by_config_id[variant_id].values():
+        for var_id in {variant_id, None}:
+            for field in self._fields_by_variant_by_config_id[var_id].values():
                 if field is self._variant_field:
                     continue
                 
-- 
GitLab