Skip to content
Snippets Groups Projects
Commit c934eea0 authored by Krishna Krishna Nikhil's avatar Krishna Krishna Nikhil
Browse files

Scaling is not pressured to reach any certain value

parent bd5d863d
Branches
No related tags found
No related merge requests found
import torch
from torch.utils.data import DataLoader
import numpy as np
import json
class Optimizer(object):
'''
......@@ -17,6 +18,8 @@ class Optimizer(object):
detector_start_parameters,
lr=0.001, batch_size=128,
constraints : dict = None,):
with open('config.json', 'r') as config_file:
config = json.load(config_file)
self.generator = generator
self.surrogate_model = surrogate_model
self.reconstruction_model = reconstruction_model
......@@ -26,7 +29,7 @@ class Optimizer(object):
self.lr = lr
self.batch_size = batch_size
self.constraints = constraints
self.device = torch.device('cuda')
self.device = torch.device(config["device"]) #cuda
self.cu_box = torch.tensor(self.generator.box_size, dtype=torch.float32).to(self.device)
self.detector_parameters = torch.nn.Parameter(torch.tensor(np.array(detector_start_parameters, dtype='float32')).to(self.device), requires_grad=True)
......@@ -39,7 +42,7 @@ class Optimizer(object):
self.surrogate_model.to(device)
self.cu_box.to(device)
def other_constraints(self, dataset, ToP, scale_update):
def other_constraints(self, ToP, scale_factor):
# keep the size of the detector within 3m
raw_detector_parameters = self.detector_parameters
......@@ -64,8 +67,6 @@ class Optimizer(object):
# ,"G4_BRASS": 16.66
# ,"G4_Si": 8330000
# ,"G4_PbWO4": 2500 }
scale_factor = 1 + scale_update * 0.02
if self.constraints is not None:
zero_length = (torch.zeros(number_parameters//2)).to(self.device)
......@@ -129,7 +130,8 @@ class Optimizer(object):
#direction is a vector in parameter space
v = direction
v_length = np.linalg.norm(v)
v_norm = v / v_length
epsilon = 1e-4
v_norm = v / (v_length + epsilon)
s = min_scale *np.max([1., 4.*v_length]) # scale factor at least by a factor of two, if not more
......@@ -141,13 +143,14 @@ class Optimizer(object):
# print('new box_covariance', self.generator.box_covariance)
def optimize(self, dataset, ToP , scale_update ,batch_size, n_epochs, lr, add_constraints = False):
def optimize(self, dataset, ToP , scale_factor, iteration ,batch_size, n_epochs, lr, add_constraints = False):
'''
keep both models fixed, train only the detector parameters (self.detector_start_parameters)
using the reconstruction model loss
'''
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.surrogate_model.eval()
self.reconstruction_model.eval()
......@@ -163,12 +166,11 @@ class Optimizer(object):
reco_surrogate_loss = 0
constraint_loss = 0
for epoch in range(n_epochs):
mean_loss = 0
reco_surrogate_loss = 0
constraint_loss = 0
batch_mean_loss = 0
batch_reco_surrogate_loss = 0
batch_constraint_loss = 0
stop_epoch = False
for batch_idx, (_, true_inputs, true_context, reco_result) in enumerate(data_loader):
# in principle this could also be sampled from the correct distributions; but the distributions are not known in all cases (mostly for context)
# keep in mind for an extension
true_inputs = true_inputs.to(self.device)
......@@ -181,10 +183,10 @@ class Optimizer(object):
true_context)
# calculate the loss
loss = self.reconstruction_model.loss(dataset.unnormalise_target(reco_surrogate), dataset.unnormalise_target(true_inputs))
reco_surrogate_loss += loss.item()
batch_reco_surrogate_loss += loss.item()
if add_constraints:
constraint_loss_value = self.other_constraints(dataset, ToP, scale_update)
constraint_loss += constraint_loss_value.item()
constraint_loss_value = self.other_constraints(ToP, scale_factor)
batch_constraint_loss += constraint_loss_value.item()
loss += constraint_loss_value
self.optimizer.zero_grad()
......@@ -202,13 +204,14 @@ class Optimizer(object):
return self.detector_parameters.detach().cpu().numpy(), False, mean_loss / (batch_idx+1)
self.optimizer.step()
mean_loss += loss.item()
batch_mean_loss += loss.item()
#record steps
#check if the parameters are still local otherwise stop
if not self.generator.is_local(self.detector_parameters.detach().cpu().numpy(),0.8):#a bit smaller box size to be safe
stop_epoch = True
stopped_at = epoch
print("parameter not local")
break
......@@ -218,21 +221,26 @@ class Optimizer(object):
print('current parameters: ')
for k in pdct.keys():
print(k, pdct[k])
batch_mean_loss /= (batch_idx+1)
batch_reco_surrogate_loss/= (batch_idx+1)
batch_constraint_loss/= (batch_idx+1)
if epoch == (n_epochs - 1):
mean_loss += batch_mean_loss
reco_surrogate_loss += batch_reco_surrogate_loss
constraint_loss+= batch_constraint_loss
print('Optimizer Epoch: {} \tLoss: {:.8f}'.format(
epoch, loss.item()))
stopped_at = epoch+1
if stop_epoch:
break
# if not stop_epoch:
# scale_update += 1
# print('Scale updated to : ', scale_update)
scale_update += 1
if not stop_epoch:
scale_factor += 0.08
print('Scale updated to : ', scale_factor)
# scale_update += 1
self.clamp_parameters()
mean_loss /= batch_idx+1
reco_surrogate_loss/= batch_idx+1
constraint_loss/= batch_idx + 1
constraint_loss/= stopped_at
self.adjust_generator_covariance( self.detector_parameters.detach().cpu().numpy() - initial_parameters )
return self.detector_parameters.detach().cpu().numpy(), True, mean_loss, reco_surrogate_loss, constraint_loss ,scale_update
return self.detector_parameters.detach().cpu().numpy(), True, mean_loss, reco_surrogate_loss, constraint_loss ,scale_factor
def get_optimum(self):
return self.detector_parameters.detach().cpu().numpy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment