From 00cf7ae9b2afbbd16ec37562faee2a7fe0b7979c Mon Sep 17 00:00:00 2001
From: Felix Metzner <felixmetzner@outlook.com>
Date: Thu, 18 Apr 2024 18:30:45 +0200
Subject: [PATCH] Fixing bugs in and improvment of handling of axes in
 SystematicsShapePlotter.

---
 .../dedicated_fit_approach/plotting_tools.py  | 73 +++++++++++++------
 1 file changed, 50 insertions(+), 23 deletions(-)

diff --git a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
index be8de9e0e..ceee6329e 100644
--- a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
+++ b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
@@ -304,7 +304,7 @@ class NuisancePullPlotter:
         export(
             fig=fig,
             filename=plot_file_name,
-            target_dir=self.output_dir_path,
+            target_dir=self.overview_output_dir_path,
             close_figure=False,
         )
 
@@ -424,21 +424,21 @@ class SystematicsShapePlotter:
             fit_setup=self._asimov_fit_setup,
         )
 
-        self._comp_to_axis_map: Dict[str, int] = {
-            "BpDztau": 3,
-            "BzDmtau": 2,
-            "BpDzStau": 1,
-            "BzDmStau": 0,
-            "BpDzl": 7,
-            "BzDml": 6,
-            "BpDzSl": 5,
-            "BzDmSl": 4,
-            "DSS_in_cB": 8,
-            "DSS_in_nB": 8,
-            "BBbarBKG_in_cB": 9,
-            "BBbarBKG_in_nB": 9,
-            "CBKG_in_cB": 10,
-            "CBKG_in_nB": 10,
+        self._comp_to_axis_map: Dict[str, Tuple[int, int]] = {
+            "BpDztau": (0, 3),
+            "BzDmtau": (0, 2),
+            "BpDzStau": (0, 1),
+            "BzDmStau": (0, 0),
+            "BpDzl": (1, 3),
+            "BzDml": (1, 2),
+            "BpDzSl": (1, 1),
+            "BzDmSl": (1, 0),
+            "DSS_in_cB": (2, 0),
+            "DSS_in_nB": (2, 0),
+            "BBbarBKG_in_cB": (2, 1),
+            "BBbarBKG_in_nB": (2, 1),
+            "CBKG_in_cB": (2, 2),
+            "CBKG_in_nB": (2, 2),
         }
 
         self._charged_suffix: str = "_in_cB"
@@ -451,6 +451,22 @@ class SystematicsShapePlotter:
             "Bm_to_Dsz": self._charged_suffix,
         }
 
+    def get_comp_to_axis_map(
+        self,
+        components_to_consider: Optional[Tuple[str, ...]],
+    ) -> Dict[str, Tuple[int, int]]:
+        if components_to_consider is None:
+            return {k: v for k, v in self._comp_to_axis_map.items()}
+
+        relevant_row_indices: Tuple[int, ...] = tuple(
+            set(v[0] for k, v in self._comp_to_axis_map.items() if k in components_to_consider)
+        )
+        _new_comp_to_axis_map: Dict[str, Tuple[int, int]] = {
+            k: (v[0] - sum([1 for r in range(v[0]) if r not in relevant_row_indices]), v[1])
+            for k, v in self._comp_to_axis_map.items()
+        }
+        return _new_comp_to_axis_map
+
     @property
     def fit_components(self) -> Tuple[ComponentInfo, ...]:
         return self._asimov_fit_setup.components
@@ -540,17 +556,28 @@ class SystematicsShapePlotter:
         _shape_plot_infos: List[SpecificShapePlotInfoEntry] = []
 
         # TODO: Use!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-        # sys_shape_info.components_to_consider
+        n_plot_rows: int = 3
+        if sys_shape_info.components_to_consider is not None:
+            n_rows_needed: int = int(np.ceil(len(sys_shape_info.components_to_consider) / 4.0))
+            assert n_rows_needed <= 3, (n_rows_needed, len(sys_shape_info.components_to_consider))
+            n_plot_rows = n_rows_needed
+
+        comp_to_axis_map: Dict[str, Tuple[int, int]] = self.get_comp_to_axis_map(
+            components_to_consider=sys_shape_info.components_to_consider,
+        )
 
         fig, axes = plt.subplots(
-            nrows=4,
-            ncols=3,
+            nrows=n_plot_rows,
+            ncols=4,
             figsize=self.fig_size,
             dpi=300,
             sharex="none",
             sharey="row",
-        )  # type: FigureType, Sequence[AxesType]
-        assert len(axes) == 12, len(axes)
+        )  # type: FigureType, np.ndarray
+
+        assert isinstance(axes, np.ndarray), type(axes)
+        assert all(isinstance(_ax, AxesType) for _ax in axes.flatten()), [type(a) for a in axes.flatten()]
+        assert len(axes.shape) == 2, (len(axes.shape), axes.shape)
 
         for comp_id, component in enumerate(self.fit_components):
             if self._skip_this_component(reco_ch_info=reco_ch_info, component_info=component):
@@ -572,8 +599,8 @@ class SystematicsShapePlotter:
                 scale_factor=sys_shape_info.scale_factor,
             )
 
-            axis_index: int = self._comp_to_axis_map[component.name]
-            this_axis: AxesType = axes[axis_index]
+            axis_index_pair: Tuple[int, int] = comp_to_axis_map[component.name]
+            this_axis: AxesType = axes[axis_index_pair[0], axis_index_pair[1]]
 
             self._plot_shape_overview(
                 ax=this_axis,
-- 
GitLab