Skip to content
Snippets Groups Projects
Commit 950c4832 authored by Lars Sowa's avatar Lars Sowa
Browse files

skeleton for improvements

parent 276f373e
Branches
No related tags found
1 merge request!11Tree Batches
......@@ -115,6 +115,18 @@ def _node_selection(
return pred
def tree_batches(tcs_to_compute: list):
trees = {}
for tc_key in tcs_to_compute:
tree = tc_key[:-1]
if tree not in trees:
trees[tree] = []
trees[tree].append(tc_key)
for batch in trees.values():
yield batch
class CustomForwardDict(dict):
# custom dict wrapper to dynamically access the tctensor
def __init__(
......@@ -163,7 +175,7 @@ class BaseTaylorAnalysis(object):
forward_kwargs: CustomForwardDict,
tctensor_features_axis: int,
pred: torch.Tensor,
ind_i: int,
ind_i_list: int,
) -> torch.Tensor:
"""Method to compute the first order taylorcoefficients.
......@@ -178,7 +190,12 @@ class BaseTaylorAnalysis(object):
# first order grads
gradients = grad(pred, forward_kwargs.tctensor)[0]
return gradients[_get_slice(gradients.shape, ind_i, tctensor_features_axis)]
tcs = {}
for ind_i in ind_i_list:
tcs[(ind_i,)] = gradients[
_get_slice(gradients.shape, ind_i, tctensor_features_axis)
]
return tcs
@torch.enable_grad()
def _second_order(
......@@ -187,7 +204,7 @@ class BaseTaylorAnalysis(object):
tctensor_features_axis: int,
pred: torch.Tensor,
ind_i: int,
ind_j: int,
ind_j_list: list,
) -> torch.Tensor:
"""Method to compute the second order taylorcoefficients.
......@@ -207,9 +224,14 @@ class BaseTaylorAnalysis(object):
)
# second order gradients
gradients = grad(gradients[ind_i], forward_kwargs.tctensor)[0]
# factor for second order taylor terms
gradients *= _get_factorial_factors(ind_i, ind_j)
return gradients[_get_slice(gradients.shape, ind_j, tctensor_features_axis)]
tcs = {}
for ind_j in ind_j_list:
fac = _get_factorial_factors(ind_i, ind_j)
tcs[(ind_i, ind_j)] = (
fac
* gradients[_get_slice(gradients.shape, ind_j, tctensor_features_axis)]
)
return tcs
@torch.enable_grad()
def _third_order(
......@@ -219,7 +241,7 @@ class BaseTaylorAnalysis(object):
pred: torch.Tensor,
ind_i: int,
ind_j: int,
ind_k: int,
ind_k_list: list,
) -> torch.Tensor:
"""Method to compute the third order taylorcoefficients.
......@@ -247,9 +269,14 @@ class BaseTaylorAnalysis(object):
)
# third order gradients
gradients = grad(gradients[ind_j], forward_kwargs.tctensor)[0]
# factor for all third order taylor terms
gradients *= _get_factorial_factors(ind_i, ind_j, ind_k)
return gradients[_get_slice(gradients.shape, ind_k, tctensor_features_axis)]
tcs = {}
for ind_k in ind_k_list:
fac = _get_factorial_factors(ind_i, ind_j, ind_k)
tcs[(ind_i, ind_j, ind_k)] = (
fac
* gradients[_get_slice(gradients.shape, ind_k, tctensor_features_axis)]
)
return tcs
def _calculate_tc(
self,
......@@ -258,7 +285,7 @@ class BaseTaylorAnalysis(object):
eval_max_output_node_only: bool,
tctensor_features_axis: int,
selected_model_output_idx: int,
*indices,
batch: List[Tuple[int, ...]],
) -> torch.Tensor:
"""Method to calculate the taylorcoefficients based on the indices.
......@@ -275,7 +302,6 @@ class BaseTaylorAnalysis(object):
_type_: Output type is specified by the user defined reduce function.
"""
# Make prediction
forward_kwargs.tctensor.requires_grad = True
self.zero_grad()
forward_kwargs.tctensor.grad = None
......@@ -292,15 +318,16 @@ class BaseTaylorAnalysis(object):
# compute TCs
functions = [self._first_order, self._second_order, self._third_order]
order = len(batch[0]) - 1
tree = batch[0][:-1]
batch = [b[-1] for b in batch]
try:
return (
functions[len(indices) - 1](
forward_kwargs,
tctensor_features_axis,
pred,
*indices,
),
indices,
return functions[order](
forward_kwargs,
tctensor_features_axis,
pred,
*tree,
batch,
)
except KeyError:
raise NotImplementedError(
......@@ -350,6 +377,10 @@ class BaseTaylorAnalysis(object):
selected_output_node, (int, tuple, type(None))
), "Node must be int, tuple or None!"
# check if indices are tuples
# do this better
tc_idx_list = [_check_for_tuple(ind) for ind in tc_idx_list]
args = [
(
forward_kwargs,
......@@ -357,9 +388,10 @@ class BaseTaylorAnalysis(object):
eval_max_output_node_only,
tctensor_features_axis,
selected_model_output_idx,
*_check_for_tuple(ind),
ind,
)
for ind in tc_idx_list
for ind in tree_batches(tc_idx_list)
# for ind in tc_idx_list
]
output = {}
......@@ -369,11 +401,15 @@ class BaseTaylorAnalysis(object):
results = [future.result() for future in futures]
# Convert the results into the output dictionary
output = {ind: reduce_func(result) for result, ind in results}
for result in results:
for key, val in result.items():
output[key] = reduce_func(val)
#output = {key: reduce_func(value) for key, value in result.items()}
else:
for arg in args:
# get TCs
result, ind = self._calculate_tc(*arg)
result = self._calculate_tc(*arg)
# apply reduce function
output[ind] = reduce_func(result)
for key, value in result.items():
output[key] = reduce_func(value)
return output
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment