Skip to content
Snippets Groups Projects
Commit db1e1050 authored by Lars Sowa's avatar Lars Sowa
Browse files

Documentation, comments, minor renames for simplicity

parent 950c4832
Branches
No related tags found
1 merge request!11Tree Batches
......@@ -115,14 +115,28 @@ def _node_selection(
return pred
def tree_batches(tcs_to_compute: list):
def create_tree_batches(tcs_to_compute: list):
"""
Generator to create batches of taylorcoefficients based on their tree structure so that each tree has to be computed only ones.
A tree structure is defined by the tuple of indices.
E.g. [(0, 1), (0, 2), (1, 2)] will be batched as [(0, 1), (0, 2)], [(1, 2)].
Args:
tcs_to_compute (list): List of tuples with indices for which the taylorcoefficients should be computed.
Yields:
list: Batches of taylorcoefficients based on their tree structure.
"""
# get the trees
trees = {}
for tc_key in tcs_to_compute:
tree = tc_key[:-1]
if tree not in trees:
if tree not in trees.keys():
trees[tree] = []
trees[tree].append(tc_key)
# return the batches
for batch in trees.values():
yield batch
......@@ -175,7 +189,7 @@ class BaseTaylorAnalysis(object):
forward_kwargs: CustomForwardDict,
tctensor_features_axis: int,
pred: torch.Tensor,
ind_i_list: int,
ind_i_list: list,
) -> torch.Tensor:
"""Method to compute the first order taylorcoefficients.
......@@ -190,6 +204,8 @@ class BaseTaylorAnalysis(object):
# first order grads
gradients = grad(pred, forward_kwargs.tctensor)[0]
# get relevant taylorcoefficients
tcs = {}
for ind_i in ind_i_list:
tcs[(ind_i,)] = gradients[
......@@ -224,6 +240,8 @@ class BaseTaylorAnalysis(object):
)
# second order gradients
gradients = grad(gradients[ind_i], forward_kwargs.tctensor)[0]
# get relevant taylorcoefficients for ind_i tree
tcs = {}
for ind_j in ind_j_list:
fac = _get_factorial_factors(ind_i, ind_j)
......@@ -269,6 +287,8 @@ class BaseTaylorAnalysis(object):
)
# third order gradients
gradients = grad(gradients[ind_j], forward_kwargs.tctensor)[0]
# get relevant taylorcoefficients for ind_i, ind_j tree
tcs = {}
for ind_k in ind_k_list:
fac = _get_factorial_factors(ind_i, ind_j, ind_k)
......@@ -319,8 +339,9 @@ class BaseTaylorAnalysis(object):
# compute TCs
functions = [self._first_order, self._second_order, self._third_order]
order = len(batch[0]) - 1
tree = batch[0][:-1]
batch = [b[-1] for b in batch]
tree = batch[0][:-1] # get the tree structure
batch = [b[-1] for b in batch] # indices without the tree structure
try:
return functions[order](
forward_kwargs,
......@@ -360,7 +381,7 @@ class BaseTaylorAnalysis(object):
tctensor_features_axis (int, optional): Dimension containing features in tctensor given in forward_kwargs. Defaults to -1.
additional_idx_to_tctensor (int, optional): Index of the tctensor if forward_kwargs[forward_kwargs_tctensor_key] is a list. Defaults to None.
selected_model_output_idx (int, optional): Index of the model output if its output is a sequence. Defaults to 0.
n_threads (int, optional): Number of threads to use for parallelization. If None, no multithreading is used at all. Defaults to None.
n_threads (int, optional): Number of threads to parallelize the computation of TCs in tc_idx_list. If None, no multithreading is used at all. Defaults to None.
Raises:
ValueError: tc_idx_list must be a List of tuples!
......@@ -378,7 +399,6 @@ class BaseTaylorAnalysis(object):
), "Node must be int, tuple or None!"
# check if indices are tuples
# do this better
tc_idx_list = [_check_for_tuple(ind) for ind in tc_idx_list]
args = [
......@@ -390,7 +410,7 @@ class BaseTaylorAnalysis(object):
selected_model_output_idx,
ind,
)
for ind in tree_batches(tc_idx_list)
for ind in create_tree_batches(tc_idx_list)
# for ind in tc_idx_list
]
......@@ -402,9 +422,8 @@ class BaseTaylorAnalysis(object):
# Convert the results into the output dictionary
for result in results:
for key, val in result.items():
output[key] = reduce_func(val)
#output = {key: reduce_func(value) for key, value in result.items()}
for key, value in result.items():
output[key] = reduce_func(value)
else:
for arg in args:
# get TCs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment