From ae3fcd664ca1d3b4660b23e25793cb400a6d6a23 Mon Sep 17 00:00:00 2001 From: Klaus Rabbertz <klaus.rabbertz@cern.ch> Date: Fri, 21 Mar 2025 14:41:24 +0100 Subject: [PATCH] Some work on nnlojet-combine procedure for my workflow --- tools/nnlojet/nnlojet-combine.py | 67 ++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/tools/nnlojet/nnlojet-combine.py b/tools/nnlojet/nnlojet-combine.py index b24552fd..d7f09576 100755 --- a/tools/nnlojet/nnlojet-combine.py +++ b/tools/nnlojet/nnlojet-combine.py @@ -12,8 +12,8 @@ import multiprocessing as mp import logging from logging.handlers import QueueHandler, QueueListener -from nnlojet_util import NNLOJETHistogram, NNLOJETContainer -import nnlojet_plot +from nnlojet.util import NNLOJETHistogram, NNLOJETContainer +from nnlojet.plot import plot_merge_and, plot_merge_plus class Task_part(): @@ -33,6 +33,9 @@ class Task_part(): self._columns = kwargs.get('columns', None) self._rebin = kwargs.get('rebin', None) self._cumulant = kwargs.get('cumulant', None) + # for combined data files we want to know the observable name + # that we need to extract for the merge + self._obs_name = kwargs.get('obs_name', None) # copy to instance variable to preserve between processes... self._trim_threshold = Task_part._default_trim_threshold self._trim_max_frac = Task_part._default_trim_max_frac @@ -50,7 +53,7 @@ class Task_part(): container = NNLOJETContainer(size=len(self._files), weights=self._qweights) for file in self._files: try: - container.append(NNLOJETHistogram(nx=self._nx, filename=file, columns=self._columns, rebin=self._rebin, cumulant=self._cumulant)) + container.append(NNLOJETHistogram(nx=self._nx, filename=file, obs_name=self._obs_name, columns=self._columns, rebin=self._rebin, cumulant=self._cumulant)) except ValueError as e: print(e) print("error reading file:", file) @@ -91,7 +94,7 @@ def process_parts_queue(parts_queue): parts_queue.task_done() -def read_APPLfast(wgt_file): +def read_weights(wgt_file): with open(wgt_file, 'rt') as file: hist_acc = NNLOJETHistogram() nx = 0 @@ -109,7 +112,7 @@ def read_APPLfast(wgt_file): print(hist_acc) -def obs_generator(obs_list): +def obs_generator(obs_list, obs_dict): cross_pattern = re.compile(r'.*cross.*') for obs_line in obs_list: obs_parse = [it.strip() for it in obs_line.split('>')] @@ -124,6 +127,7 @@ def obs_generator(obs_list): nx = 0 if re.match(cross_pattern, obs_in) else 3 #> do we have any cumulant labels? obs_cum = None + obs_name = None obs_parse = obs_in.split() if (len(obs_parse) == 2) and (nx == 3): if obs_in == obs_out: @@ -134,7 +138,14 @@ def obs_generator(obs_list): elif len(obs_parse) != 1: raise ValueError('invalid observables specification: {}'.format(obs_line)) # print([obs_line, obs_in, obs_out, nx, obs_cum]) - yield (obs_line, obs_in, obs_out, nx, obs_cum) + yield (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) + for obs_in in obs_dict.keys(): + obs_line = obs_in + obs_cum = None + for obs_out in obs_dict[obs_in]: + obs_name = obs_out + nx = 0 if re.match(cross_pattern, obs_out) else 3 + yield (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) def main(): @@ -146,9 +157,9 @@ def main(): # multi-process parser.add_argument('-j', '--jobs', type=int, nargs='?', action='store', default='1', help='Specifies the number of jobs to run simultaneously.') - # read in APPLfast weight table - parser.add_argument('--APPLfast', action='store', default=None, - help='Read in an APPLfast weight table and combine.') + # read in weights weight table + parser.add_argument('--weights', action='store', default=None, + help='Read in a weight table and combine.') # Parse the input arguments! args = parser.parse_args() @@ -156,10 +167,10 @@ def main(): config.optionxform = lambda option: option # do not convert to lower case config.read(args.config) - # read in APPLfast weight table? - wgt_file = args.APPLfast + # read in weight table? + wgt_file = args.weights if wgt_file is not None: - read_APPLfast(wgt_file) + read_weights(wgt_file) return cross_pattern = re.compile(r'.*cross.*') @@ -293,6 +304,22 @@ def main(): else: print("couldn't extract observable name from file: {}".format(file)) print("found observables: {}".format(obs_list)) + # now loop over the observables to check if they contain internal observables + # these will be stored in a dict + obs_dict = dict() + for obs in obs_list: + files = glob.glob(config.get('Paths', 'raw_dir') + '/**/*' + obs + '*.dat', recursive=True) + if len(files) > 0: + with open(files[0], 'rt') as histfile: + for line in histfile: + if re.match(r'^\s*#name', line, re.IGNORECASE): + if obs in obs_list: obs_list.remove(obs) + if obs not in obs_dict: obs_dict[obs] = [] + obs_name = line.split()[1] + obs_dict[obs].append(obs_name) + else: + raise ValueError("no file?!") + # # if we want to perform the rebinning *after* merging # # make a reduced set of distinct observables @@ -306,7 +333,7 @@ def main(): /////////// """) # for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_set): - for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_list): + for (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) in obs_generator(obs_list, obs_dict): print("processing observable {}...".format(obs_out)) # combine the different observables @@ -326,7 +353,7 @@ def main(): #rebin = [ float(i) for i in rebin ] #print('rebin = ', rebin) - task = Task_part(nx=nx, files=files, outfile=outfile, weights=qweights, columns=columns, rebin=rebin, cumulant=obs_cum) + task = Task_part(nx=nx, files=files, outfile=outfile, weights=qweights, columns=columns, rebin=rebin, cumulant=obs_cum, obs_name=obs_name) # print(" > submitting {}".format(outfile)) parts_queue.put(task) @@ -371,7 +398,7 @@ def main(): // Merge // /////////// """) - for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_list): + for (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) in obs_generator(obs_list, obs_dict): print("processing observable {}...".format(obs_out)) for mrg in config.options('Merge'): parts = config.get('Merge', mrg) @@ -398,7 +425,7 @@ def main(): # debug plots if qplot: with open(outdir_plots + '/' + mrg + '.' + obs_out + '.plus.plt', 'wt') as fout: - print(nnlojet_plot.plot_merge_plus(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout) + print(plot_merge_plus(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout) elif "|" in parts: parts = list(map(str.strip, parts.split('|'))) @@ -429,7 +456,7 @@ def main(): # debug plots if qplot: with open(outdir_plots + '/' + mrg + '.' + obs_out + '.ratio.plt', 'wt') as fout: - print(nnlojet_plot.plot_merge_and(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout) + print(plot_merge_and(path='../Parts', obs=obs_out, merge=mrg, parts=' '.join(parts)), file=fout) else: raise ValueError("couldn't find '|', '&' nor '+' in the merge:", parts) @@ -441,7 +468,7 @@ def main(): // Final // /////////// """) - for (obs_line, obs_in, obs_out, nx, obs_cum) in obs_generator(obs_list): + for (obs_line, obs_in, obs_out, nx, obs_cum, obs_name) in obs_generator(obs_list, obs_dict): print("processing observable {}...".format(obs_out)) # assemble final histograms for fin in config.options('Final'): @@ -455,8 +482,8 @@ def main(): hist.write_to_file(outfile) # write to the file if qweights: - with open(outdir_final + '/' + fin + '.' + obs_out + '.APPLfast.txt', 'wt') as fout: - print(hist.to_APPLfast_weight(), file=fout) + with open(outdir_final + '/' + fin + '.' + obs_out + '.weights.txt', 'wt') as fout: + print(hist.to_weights(), file=fout) if __name__ == "__main__": -- GitLab