Skip to content
Snippets Groups Projects
Commit fc1d05cd authored by Lars Sowa's avatar Lars Sowa
Browse files

fix tests, add docu

parent 41f792d4
Branches
No related tags found
1 merge request!9Flex tensor dev
Pipeline #3452 passed
......@@ -338,6 +338,7 @@ class BaseTaylorAnalysis(object):
reduce_func (Callable, optional): Function to reduce the taylorcoefficients. Defaults to identity.
features_axis (int, optional): Dimension containing features in tensor forward_kwargs.deriv_target. Defaults to -1.
output_idx (Union[int, None], optional): Index of the target tensor if forward_kwargs[target_key] is a list. Defaults to None.
keep_model_output (int, optional): Index of the model output if its output is a sequence. Defaults to 0.
Raises:
ValueError: index_list must be a List of tuples!
......
......@@ -54,7 +54,7 @@ class TestBaseClass:
)
]
for index in combinations:
for eval_max_node_only in [False, True]:
for eval_max_output_node_only in [False, True]:
node_outputs = []
for node in range(mlp_specs["output_neurons"]):
# get singe node results
......@@ -62,8 +62,8 @@ class TestBaseClass:
"x",
forward_kwargs={"x": x_data},
index_list=[index],
node=node,
eval_max_node_only=eval_max_node_only,
output_node=node,
eval_max_output_node_only=eval_max_output_node_only,
)
tc = tc[index]
node_outputs.append(tc)
......@@ -72,14 +72,14 @@ class TestBaseClass:
"x",
forward_kwargs={"x": x_data},
index_list=[index],
node=None,
eval_max_node_only=eval_max_node_only,
output_node=None,
eval_max_output_node_only=eval_max_output_node_only,
)
tc = tc[index]
# sums should be equal
node_outputs = torch.stack(node_outputs, dim=-1).sum(dim=-1)
with unittest.TestCase().subTest(
index=index, eval_max_node_only=eval_max_node_only
index=index, eval_max_output_node_only=eval_max_output_node_only
):
# check if tensors are close
is_close = torch.testing.assert_close(tc, node_outputs) == None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment