Skip to content
Snippets Groups Projects
Commit dae3f6b0 authored by Lars Sowa's avatar Lars Sowa
Browse files

add test for mutliprocessing

parent b7d4b8e5
Branches
No related tags found
1 merge request!10Multiprocessing
import torch
import unittest
from src.tayloranalysis import extend_model
from src.tayloranalysis import extend_model, BaseTaylorAnalysis
class Polynom(torch.nn.Module):
class Polynom(torch.nn.Module, BaseTaylorAnalysis):
def __init__(self):
torch.nn.Module.__init__(self)
......@@ -14,7 +14,7 @@ class Polynom(torch.nn.Module):
return x * y + x * y**2 + y**3
class FlippedPolynom(torch.nn.Module):
class FlippedPolynom(torch.nn.Module, BaseTaylorAnalysis):
def __init__(self):
torch.nn.Module.__init__(self)
......@@ -28,8 +28,8 @@ class TestTCComputation(unittest.TestCase):
def setUp(self):
# setup models
self.models = [
extend_model(Polynom)(),
extend_model(FlippedPolynom)(),
Polynom(),
FlippedPolynom(),
]
self.feature_axis = (-1, -2)
......@@ -75,6 +75,26 @@ class TestTCComputation(unittest.TestCase):
# compare result to expected value
self.assertAlmostEqual(tc, self.solution_dict[combination])
def test_coefficients_multiprocessing(self):
for combination in self.solution_dict.keys():
for _model, _feature_axis, _point in zip(
self.models,
self.feature_axis,
self.points,
):
# compute TC
tc = _model.get_tc(
"point",
forward_kwargs={"point": _point},
tc_idx_list=[combination],
tctensor_features_axis=_feature_axis,
use_parallelization=3,
)
tc = tc[combination].item() # as float
with self.subTest(combination=combination, feature_axis=_feature_axis):
# compare result to expected value
self.assertAlmostEqual(tc, self.solution_dict[combination])
if __name__ == "__main__":
# run TestTCComputation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment