Skip to content
Snippets Groups Projects
Commit 0d0676c8 authored by Krishna Krishna Nikhil's avatar Krishna Krishna Nikhil
Browse files

Update file optimizer.py

parent 61006f1c
Branches
No related tags found
No related merge requests found
......@@ -93,8 +93,6 @@ class Optimizer(object):
diff_loss = torch.mean(100.*torch.nn.ReLU()(abs_maxdiff - self.constraints['diff'])**2)
diff_loss += torch.mean(100.*torch.nn.ReLU()(scint_maxdiff - self.constraints['diff'])**2)
if 'cost' in self.constraints:
# cost = w * cost_material_a + (1-w) * cost_material_b
# then you can scale w in a way that makes it more 'strict', so e.g. with a sigmoid(10*(w-0.5)) (please plot that and check if that makes sense before using it)
def sigmoid(number):
return 1/ (1 + torch.exp(-number))
......@@ -168,7 +166,7 @@ class Optimizer(object):
reco_surrogate_loss = 0
constraint_loss = 0
for epoch in range(n_epochs):
batch_mean_loss = 0
batch_mean_loss = 0 # To use last epochs only
batch_reco_surrogate_loss = 0
batch_constraint_loss = 0
stop_epoch = False
......@@ -226,21 +224,22 @@ class Optimizer(object):
batch_mean_loss /= (batch_idx+1)
batch_reco_surrogate_loss/= (batch_idx+1)
batch_constraint_loss/= (batch_idx+1)
if epoch == (n_epochs - 1):
if epoch == (n_epochs - 1): # Only using Last Epochs(Remove this conditional to use more)
mean_loss += batch_mean_loss
reco_surrogate_loss += batch_reco_surrogate_loss
constraint_loss+= batch_constraint_loss
constraint_loss+= batch_constraint_loss
print('Optimizer Epoch: {} \tLoss: {:.8f}'.format(
epoch, loss.item()))
stopped_at = epoch+1
# stopped_at = epoch+1
if stop_epoch:
break
if not stop_epoch:
scale_factor += 0.08
scale_factor += 0.08 # Simplest approach that works well
print('Scale updated to : ', scale_factor)
# scale_update += 1
self.clamp_parameters()
constraint_loss/= stopped_at
# commented code useful to calculate loss average instead of loss of only last epoch
# constraint_loss/= stopped_at
self.adjust_generator_covariance( self.detector_parameters.detach().cpu().numpy() - initial_parameters )
return self.detector_parameters.detach().cpu().numpy(), True, mean_loss, reco_surrogate_loss, constraint_loss ,scale_factor
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment