From aa768301af35e9e8a76250c9914cd2109e29ef0e Mon Sep 17 00:00:00 2001
From: Felix Metzner <felixmetzner@outlook.com>
Date: Thu, 18 Apr 2024 18:44:29 +0200
Subject: [PATCH] More changes to handling of axes and figure size.

---
 .../dedicated_fit_approach/plotting_tools.py    | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
index ceee6329e..b694b123a 100644
--- a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
+++ b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
@@ -409,14 +409,14 @@ class SystematicsShapePlotter:
         self,
         base_output_dir_path: PathType,
         fit_setup: FitSetupTriplet,
-        fig_size: Tuple[float, float] = (6.0, 5.0),
+        fig_size: Optional[Tuple[float, float]] = None,
         height_ratio: Tuple[float, float] = (3.5, 1.0),
     ) -> None:
 
         assert os.path.isdir(base_output_dir_path), base_output_dir_path
 
         self.output_dir_path: PathType = os.path.join(base_output_dir_path, self.output_dir_name)
-        self.fig_size: Tuple[float, float] = fig_size
+        self.fig_size: Optional[Tuple[float, float]] = fig_size
         self.height_ratio: Tuple[float, float] = height_ratio
 
         self._asimov_fit_setup: FitSetupInfoContainer = fit_setup.asimov
@@ -555,7 +555,6 @@ class SystematicsShapePlotter:
 
         _shape_plot_infos: List[SpecificShapePlotInfoEntry] = []
 
-        # TODO: Use!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
         n_plot_rows: int = 3
         if sys_shape_info.components_to_consider is not None:
             n_rows_needed: int = int(np.ceil(len(sys_shape_info.components_to_consider) / 4.0))
@@ -566,10 +565,15 @@ class SystematicsShapePlotter:
             components_to_consider=sys_shape_info.components_to_consider,
         )
 
+        if self.fig_size is None:
+            fig_size: Tuple[float, float] = (8.0, n_plot_rows * 2.0)
+        else:
+            fig_size = self.fig_size
+
         fig, axes = plt.subplots(
             nrows=n_plot_rows,
             ncols=4,
-            figsize=self.fig_size,
+            figsize=fig_size,
             dpi=300,
             sharex="none",
             sharey="row",
@@ -577,7 +581,10 @@ class SystematicsShapePlotter:
 
         assert isinstance(axes, np.ndarray), type(axes)
         assert all(isinstance(_ax, AxesType) for _ax in axes.flatten()), [type(a) for a in axes.flatten()]
-        assert len(axes.shape) == 2, (len(axes.shape), axes.shape)
+        if n_plot_rows > 1:
+            assert len(axes.shape) == 2, (len(axes.shape), axes.shape)
+        else:
+            assert len(axes.shape) == 1, (len(axes.shape), axes.shape)
 
         for comp_id, component in enumerate(self.fit_components):
             if self._skip_this_component(reco_ch_info=reco_ch_info, component_info=component):
-- 
GitLab