Skip to content
Snippets Groups Projects
Commit b7d4b8e5 authored by Lars Sowa's avatar Lars Sowa
Browse files

introduce multiprocessing with extra parameter to specify the the number of processes

parent c48403e8
Branches
No related tags found
1 merge request!10Multiprocessing
import torch
import multiprocessing as mp
from collections import Counter
from math import factorial
......@@ -6,9 +7,14 @@ from torch.autograd import grad
from typing import Tuple, List, Dict, Optional, Any, Union, Callable
from collections.abc import Sequence
##############################################
# Helpers
def _check_for_tuple(ind):
if not isinstance(ind, tuple):
raise ValueError("Indices must be tuple!")
return ind
def _get_factorial_factors(*indices: int) -> float:
"""Function to compute the factorial factors for the taylorcoefficients: Prod_n^len(indices) 1/n!
......@@ -286,11 +292,14 @@ class BaseTaylorAnalysis(object):
# compute TCs
functions = [self._first_order, self._second_order, self._third_order]
try:
return functions[len(indices) - 1](
return (
functions[len(indices) - 1](
forward_kwargs,
tctensor_features_axis,
pred,
*indices,
),
indices,
)
except KeyError:
raise NotImplementedError(
......@@ -309,6 +318,7 @@ class BaseTaylorAnalysis(object):
tctensor_features_axis: int = -1,
additional_idx_to_tctensor: Optional[int] = None,
selected_model_output_idx: Optional[int] = None,
use_parallelization: Optional[int] = None,
) -> Dict[Tuple[int, ...], Any]:
"""Function to handle multiple indices and return the taylorcoefficients as a dictionary: to be used by the user.
......@@ -322,13 +332,13 @@ class BaseTaylorAnalysis(object):
tctensor_features_axis (int, optional): Dimension containing features in tctensor given in forward_kwargs. Defaults to -1.
additional_idx_to_tctensor (int, optional): Index of the tctensor if forward_kwargs[forward_kwargs_tctensor_key] is a list. Defaults to None.
selected_model_output_idx (int, optional): Index of the model output if its output is a sequence. Defaults to 0.
use_parallelization (int, optional): Number of processes to use for parallelization. If None, no multiprocessing is used at all. Defaults to None.
Raises:
ValueError: tc_idx_list must be a List of tuples!
Returns:
Dict: Dictionary with taylorcoefficients. Values are set by the user within the reduce function. Keys are the indices (tuple).
"""
forward_kwargs = CustomForwardDict(
forward_kwargs_tctensor_key, additional_idx_to_tctensor, forward_kwargs
)
......@@ -338,20 +348,34 @@ class BaseTaylorAnalysis(object):
selected_output_node, (int, tuple, type(None))
), "Node must be int, tuple or None!"
# loop over all tc to compute
output = {}
for ind in tc_idx_list:
if not isinstance(ind, tuple):
raise ValueError("tc_idx_list must be a list of tuples!")
# get TCs
out = self._calculate_tc(
args = [
(
forward_kwargs,
selected_output_node,
eval_max_output_node_only,
tctensor_features_axis,
selected_model_output_idx,
*ind,
*_check_for_tuple(ind),
)
for ind in tc_idx_list
]
output = {}
if use_parallelization is not None:
ctx = mp.get_context("spawn")
with ctx.Pool(processes=use_parallelization) as pool:
# Map the process_individual_tc function to the arguments
results = pool.starmap(self._calculate_tc, args)
pool.close()
pool.join()
# Convert the results into the output dictionary
output = {ind: reduce_func(result) for result, ind in results}
else:
for arg in args:
# get TCs
result, ind= self._calculate_tc(*arg)
# apply reduce function
output[ind] = reduce_func(out)
output[ind] = reduce_func(result)
return output
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment