Skip to content
Snippets Groups Projects
Commit 9bb8cf2d authored by Lars Sowa's avatar Lars Sowa
Browse files

implement arturs comments

parent a03bafb8
Branches
No related tags found
1 merge request!11Tree Batches
Pipeline #4173 passed
import torch
from collections import Counter
from collections import Counter, defaultdict
from math import factorial
from torch.autograd import grad
from typing import Tuple, List, Dict, Optional, Any, Union, Callable
......@@ -11,12 +11,6 @@ from concurrent.futures import ThreadPoolExecutor
# Helpers
def _check_for_tuple(ind):
if not isinstance(ind, tuple):
raise ValueError("Indices must be tuple!")
return ind
def _get_factorial_factors(*indices: int) -> float:
"""Function to compute the factorial factors for the taylorcoefficients: Prod_n^len(indices) 1/n!
......@@ -30,7 +24,9 @@ def _get_factorial_factors(*indices: int) -> float:
return 1.0 / factor
def _get_summation_indices(shape: torch.Tensor.shape, drop_axis) -> Tuple[int, ...]:
def _get_summation_indices(
shape: torch.Tensor.shape, drop_axis: Union[int, Tuple[int, ...]]
) -> Tuple[int, ...]:
"""Function to get the summation indices for the gradient.
Args:
......@@ -115,36 +111,14 @@ def _node_selection(
return pred
def create_tree_batches(tcs_to_compute: list):
"""
Generator to create batches of taylorcoefficients based on their tree structure so that each tree has to be computed only ones.
A tree structure is defined by the tuple of indices.
E.g. [(0, 1), (0, 2), (1, 2)] will be batched as [(0, 1), (0, 2)], [(1, 2)].
Args:
tcs_to_compute (list): List of tuples with indices for which the taylorcoefficients should be computed.
Yields:
list: Batches of taylorcoefficients based on their tree structure.
"""
# get the trees
trees = {}
for tc_key in tcs_to_compute:
tree = tc_key[:-1]
if tree not in trees.keys():
trees[tree] = []
trees[tree].append(tc_key)
# return the batches
for batch in trees.values():
yield batch
class CustomForwardDict(dict):
# custom dict wrapper to dynamically access the tctensor
def __init__(
self, forward_kwargs_tctensor_key: str, idx_to_tctensor: int, *args, **kwargs
self,
forward_kwargs_tctensor_key: str,
idx_to_tctensor: int,
*args: Any,
**kwargs: Dict[str, Any],
):
super().__init__(*args, **kwargs)
......@@ -189,7 +163,7 @@ class BaseTaylorAnalysis(object):
forward_kwargs: CustomForwardDict,
tctensor_features_axis: int,
pred: torch.Tensor,
ind_i_list: list,
indices_i: List[int],
) -> torch.Tensor:
"""Method to compute the first order taylorcoefficients.
......@@ -207,7 +181,7 @@ class BaseTaylorAnalysis(object):
# get relevant taylorcoefficients
tcs = {}
for ind_i in ind_i_list:
for ind_i in indices_i:
tcs[(ind_i,)] = gradients[
_get_slice(gradients.shape, ind_i, tctensor_features_axis)
]
......@@ -220,7 +194,7 @@ class BaseTaylorAnalysis(object):
tctensor_features_axis: int,
pred: torch.Tensor,
ind_i: int,
ind_j_list: list,
indices_j: List[int],
) -> torch.Tensor:
"""Method to compute the second order taylorcoefficients.
......@@ -247,7 +221,7 @@ class BaseTaylorAnalysis(object):
# get relevant taylorcoefficients for ind_i tree
tcs = {}
for ind_j in ind_j_list:
for ind_j in indices_j:
fac = _get_factorial_factors(ind_i, ind_j)
tcs[(ind_i, ind_j)] = (
fac
......@@ -263,7 +237,7 @@ class BaseTaylorAnalysis(object):
pred: torch.Tensor,
ind_i: int,
ind_j: int,
ind_k_list: list,
indices_k: List[int],
) -> torch.Tensor:
"""Method to compute the third order taylorcoefficients.
......@@ -301,7 +275,7 @@ class BaseTaylorAnalysis(object):
# get relevant taylorcoefficients for ind_i, ind_j tree
tcs = {}
for ind_k in ind_k_list:
for ind_k in indices_k:
fac = _get_factorial_factors(ind_i, ind_j, ind_k)
tcs[(ind_i, ind_j, ind_k)] = (
fac
......@@ -311,7 +285,7 @@ class BaseTaylorAnalysis(object):
def _calculate_tc(
self,
pred,
pred: torch.Tensor,
forward_kwargs: CustomForwardDict,
selected_output_node: int,
eval_max_output_node_only: bool,
......@@ -345,18 +319,18 @@ class BaseTaylorAnalysis(object):
pred = _node_selection(pred, selected_output_node, eval_max_output_node_only)
# compute TCs
functions = [self._first_order, self._second_order, self._third_order]
order = len(batch[0]) - 1
functions = {1: self._first_order, 2: self._second_order, 3: self._third_order}
order = len(batch[0]) # get tc order for current batch
tree = batch[0][:-1] # get the tree structure
batch = [b[-1] for b in batch] # indices without the tree structure
indices_last = [b[-1] for b in batch] # indices without the tree structure
try:
return functions[order](
forward_kwargs,
tctensor_features_axis,
pred,
*tree,
batch,
indices_last,
)
except KeyError:
raise NotImplementedError(
......@@ -399,7 +373,9 @@ class BaseTaylorAnalysis(object):
"""
# check input
tc_idx_list = [_check_for_tuple(ind) for ind in tc_idx_list]
assert all(
isinstance(tc, tuple) for tc in tc_idx_list
), "Indices must be tuple!"
assert isinstance(reduce_func, Callable), "Reduce function must be callable!"
assert isinstance(
selected_output_node, (int, tuple, type(None))
......@@ -416,6 +392,13 @@ class BaseTaylorAnalysis(object):
forward_kwargs.tctensor.requires_grad = True
pred = self(**forward_kwargs)
# create tree batches
trees = defaultdict(list)
for ind in tc_idx_list:
tree = ind[:-1]
trees[tree].append(ind)
# create args for (parallel) computation
args = [
(
pred,
......@@ -424,9 +407,9 @@ class BaseTaylorAnalysis(object):
eval_max_output_node_only,
tctensor_features_axis,
selected_model_output_idx,
ind,
batch,
)
for ind in create_tree_batches(tc_idx_list)
for batch in trees.values()
]
output = {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment