From f358e0b216c7b2fa6af8b5b897804636139e3290 Mon Sep 17 00:00:00 2001
From: Felix Metzner <felixmetzner@outlook.com>
Date: Wed, 17 Apr 2024 17:42:44 +0200
Subject: [PATCH] Adding nuisance pull overview plots.

---
 .../dedicated_fit_routine.py                  |  21 ++-
 .../dedicated_fit_approach/plotting_tools.py  | 162 +++++++++++++++++-
 2 files changed, 172 insertions(+), 11 deletions(-)

diff --git a/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py b/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py
index 3dcbdd066..9361b8fee 100644
--- a/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py
+++ b/rdstar/offline_analysis/fitting/dedicated_fit_approach/dedicated_fit_routine.py
@@ -1103,16 +1103,31 @@ class RDStarFitEvaluator:
         self,
         fit_result: RDStarFitResultContainer,
     ) -> None:
-        nuisance_pull_fitter: NuisancePullPlotter = NuisancePullPlotter(base_output_dir_path=self.base_output_dir_path)
+        nuisance_pull_plotter: NuisancePullPlotter = NuisancePullPlotter(base_output_dir_path=self.base_output_dir_path)
+
+        logger.info("Creating nuisance pull overview plot for additive systematics.")
+        sys_infos = tuple(v for v in self.full_systematics_container.values())
+        nuisance_pull_plotter.create_nuisance_pull_overview_plots(
+            additive=True,
+            sys_infos=sys_infos,
+            fit_result=fit_result,
+        )
+
+        logger.info("Creating nuisance pull overview plot for multiplicative systematics.")
+        nuisance_pull_plotter.create_nuisance_pull_overview_plots(
+            additive=False,
+            sys_infos=sys_infos,
+            fit_result=fit_result,
+        )
 
-        logger.info(f"Creating nuisance pull plots and saving them to\n\t{nuisance_pull_fitter.output_dir_path}")
+        logger.info(f"Creating single nuisance pull plots and saving them to\n\t{nuisance_pull_plotter.output_dir_path}")
 
         for sys_name, sys_container in self.full_systematics_container.items():
             if sys_name == RDStarFitter.full_systematics_container.add_sys__mc_statistics.systematics_key:
                 logger.info(f"\tSkipping nuisance pull plot for {sys_name}")
                 continue
             logger.info(f"\tPlotting nuisance pulls for {sys_name}")
-            nuisance_pull_fitter.create_nuisance_pull_plot_from(
+            nuisance_pull_plotter.create_nuisance_pull_plot_from(
                 sys_info=sys_container,
                 fit_result=fit_result,
             )
diff --git a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
index 1ec0538e4..ecaca24e1 100644
--- a/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
+++ b/rdstar/offline_analysis/fitting/dedicated_fit_approach/plotting_tools.py
@@ -82,10 +82,81 @@ class NuisancePullInfo:
         )
 
 
+class NuisancePullOverviewInfos:
+    def __init__(
+        self,
+        additive: bool,
+        sys_infos: Tuple[SystematicsInfo, ...],
+        fit_result: RDStarFitResultContainer,
+    ) -> None:
+        self._additive_systematics: bool = additive
+
+        if self.additive_systematics:
+            overview_sys_infos: Tuple[SystematicsInfo, ...] = tuple(si for si in sys_infos if si.is_additive_sys)
+        else:
+            overview_sys_infos = tuple(si for si in sys_infos if si.is_multiplicative_sys)
+
+        _param_ids_list: List[np.ndarray] = []
+        _nu_latex_label_str_list: List[str] = []
+
+        for sys_info in overview_sys_infos:
+            assert sys_info.is_active, sys_info.systematics_key
+
+            _current_param_ids: np.ndarray = np.array(sys_info.nuisance_param_ids)
+            assert len(_current_param_ids.shape) == 1, (_current_param_ids.shape, sys_info.systematics_key)
+            _param_ids_list.append(_current_param_ids)
+
+            _nn: int = sys_info.number_of_nuisances
+            _current_labels: List[str] = [sys_info.latex_str + "" if _nn == 1 else f" ({i}+1)" for i in range(_nn)]
+            _nu_latex_label_str_list.extend(_current_labels)
+
+        self._param_ids: np.ndarray = np.concatenate(_param_ids_list)
+        self._param_values: np.ndarray = fit_result.param_values[self._param_ids]
+        self._param_errors: np.ndarray = fit_result.errors[self._param_ids]
+
+        self._nu_latex_label_strings: Tuple[str, ...] = tuple(_nu_latex_label_str_list)
+
+        assert self._param_ids.shape == self._param_values.shape, (self._param_ids.shape, self._param_values.shape)
+        assert self._param_ids.shape == self._param_errors.shape, (self._param_ids.shape, self._param_errors.shape)
+        assert len(self._param_ids) == len(self._nu_latex_label_strings), (
+            len(self._param_ids),
+            len(self._nu_latex_label_strings),
+        )
+
+    @property
+    def additive_systematics(self) -> bool:
+        return self._additive_systematics
+
+    @property
+    def name(self) -> str:
+        return "additive_systematics" if self.additive_systematics else "multiplicative_systematics"
+
+    @property
+    def number_of_nuisances(self) -> int:
+        return len(self.parameter_ids)
+
+    @property
+    def parameter_ids(self) -> np.ndarray:
+        return self._param_ids
+
+    @property
+    def parameter_values(self) -> np.ndarray:
+        return self._param_values
+
+    @property
+    def parameter_errors(self) -> np.ndarray:
+        return self._param_errors
+
+    @property
+    def nuisance_pull_latex_labels(self) -> Tuple[str, ...]:
+        return self._nu_latex_label_strings
+
+
 class NuisancePullPlotter:
 
     output_dir_name: ClassVar[str] = "NuisancePulls"
     plot_name_prefix: ClassVar[str] = "nuisance_pulls_for"
+    overview_plot_name_prefix: ClassVar[str] = "nuisance_pulls_overview_for"
 
     def __init__(
         self,
@@ -95,13 +166,16 @@ class NuisancePullPlotter:
         assert os.path.isdir(base_output_dir_path), base_output_dir_path
 
         self.output_dir_path: PathType = os.path.join(base_output_dir_path, self.output_dir_name)
+        self.overview_output_dir_path: PathType = os.path.join(base_output_dir_path, self.output_dir_name, "Overviews")
 
     def create_nuisance_pull_plot_from(
         self,
         sys_info: SystematicsInfo,
         fit_result: RDStarFitResultContainer,
-        fig_size: Tuple[float, float] = (8.0, 5.0),
+        fig_size: Optional[Tuple[float, float]] = None,
     ) -> None:
+        os.makedirs(self.output_dir_path, exist_ok=True)
+
         nuisance_pull_info: NuisancePullInfo = NuisancePullInfo.create_info_from(
             sys_info=sys_info,
             fit_result=fit_result,
@@ -114,15 +188,17 @@ class NuisancePullPlotter:
     def plot_nuisance_pull(
         self,
         infos: NuisancePullInfo,
-        fig_size: Tuple[float, float] = (8.0, 5.0),
+        fig_size: Optional[Tuple[float, float]] = None,
     ) -> None:
         set_matplotlibrc_params()
 
-        os.makedirs(self.output_dir_path, exist_ok=True)
-
         plot_file_name: str = f"{self.plot_name_prefix}_{infos.name}"
+        if fig_size is None:
+            _fig_size: Tuple[float, float] = (1.0 + infos.number_of_nuisances / 4.0, 5.0)
+        else:
+            _fig_size = fig_size
 
-        fig, ax = plt.subplots(figsize=fig_size, dpi=300)  # type: FigureType, AxesType
+        fig, ax = plt.subplots(figsize=_fig_size, dpi=300)  # type: FigureType, AxesType
 
         ax.set_xlim(-0.4, infos.number_of_nuisances - 0.6)
         ax.set_ylim(-2.2, 2.2)
@@ -131,10 +207,10 @@ class NuisancePullPlotter:
         ls_max: int = infos.number_of_nuisances
         ls_num: int = infos.number_of_nuisances + 2
 
-        ax.fill_between(x=np.linspace(ls_min, ls_max, ls_num), y1=-1.0, y2=+1.0, color=TangoColors.chameleon2)
+        ax.fill_between(x=np.linspace(ls_min, ls_max, ls_num), y1=-2.0, y2=+2.0, color=TangoColors.butter2, alpha=1.0)
+        ax.fill_between(x=np.linspace(ls_min, ls_max, ls_num), y1=-1.0, y2=+1.0, color=TangoColors.chameleon2, alpha=1.0)
 
-        ax.fill_between(x=np.linspace(ls_min, ls_max, ls_num), y1=-2.0, y2=-1.0, color=TangoColors.butter2)
-        ax.fill_between(x=np.linspace(ls_min, ls_max, ls_num), y1=+1.0, y2=+2.0, color=TangoColors.butter2)
+        ax.hlines(y=0.0, xmin=ls_min, xmax=ls_max, linestyles="--", lw=1.0, color=KITColors.kit_black)
 
         ax.errorbar(
             x=np.arange(infos.number_of_nuisances),
@@ -158,6 +234,76 @@ class NuisancePullPlotter:
             close_figure=False,
         )
 
+    def create_nuisance_pull_overview_plots(
+        self,
+        additive: bool,
+        sys_infos: Tuple[SystematicsInfo, ...],
+        fit_result: RDStarFitResultContainer,
+        fig_size: Optional[Tuple[float, float]] = None,
+    ) -> None:
+        os.makedirs(self.overview_output_dir_path, exist_ok=True)
+
+        nuisance_pull_overview_infos: NuisancePullOverviewInfos = NuisancePullOverviewInfos(
+            additive=additive,
+            sys_infos=sys_infos,
+            fit_result=fit_result,
+        )
+
+        self.plot_nuisance_pull_overview(
+            infos=nuisance_pull_overview_infos,
+            fig_size=fig_size,
+        )
+
+    def plot_nuisance_pull_overview(
+        self,
+        infos: NuisancePullOverviewInfos,
+        fig_size: Optional[Tuple[float, float]] = None,
+    ) -> None:
+        set_matplotlibrc_params()
+
+        plot_file_name: str = f"{self.overview_plot_name_prefix}_{infos.name}"
+
+        if fig_size is None:
+            _fig_size: Tuple[float, float] = (6.0, 1.0 + infos.number_of_nuisances / 4.0)
+        else:
+            _fig_size = fig_size
+
+        fig, ax = plt.subplots(figsize=_fig_size, dpi=300)  # type: FigureType, AxesType
+
+        ax.set_xlim(-2.2, 2.2)
+        ax.set_ylim(-0.4, infos.number_of_nuisances - 0.6)
+
+        ls_min: int = -1
+        ls_max: int = infos.number_of_nuisances
+        ls_num: int = infos.number_of_nuisances + 2
+
+        ax.fill_betweenx(y=np.linspace(ls_min, ls_max, ls_num), x1=-2.0, x2=+2.0, color=TangoColors.butter2, alpha=1.0)
+        ax.fill_betweenx(y=np.linspace(ls_min, ls_max, ls_num), x1=-1.0, x2=+1.0, color=TangoColors.chameleon2, alpha=1.0)
+
+        ax.vlines(x=0.0, ymin=ls_min, ymax=ls_max, linestyles="--", lw=1.0, color=KITColors.kit_black)
+
+        ax.errorbar(
+            y=np.arange(infos.number_of_nuisances),
+            x=[nu_v for nu_v in infos.parameter_values],
+            xerr=[nu_e for nu_e in infos.parameter_errors],
+            marker=".",
+            color=KITColors.kit_black,
+            linestyle="",
+            capsize=6,
+        )
+
+        plt.title(r"$\mathrm{Nuisance\;Parameter\;Pulls}$", fontsize=22)
+        plt.xlabel(r"$\mathrm{Standard\;Deviations}$", fontsize=18, **xlabel_pos)
+
+        ax.set_yticklabels(infos.nuisance_pull_latex_labels)
+
+        export(
+            fig=fig,
+            filename=plot_file_name,
+            target_dir=self.output_dir_path,
+            close_figure=False,
+        )
+
 
 # endregion
 
-- 
GitLab