From dfb1ac04fb09908643e3d83c281d0930162c0c23 Mon Sep 17 00:00:00 2001
From: Felix Metzner <felixmetzner@outlook.com>
Date: Thu, 18 Apr 2024 17:36:26 +0200
Subject: [PATCH] Adding shape sys overview plotter.

---
 .../dedicated_fit_routine.py                  |  20 +-
 .../dedicated_fit_approach/plotting_tools.py  | 286 +++++++++++++++++-
 2 files changed, 285 insertions(+), 21 deletions(-)

diff --git a/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py b/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py
index 9361b8fee..f6e0ffef6 100644
--- a/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py
+++ b/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py
@@ -1160,8 +1160,8 @@ class RDStarFitEvaluator:
 
         yield from sys_shape_plotter.create_systematics_shape_plots(
             sys_shape_info=example_sys_shape_info,
-            add_statistical_uncertainties=add_statistical_uncertainties,
-            normalize_pulls=normalize_pulls,
+            # add_statistical_uncertainties=add_statistical_uncertainties,
+            # normalize_pulls=normalize_pulls,
         )
 
     def plot_norm_ff_sys_shape_effects(
@@ -1192,8 +1192,8 @@ class RDStarFitEvaluator:
 
         yield from sys_shape_plotter.create_systematics_shape_plots(
             sys_shape_info=example_sys_shape_info,
-            add_statistical_uncertainties=add_statistical_uncertainties,
-            normalize_pulls=normalize_pulls,
+            # add_statistical_uncertainties=add_statistical_uncertainties,
+            # normalize_pulls=normalize_pulls,
         )
 
     def plot_l_id_sys_shape_effects(
@@ -1224,8 +1224,8 @@ class RDStarFitEvaluator:
 
         yield from sys_shape_plotter.create_systematics_shape_plots(
             sys_shape_info=example_sys_shape_info,
-            add_statistical_uncertainties=add_statistical_uncertainties,
-            normalize_pulls=normalize_pulls,
+            # add_statistical_uncertainties=add_statistical_uncertainties,
+            # normalize_pulls=normalize_pulls,
         )
 
     def plot_tracking_sys_shape_effects(
@@ -1256,8 +1256,8 @@ class RDStarFitEvaluator:
 
         yield from sys_shape_plotter.create_systematics_shape_plots(
             sys_shape_info=example_sys_shape_info,
-            add_statistical_uncertainties=add_statistical_uncertainties,
-            normalize_pulls=normalize_pulls,
+            # add_statistical_uncertainties=add_statistical_uncertainties,
+            # normalize_pulls=normalize_pulls,
         )
 
     def plot_k_short_sys_shape_effects(
@@ -1288,8 +1288,8 @@ class RDStarFitEvaluator:
 
         yield from sys_shape_plotter.create_systematics_shape_plots(
             sys_shape_info=example_sys_shape_info,
-            add_statistical_uncertainties=add_statistical_uncertainties,
-            normalize_pulls=normalize_pulls,
+            # add_statistical_uncertainties=add_statistical_uncertainties,
+            # normalize_pulls=normalize_pulls,
         )
 
     def dump_data_for_external_fitter(
diff --git a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
index b5631b305..be8de9e0e 100644
--- a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
+++ b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
@@ -10,7 +10,8 @@ import numpy as np
 import matplotlib.pyplot as plt
 
 from dataclasses import dataclass
-from typing import Tuple, List, Sequence, Optional, ClassVar, Generator
+from collections.abc import Mapping
+from typing import Tuple, List, Dict, Sequence, Optional, ClassVar, Generator, Iterator
 
 from templatefitter.plotter.plot_utilities import export, AxesType, FigureType
 from templatefitter.plotter.plot_style import KITColors, TangoColors, set_matplotlibrc_params, xlabel_pos, ylabel_pos
@@ -35,8 +36,10 @@ from rdstar.offline_analysis.fitting.dedicated_fit_approach.fit_info_container i
 __all__ = [
     "NuisancePullPlotter",
     "ShapePlotInfoContainer",
-    "SystematicsShapePlotter",
     "SpecificShapePlotInfoContainer",
+    "SystematicsShapePlotter",
+    "SpecificShapeProjectionPlotInfoContainer",
+    "SystematicsShapeProjectionPlotter",
 ]
 
 
@@ -350,20 +353,21 @@ class ShapePlotInfoContainer:
         )
         return np.sqrt(self.pure_bin_counts) / norm
 
+    @property
+    def number_of_eigendirections(self) -> int:
+        return self.relative_shape_error.shape[2]
+
 
 @dataclass(frozen=True)
-class SpecificShapePlotInfoContainer:
+class SpecificShapePlotInfoEntry:
     name: str
     latex_str: str
     subset_index: int
     normed_base_shape: np.ndarray
     shape_error: np.ndarray
     stat_error: np.ndarray
-    observable: FitObservableInfo
     reco_ch_info: RecoChannelInfo
     component_info: ComponentInfo
-    add_statistical_uncertainty: bool
-    normalize_pulls: bool
     scale_factor: Optional[float] = None
 
     def __post_init__(self) -> None:
@@ -373,6 +377,27 @@ class SpecificShapePlotInfoContainer:
             self.normed_base_shape.shape,
             self.shape_error.shape,
         )
+        assert self.normed_base_shape.shape == self.stat_error.shape, (
+            self.normed_base_shape.shape,
+            self.stat_error.shape,
+        )
+
+
+class SpecificShapePlotInfoContainer(Mapping):
+    def __init__(
+        self,
+        shape_plot_infos: Tuple[SpecificShapePlotInfoEntry, ...],
+    ) -> None:
+        self._entries: Dict[str, SpecificShapePlotInfoEntry] = {v.name: v for v in shape_plot_infos}
+
+    def __getitem__(self, item: str) -> SpecificShapePlotInfoEntry:
+        return self._entries[item]
+
+    def __iter__(self) -> Iterator[str]:
+        return iter(self._entries)
+
+    def __len__(self) -> int:
+        return len(self._entries)
 
 
 class SystematicsShapePlotter:
@@ -380,6 +405,245 @@ class SystematicsShapePlotter:
     output_dir_name: ClassVar[str] = "SystematicsShape"
     plot_name_prefix: ClassVar[str] = "sys_shape_effect_for"
 
+    def __init__(
+        self,
+        base_output_dir_path: PathType,
+        fit_setup: FitSetupTriplet,
+        fig_size: Tuple[float, float] = (6.0, 5.0),
+        height_ratio: Tuple[float, float] = (3.5, 1.0),
+    ) -> None:
+
+        assert os.path.isdir(base_output_dir_path), base_output_dir_path
+
+        self.output_dir_path: PathType = os.path.join(base_output_dir_path, self.output_dir_name)
+        self.fig_size: Tuple[float, float] = fig_size
+        self.height_ratio: Tuple[float, float] = height_ratio
+
+        self._asimov_fit_setup: FitSetupInfoContainer = fit_setup.asimov
+        self._fit_binning_info: FitSetupBinningPerRecoChInfo = FitSetupBinningPerRecoChInfo.init_from(
+            fit_setup=self._asimov_fit_setup,
+        )
+
+        self._comp_to_axis_map: Dict[str, int] = {
+            "BpDztau": 3,
+            "BzDmtau": 2,
+            "BpDzStau": 1,
+            "BzDmStau": 0,
+            "BpDzl": 7,
+            "BzDml": 6,
+            "BpDzSl": 5,
+            "BzDmSl": 4,
+            "DSS_in_cB": 8,
+            "DSS_in_nB": 8,
+            "BBbarBKG_in_cB": 9,
+            "BBbarBKG_in_nB": 9,
+            "CBKG_in_cB": 10,
+            "CBKG_in_nB": 10,
+        }
+
+        self._charged_suffix: str = "_in_cB"
+        self._neutral_suffix: str = "_in_nB"
+
+        self._reco_ch_name_to_charge_suffix: Dict[str, str] = {
+            "Bz_to_Dm": self._neutral_suffix,
+            "Bm_to_Dz": self._charged_suffix,
+            "Bz_to_Dsm": self._neutral_suffix,
+            "Bm_to_Dsz": self._charged_suffix,
+        }
+
+    @property
+    def fit_components(self) -> Tuple[ComponentInfo, ...]:
+        return self._asimov_fit_setup.components
+
+    @property
+    def fit_reco_channels(self) -> Tuple[RecoChannelInfo, ...]:
+        return self._asimov_fit_setup.reco_channels
+
+    @property
+    def fit_observable_infos(self) -> Tuple[FitObservableInfo, ...]:
+        return self._asimov_fit_setup.observable_infos
+
+    def create_systematics_shape_plots(
+        self,
+        sys_shape_info: ShapePlotInfoContainer,
+    ) -> Generator[SpecificShapePlotInfoContainer, None, None]:
+        target_dir_path: PathType = os.path.join(self.output_dir_path, sys_shape_info.name)
+        os.makedirs(target_dir_path, exist_ok=True)
+
+        _normed_shape_reco_ch_splits: List[np.ndarray] = np.split(
+            sys_shape_info.normed_base_shape,
+            self._fit_binning_info.reco_ch_split_ids,
+            axis=0,
+        )
+        _relative_shape_error_reco_ch_splits: List[np.ndarray] = np.split(
+            sys_shape_info.relative_shape_error,
+            self._fit_binning_info.reco_ch_split_ids,
+            axis=0,
+        )
+        _relative_stat_error_reco_ch_splits: List[np.ndarray] = np.split(
+            sys_shape_info.relative_stat_error,
+            self._fit_binning_info.reco_ch_split_ids,
+            axis=0,
+        )
+
+        for reco_ch_id, reco_channel_info in enumerate(self.fit_reco_channels):
+            subset_sys_shape_info: ShapePlotInfoContainer = ShapePlotInfoContainer(
+                name=f"{sys_shape_info.name}_for_reco_ch_{reco_channel_info.name}",
+                latex_str=sys_shape_info.latex_str,
+                normed_base_shape=_normed_shape_reco_ch_splits[reco_ch_id],
+                relative_shape_error=_relative_shape_error_reco_ch_splits[reco_ch_id],
+                pure_bin_counts=_relative_stat_error_reco_ch_splits[reco_ch_id],
+                scale_factor=sys_shape_info.scale_factor,
+                components_to_consider=sys_shape_info.components_to_consider,
+                plot_subsets=sys_shape_info.plot_subsets,
+            )
+
+            yield from self.create_systematics_shape_overview_plot_for(
+                sys_shape_info=subset_sys_shape_info,
+                reco_ch_info=reco_channel_info,
+            )
+
+    def create_systematics_shape_overview_plot_for(
+        self,
+        sys_shape_info: ShapePlotInfoContainer,
+        reco_ch_info: RecoChannelInfo,
+    ) -> Generator[SpecificShapePlotInfoContainer, None, None]:
+        for subset_id in range(sys_shape_info.number_of_eigendirections):
+            yield self.plot_systematics_shape_overview_plot_for(
+                subset_id=subset_id,
+                sys_shape_info=sys_shape_info,
+                reco_ch_info=reco_ch_info,
+            )
+
+    def _skip_this_component(
+        self,
+        reco_ch_info: RecoChannelInfo,
+        component_info: ComponentInfo,
+    ) -> bool:
+        ch_suffixes: Tuple[str, ...] = tuple(set(self._reco_ch_name_to_charge_suffix.values()))
+        if not any(component_info.name.endswith(ch_suffix) for ch_suffix in ch_suffixes):
+            return False
+
+        expected_charge_suffix: str = self._reco_ch_name_to_charge_suffix[reco_ch_info.name]
+        if component_info.name.endswith(expected_charge_suffix):
+            return False
+        else:
+            return True
+
+    def plot_systematics_shape_overview_plot_for(
+        self,
+        subset_id: int,
+        sys_shape_info: ShapePlotInfoContainer,
+        reco_ch_info: RecoChannelInfo,
+    ) -> SpecificShapePlotInfoContainer:
+
+        _shape_plot_infos: List[SpecificShapePlotInfoEntry] = []
+
+        # TODO: Use!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+        # sys_shape_info.components_to_consider
+
+        fig, axes = plt.subplots(
+            nrows=4,
+            ncols=3,
+            figsize=self.fig_size,
+            dpi=300,
+            sharex="none",
+            sharey="row",
+        )  # type: FigureType, Sequence[AxesType]
+        assert len(axes) == 12, len(axes)
+
+        for comp_id, component in enumerate(self.fit_components):
+            if self._skip_this_component(reco_ch_info=reco_ch_info, component_info=component):
+                continue
+
+            normed_base_shape: np.ndarray = sys_shape_info.normed_base_shape[:, comp_id]
+            shape_error: np.ndarray = sys_shape_info.relative_shape_error[:, comp_id, subset_id]
+            stat_error: np.ndarray = sys_shape_info.relative_stat_error[:, comp_id]
+
+            _shape_plot_info = SpecificShapePlotInfoEntry(
+                name=f"{sys_shape_info.name}_{subset_id}",
+                latex_str=sys_shape_info.latex_str + f" ({subset_id})",
+                subset_index=subset_id,
+                normed_base_shape=normed_base_shape,
+                shape_error=shape_error,
+                stat_error=stat_error,
+                reco_ch_info=reco_ch_info,
+                component_info=component,
+                scale_factor=sys_shape_info.scale_factor,
+            )
+
+            axis_index: int = self._comp_to_axis_map[component.name]
+            this_axis: AxesType = axes[axis_index]
+
+            self._plot_shape_overview(
+                ax=this_axis,
+                shape_plot_info=_shape_plot_info,
+            )
+
+            _shape_plot_infos.append(_shape_plot_info)
+
+        return SpecificShapePlotInfoContainer(shape_plot_infos=tuple(_shape_plot_infos))
+
+    @staticmethod
+    def _plot_shape_overview(
+        ax: AxesType,
+        shape_plot_info: SpecificShapePlotInfoEntry,
+    ) -> None:
+        assert len(shape_plot_info.shape_error.shape) == 1, shape_plot_info.shape_error.shape
+
+        bin_edges: np.ndarray = np.arange(shape_plot_info.shape_error.shape[0] + 1) - 0.5
+        bin_mids: np.ndarray = (bin_edges[1:] + bin_edges[:-1]) / 2.0
+
+        up_sys: np.ndarray = 1.0 + shape_plot_info.shape_error
+        down_sys: np.ndarray = 1.0 - shape_plot_info.shape_error
+
+        hist_plot_infos: List[Tuple[str, np.ndarray, str]] = [
+            (KITColors.kit_blue, up_sys, "Up/Nom"),
+            (KITColors.kit_green, down_sys, "Down/Nom"),
+        ]
+
+        for h_color, h_weights, h_label in hist_plot_infos:
+            ax.hist(
+                x=bin_mids,
+                bins=bin_edges,
+                weights=h_weights,
+                stacked=False,
+                color=h_color,
+                lw=0.8,
+                label=h_label,
+                histtype="step",
+            )
+
+
+@dataclass(frozen=True)
+class SpecificShapeProjectionPlotInfoContainer:
+    name: str
+    latex_str: str
+    subset_index: int
+    normed_base_shape: np.ndarray
+    shape_error: np.ndarray
+    stat_error: np.ndarray
+    observable: FitObservableInfo
+    reco_ch_info: RecoChannelInfo
+    component_info: ComponentInfo
+    add_statistical_uncertainty: bool
+    normalize_pulls: bool
+    scale_factor: Optional[float] = None
+
+    def __post_init__(self) -> None:
+        assert len(self.normed_base_shape.shape) == 1, (len(self.normed_base_shape.shape), self.normed_base_shape.shape)
+
+        assert self.normed_base_shape.shape == self.shape_error.shape, (
+            self.normed_base_shape.shape,
+            self.shape_error.shape,
+        )
+
+
+class SystematicsShapeProjectionPlotter:
+
+    output_dir_name: ClassVar[str] = "SystematicsShapeProjections"
+    plot_name_prefix: ClassVar[str] = "sys_shape_effect_on_projection_for"
+
     def __init__(
         self,
         base_output_dir_path: PathType,
@@ -425,7 +689,7 @@ class SystematicsShapePlotter:
         component_info: ComponentInfo,
         add_statistical_uncertainty: bool,
         normalize_pulls: bool,
-    ) -> Generator[SpecificShapePlotInfoContainer, None, None]:
+    ) -> Generator[SpecificShapeProjectionPlotInfoContainer, None, None]:
 
         assert len(selected_normed_base_shape.shape) == 1, selected_normed_base_shape.shape
         assert len(selected_relative_shape_error.shape) == 2, selected_relative_shape_error.shape
@@ -499,7 +763,7 @@ class SystematicsShapePlotter:
 
             if sys_shape_info.plot_subsets:
                 for subset_id in range(sys_shape_info.relative_shape_error.shape[2]):
-                    yield SpecificShapePlotInfoContainer(
+                    yield SpecificShapeProjectionPlotInfoContainer(
                         name=sys_shape_info.name,
                         latex_str=sys_shape_info.latex_str,
                         subset_index=subset_id,
@@ -514,7 +778,7 @@ class SystematicsShapePlotter:
                         scale_factor=sys_shape_info.scale_factor,
                     )
 
-            yield SpecificShapePlotInfoContainer(
+            yield SpecificShapeProjectionPlotInfoContainer(
                 name=sys_shape_info.name,
                 latex_str=sys_shape_info.latex_str,
                 subset_index=-1,
@@ -534,7 +798,7 @@ class SystematicsShapePlotter:
         sys_shape_info: ShapePlotInfoContainer,
         add_statistical_uncertainties: bool = True,
         normalize_pulls: bool = True,
-    ) -> Generator[SpecificShapePlotInfoContainer, None, None]:
+    ) -> Generator[SpecificShapeProjectionPlotInfoContainer, None, None]:
         target_dir_path: PathType = os.path.join(self.output_dir_path, sys_shape_info.name)
         os.makedirs(target_dir_path, exist_ok=True)
 
@@ -601,7 +865,7 @@ class SystematicsShapePlotter:
 
     def plot_systematics_shape_effect_for(
         self,
-        spec_sys_shape_info: SpecificShapePlotInfoContainer,
+        spec_sys_shape_info: SpecificShapeProjectionPlotInfoContainer,
         target_dir_path: PathType,
     ) -> None:
         set_matplotlibrc_params()
-- 
GitLab