Skip to content
Snippets Groups Projects
Commit 839ed40a authored by Krishna Krishna Nikhil's avatar Krishna Krishna Nikhil
Browse files

Update 3 optimizer.py

parent d2bf339d
Branches
No related tags found
No related merge requests found
......@@ -84,18 +84,21 @@ class Optimizer(object):
diff_loss += torch.mean(100.*torch.nn.ReLU()(scint_maxdiff - self.constraints['diff'])**2)
if 'cost' in self.constraints:
abs_mask = (abs_parameters >= 0.5)
cost_abs = torch.where(abs_mask, 25 , 4.166)
scint_mask = (scint_parameters >= 0.5)
cost_scint = torch.where(scint_mask, 2500, 0.001)
# cost = w * cost_material_a + (1-w) * cost_material_b
# then you can scale w in a way that makes it more 'strict', so e.g. with a sigmoid(10*(w-0.5)) (please plot that and check if that makes sense before using it)
def sigmoid(number):
return 1/ (1 + torch.exp(-number))
abs_sigm = sigmoid(10*(abs_parameters-0.5))
cost_abs = abs_sigm * 25 + (1 - abs_sigm) * 4.166
scint_sigm = sigmoid(10*(scint_parameters-0.5))
cost_scint = scint_sigm * 2500 + (1 - scint_sigm) * 0.001
combined = torch.cat((cost_abs, cost_scint), dim=0)
combined_cost = combined.to(self.device)
# bounded_thickness = torch.tensor(thickness_parameters)
# bounded_thickness = torch.where(bounded_thickness > 0, bounded_thickness, 0)
cost = torch.sum(combined_cost * thickness_parameters)
cost_loss = torch.mean(50.* torch.nn.ReLU()(cost - self.constraints['cost'])**2)
......@@ -114,6 +117,7 @@ class Optimizer(object):
return total_length_loss +lower_loss + upper_loss + diff_loss + cost_loss
def clamp_parameters(self):
return
self.detector_parameters.data = self.detector_parameters.data.clamp(1e-3) #DEBUG, NE
def adjust_generator_covariance(self, direction, min_scale=2.0):
......@@ -168,7 +172,7 @@ class Optimizer(object):
true_inputs,
true_context)
# calculate the loss
loss = self.reconstruction_model.loss(dataset.unnormalise_target(reco_surrogate), dataset.unnormalise_target(true_inputs))
loss = 5 * self.reconstruction_model.loss(dataset.unnormalise_target(reco_surrogate), dataset.unnormalise_target(true_inputs))
if add_constraints:
loss += self.other_constraints(dataset)
#
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment