From 952d258fcf7562706727630ac8a6d17232dfd912 Mon Sep 17 00:00:00 2001
From: Felix Metzner <felixmetzner@outlook.com>
Date: Wed, 15 May 2024 08:52:03 +0200
Subject: [PATCH] Adding some checks in shape_sys_evaluater.py.

---
 .../shape_sys_evaluater.py                    | 67 +++++++++++++++++--
 1 file changed, 60 insertions(+), 7 deletions(-)

diff --git a/rdstar/offline_analysis/fitting/dedicated_fit_approach/systematics_manager/shape_sys_evaluater.py b/rdstar/offline_analysis/fitting/dedicated_fit_approach/systematics_manager/shape_sys_evaluater.py
index 8740682d9..56d109ce8 100644
--- a/rdstar/offline_analysis/fitting/dedicated_fit_approach/systematics_manager/shape_sys_evaluater.py
+++ b/rdstar/offline_analysis/fitting/dedicated_fit_approach/systematics_manager/shape_sys_evaluater.py
@@ -308,7 +308,7 @@ class SysColVarManager:
 
                     sys_var_cols = (applied_track_sys_weight_col,)
                 else:
-                    sys_var_cols = (sys_info.var, )
+                    sys_var_cols = (sys_info.var,)
             else:
                 raise RuntimeError("Unknown case encountered!")
 
@@ -328,13 +328,13 @@ class SysColVarManager:
             if sys_info.up_down_from_df is None:
                 up_down_tag: str = ""
             else:
-                _up_tag, _down_tag = sys_info.up_down_from_df
+                _up_tag, _down_tag = sys_info.up_down_from_df  # type: str, str
                 _this_var_col_name: str = sys_var_cols[sys_i]
                 if _up_tag in _this_var_col_name:
                     up_down_tag = "_up"
                 elif _down_tag in _this_var_col_name:
                     assert len(sys_var_cols) % 2 == 0, (len(sys_var_cols), len(sys_var_cols) % 2, sys_var_cols)
-                    sys_i_mod = - int(len(sys_var_cols) / 2)
+                    sys_i_mod = -int(len(sys_var_cols) / 2)
                     up_down_tag = "_down"
                 else:
                     raise RuntimeError(f"No match for {_up_tag} or {_down_tag} in {_this_var_col_name}!")
@@ -344,8 +344,49 @@ class SysColVarManager:
             new_col_names.append(new_col_name)
             df.loc[:, new_col_name] = w_var
 
+        if sys_info.up_down_from_df is not None:
+            # Testing order of up and down variation column names:
+            self.test_var_col_names_for_up_down_var_case(
+                column_names=new_col_names,
+                systematics_info=sys_info,
+            )
+
         return tuple(new_col_names)
 
+    @staticmethod
+    def test_var_col_names_for_up_down_var_case(
+        column_names: List[str],
+        systematics_info: SysColsInfo,
+    ) -> None:
+        assert len(column_names) % 2 == 0, (len(column_names), len(column_names) % 2)
+        _number_of_up_vars: int = int(len(column_names) / 2)
+
+        _up_var_cols: Tuple[str, ...] = tuple(c for c in column_names[:_number_of_up_vars])
+        _down_var_cols: Tuple[str, ...] = tuple(c for c in column_names[_number_of_up_vars:])
+
+        assert systematics_info.up_down_from_df is not None
+        _test_up_tag, _test_down_tag = systematics_info.up_down_from_df  # type: str, str
+
+        assert all(_test_up_tag in c for c in _up_var_cols), (
+            _test_up_tag,
+            [c for c in _up_var_cols if _test_up_tag not in c],
+        )
+        assert all(_test_down_tag in c for c in _down_var_cols), (
+            _test_down_tag,
+            [c for c in _down_var_cols if _test_down_tag not in c],
+        )
+
+        assert all(_test_down_tag not in c for c in _up_var_cols), (
+            _test_up_tag,
+            _test_down_tag,
+            [c for c in _up_var_cols if _test_down_tag in c],
+        )
+        assert all(_test_up_tag not in c for c in _down_var_cols), (
+            _test_down_tag,
+            _test_up_tag,
+            [c for c in _down_var_cols if _test_up_tag in c],
+        )
+
     def get_var_info_for(
         self,
         df: pd.DataFrame,
@@ -375,7 +416,7 @@ class SysColVarManager:
 
     @staticmethod
     def _get_up_down_int_from(col_name: str) -> int:
-        _all_ints: List[int] = list(map(int, re.findall(r'\d+', col_name)))
+        _all_ints: List[int] = list(map(int, re.findall(r"\d+", col_name)))
         return _all_ints[-1]
 
     @staticmethod
@@ -691,11 +732,13 @@ class FitSetupBinningPerRecoChInfo:
             n_components=fit_setup.n_components,
         )
 
+
 # endregion
 
 
 # region Systematics Info Container helper functions
 
+
 def get_rearranged_shape_sys_array_for(
     input_array: np.ndarray,
     shape_sys_dimension: int,
@@ -735,13 +778,17 @@ def get_rearranged_shape_sys_array_for(
 
 
 def write_sys_shape_info_object_to(
-    sys_shape_info_object,
+    sys_shape_info_object: Union["ReducedSystematicsDetails", "UpDownAndDiffSystematicsDetails"],
     filename: str,
     target_dir_path: PathType,
     overwrite: bool = False,
 ) -> None:
     os.makedirs(target_dir_path, exist_ok=True)
 
+    assert isinstance(sys_shape_info_object, (ReducedSystematicsDetails, UpDownAndDiffSystematicsDetails)), (
+        type(sys_shape_info_object),
+    )
+
     file_path: PathType = os.path.join(target_dir_path, filename)
     tmp_file_path: PathType = os.path.join(target_dir_path, f"tmp_{filename}")
     assert os.path.splitext(file_path)[1] == ".pkl", (
@@ -765,6 +812,7 @@ def write_sys_shape_info_object_to(
         pickle.dump(sys_shape_info_object, output_file, pickle.HIGHEST_PROTOCOL)
     os.rename(src=tmp_file_path, dst=file_path)
 
+
 # endregion
 
 
@@ -906,6 +954,8 @@ class ReducedSystematicsDetails:
         with open(file_path, "rb") as input_file:
             reduced_shape_sys_details: ReducedSystematicsDetails = pickle.load(input_file)
 
+        assert isinstance(reduced_shape_sys_details, ReducedSystematicsDetails), type(reduced_shape_sys_details)
+
         return reduced_shape_sys_details
 
 
@@ -914,6 +964,7 @@ class ReducedSystematicsDetails:
 
 # region Up-Down Systematics Info Container
 
+
 @dataclass(frozen=True)
 class UpDownAndDiffSystematicsDetails:
     # Fit Setup Info
@@ -979,6 +1030,8 @@ class UpDownAndDiffSystematicsDetails:
         with open(file_path, "rb") as input_file:
             up_down_shape_sys_details: UpDownAndDiffSystematicsDetails = pickle.load(input_file)
 
+        assert isinstance(up_down_shape_sys_details, UpDownAndDiffSystematicsDetails), type(up_down_shape_sys_details)
+
         return up_down_shape_sys_details
 
 
@@ -1396,7 +1449,7 @@ class SysCovEvaluationManager:
         )
 
         assert all(v in self.possible_sys_col_base_names for v in self._default_eval_settings.keys()), (
-            [v for v in self._default_eval_settings.keys() if v not in self.possible_sys_col_base_names]
+            [v for v in self._default_eval_settings.keys() if v not in self.possible_sys_col_base_names],
         )
 
         self._evaluated_systematics: Dict[str, CovEvaluator] = dict()
@@ -1488,7 +1541,7 @@ class SysDiffEvaluationManager:
         )
 
         assert all(v in self.possible_sys_col_base_names for v in self.sys_diff_infos.keys()), (
-            [v for v in self.sys_diff_infos.keys() if v not in self.possible_sys_col_base_names]
+            [v for v in self.sys_diff_infos.keys() if v not in self.possible_sys_col_base_names],
         )
 
         self._evaluated_systematics: Dict[str, UpDownAndDiffSystematicsDetails] = dict()
-- 
GitLab