Skip to content
Snippets Groups Projects
Commit d4f38e7c authored by Felix Metzner's avatar Felix Metzner
Browse files

Working on rbf memory issue.

parent ba0aae0b
Branches
No related tags found
No related merge requests found
Pipeline #1246 failed
......@@ -96,7 +96,20 @@ class RbfReweighter:
# and
# https://numpy.org/doc/stable/reference/generated/numpy.hsplit.html
weights = self._rbf(*[origin_sample[variable].values for variable in self._variables])
weights = np.zeros(shape=len(origin_sample.index), dtype=np.float64)
for arrays in np.nditer(
[weights] + [origin_sample[v].values for v in self._variables],
flags=['external_loop', 'buffered'],
op_flags=['readwrite']
):
assert len(arrays) == 1 + len(self._variables), (len(arrays), len(self._variables))
i_weights = arrays[0]
assert all(isinstance(a, np.ndarray) for a in arrays), [type(a) for a in arrays]
assert all(a.shape == i_weights.shape for a in arrays), [a.shape for a in arrays]
i_weights = self._rbf(*arrays[1:])
assert len(i_weights.shape) == 1, i_weights.shape
# weights = self._rbf(*[origin_sample[variable].values for variable in self._variables])
weights[weights < 0] = 0.0
num_origin_events = len(origin_sample.index)
summed_weights = np.sum(weights)
......@@ -108,11 +121,16 @@ class RbfReweighter:
logging.debug(f"Scale factor: {self._scale_factor}")
@staticmethod
def euclidean_norm_numpy(x1, x2):
return np.linalg.norm(x1 - x2, axis=0)
def _create_rbf(self):
self._rbf = Rbf(
*[self._weight_coords[:, i] for i in range(len(self._variables))],
self._hist_weights,
function="cubic"
function="cubic",
norm=RbfReweighter.euclidean_norm_numpy
)
def export_to_json(self, output_path: Union[AnyStr, os.PathLike]):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment