From ae3fcd664ca1d3b4660b23e25793cb400a6d6a23 Mon Sep 17 00:00:00 2001
From: Klaus Rabbertz <klaus.rabbertz@cern.ch>
Date: Fri, 21 Mar 2025 14:41:24 +0100
Subject: [PATCH] Some work on nnlojet-combine procedure for my workflow

---
 tools/nnlojet/nnlojet-combine.py | 67 ++++++++++++++++++++++----------
 1 file changed, 47 insertions(+), 20 deletions(-)

diff --git a/tools/nnlojet/nnlojet-combine.py b/tools/nnlojet/nnlojet-combine.py
index b24552fd..d7f09576 100755
--- a/tools/nnlojet/nnlojet-combine.py
+++ b/tools/nnlojet/nnlojet-combine.py
@@ -12,8 +12,8 @@ import multiprocessing as mp
 import logging
 from logging.handlers import QueueHandler, QueueListener
 
-from nnlojet_util import NNLOJETHistogram, NNLOJETContainer
-import nnlojet_plot
+from nnlojet.util import NNLOJETHistogram, NNLOJETContainer
+from nnlojet.plot import plot_merge_and, plot_merge_plus
 
 
 class Task_part():
@@ -33,6 +33,9 @@ class Task_part():
         self._columns = kwargs.get('columns', None)
         self._rebin = kwargs.get('rebin', None)
         self._cumulant = kwargs.get('cumulant', None)
+        # for combined data files we want to know the observable name
+        # that we need to extract for the merge
+        self._obs_name = kwargs.get('obs_name', None)
         # copy to instance variable to preserve between processes...
         self._trim_threshold     = Task_part._default_trim_threshold
         self._trim_max_frac      = Task_part._default_trim_max_frac
@@ -50,7 +53,7 @@ class Task_part():
         container = NNLOJETContainer(size=len(self._files), weights=self._qweights)
         for file in self._files:
             try:
-                container.append(NNLOJETHistogram(nx=self._nx, filename=file, columns=self._columns, rebin=self._rebin, cumulant=self._cumulant))
+                container.append(NNLOJETHistogram(nx=self._nx, filename=file, obs_name=self._obs_name, columns=self._columns, rebin=self._rebin, cumulant=self._cumulant))
             except ValueError as e:
                 print(e)
                 print("error reading file:", file)
@@ -91,7 +94,7 @@ def process_parts_queue(parts_queue):
         parts_queue.task_done()
 
 
-def read_APPLfast(wgt_file):
+def read_weights(wgt_file):
     with open(wgt_file, 'rt') as file:
         hist_acc = NNLOJETHistogram()
         nx = 0
@@ -109,7 +112,7 @@ def read_APPLfast(wgt_file):
         print(hist_acc)
 
 
-def obs_generator(obs_list):
+def obs_generator(obs_list, obs_dict):
     cross_pattern = re.compile(r'.*cross.*')
     for obs_line in obs_list:
         obs_parse = [it.strip() for it in obs_line.split('>')]
@@ -124,6 +127,7 @@ def obs_generator(obs_list):
         nx = 0 if re.match(cross_pattern, obs_in) else 3
         #> do we have any cumulant labels?
         obs_cum = None
+        obs_name = None
         obs_parse = obs_in.split()
         if (len(obs_parse) == 2) and (nx == 3):
             if obs_in == obs_out:
@@ -134,7 +138,14 @@ def obs_generator(obs_list):
         elif len(obs_parse) != 1:
             raise ValueError('invalid observables specification: {}'.format(obs_line))
         # print([obs_line, obs_in, obs_out, nx, obs_cum])
-        yield (obs_line, obs_in, obs_out, nx, obs_cum)
+        yield (obs_line, obs_in, obs_out, nx, obs_cum, obs_name)
+    for obs_in in obs_dict.keys():
+        obs_line =  obs_in
+        obs_cum = None
+        for obs_out in obs_dict[obs_in]:
+            obs_name = obs_out
+            nx = 0 if re.match(cross_pattern, obs_out) else 3
+            yield (obs_line, obs_in, obs_out, nx, obs_cum, obs_name)
 
 
 def main():
@@ -146,9 +157,9 @@ def main():
     # multi-process
     parser.add_argument('-j', '--jobs', type=int, nargs='?', action='store', default='1',
                         help='Specifies the number of jobs to run simultaneously.')
-    # read in APPLfast weight table
-    parser.add_argument('--APPLfast', action='store', default=None,
-                        help='Read in an APPLfast weight table and combine.')
+    # read in weights weight table
+    parser.add_argument('--weights', action='store', default=None,
+                        help='Read in a weight table and combine.')
     # Parse the input arguments!
     args = parser.parse_args()
 
@@ -156,10 +167,10 @@ def main():
     config.optionxform = lambda option: option  # do not convert to lower case
     config.read(args.config)
 
-    # read in APPLfast weight table?
-    wgt_file = args.APPLfast
+    # read in weight table?
+    wgt_file = args.weights
     if wgt_file is not None:
-        read_APPLfast(wgt_file)
+        read_weights(wgt_file)
         return
 
     cross_pattern = re.compile(r'.*cross.*')
@@ -293,6 +304,22 @@ def main():
                 else:
                     print("couldn't extract observable name from file: {}".format(file))
         print("found observables: {}".format(obs_list))
+    # now loop over the observables to check if they contain internal observables
+    # these will be stored in a dict
+    obs_dict = dict()
+    for obs in obs_list:
+        files = glob.glob(config.get('Paths', 'raw_dir') + '/**/*' + obs + '*.dat', recursive=True)
+        if len(files) > 0:
+            with open(files[0], 'rt') as histfile:
+                for line in histfile:
+                    if re.match(r'^\s*#name', line, re.IGNORECASE):
+                        if obs in obs_list:  obs_list.remove(obs)
+                        if obs not in obs_dict:  obs_dict[obs] = []
+                        obs_name = line.split()[1]
+                        obs_dict[obs].append(obs_name)
+        else:
+            raise ValueError("no file?!")
+
 
 #    # if we want to perform the rebinning *after* merging
 #    # make a reduced set of distinct observables
@@ -306,7 +333,7 @@ def main():
 ///////////
 """)
 #    for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_set):
-    for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_list):
+    for (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) in obs_generator(obs_list, obs_dict):
         print("processing observable {}...".format(obs_out))
 
         # combine the different observables
@@ -326,7 +353,7 @@ def main():
                 #rebin = [ float(i) for i in rebin ]
                 #print('rebin = ', rebin)
 
-            task = Task_part(nx=nx, files=files, outfile=outfile, weights=qweights, columns=columns, rebin=rebin, cumulant=obs_cum)
+            task = Task_part(nx=nx, files=files, outfile=outfile, weights=qweights, columns=columns, rebin=rebin, cumulant=obs_cum, obs_name=obs_name)
             # print(" > submitting {}".format(outfile))
             parts_queue.put(task)
 
@@ -371,7 +398,7 @@ def main():
 // Merge //
 ///////////
 """)
-        for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_list):
+        for (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) in obs_generator(obs_list, obs_dict):
             print("processing observable {}...".format(obs_out))
             for mrg in config.options('Merge'):
                 parts = config.get('Merge', mrg)
@@ -398,7 +425,7 @@ def main():
                     # debug plots
                     if qplot:
                         with open(outdir_plots + '/' + mrg + '.' + obs_out + '.plus.plt', 'wt') as fout:
-                            print(nnlojet_plot.plot_merge_plus(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout)
+                            print(plot_merge_plus(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout)
 
                 elif "|" in parts:
                     parts = list(map(str.strip, parts.split('|')))
@@ -429,7 +456,7 @@ def main():
                     # debug plots
                     if qplot:
                         with open(outdir_plots + '/' + mrg + '.' + obs_out + '.ratio.plt', 'wt') as fout:
-                            print(nnlojet_plot.plot_merge_and(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout)
+                            print(plot_merge_and(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout)
 
                 else:
                     raise ValueError("couldn't find '|', '&' nor '+' in the merge:", parts)
@@ -441,7 +468,7 @@ def main():
 // Final //
 ///////////
 """)
-    for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_list):
+    for (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) in obs_generator(obs_list, obs_dict):
         print("processing observable {}...".format(obs_out))
         # assemble final histograms
         for fin in config.options('Final'):
@@ -455,8 +482,8 @@ def main():
             hist.write_to_file(outfile)
             # write to the file
             if qweights:
-                with open(outdir_final + '/' + fin + '.' + obs_out + '.APPLfast.txt', 'wt') as fout:
-                    print(hist.to_APPLfast_weight(), file=fout)
+                with open(outdir_final + '/' + fin + '.' + obs_out + '.weights.txt', 'wt') as fout:
+                    print(hist.to_weights(), file=fout)
 
 
 if __name__ == "__main__":
-- 
GitLab