Skip to content
Snippets Groups Projects
Commit 7df4ce5c authored by Lorenzo Asfour's avatar Lorenzo Asfour
Browse files

fix: Correct learning rate update in optimizer, reconstruction, surrogate

parent b48fabda
Branches main
No related tags found
No related merge requests found
...@@ -88,7 +88,8 @@ class Optimizer(object): ...@@ -88,7 +88,8 @@ class Optimizer(object):
using the reconstruction model loss using the reconstruction model loss
''' '''
# set the optimizer # set the optimizer
self.optimizer.lr = lr for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.surrogate_model.eval() self.surrogate_model.eval()
self.reconstruction_model.eval() self.reconstruction_model.eval()
......
...@@ -52,7 +52,8 @@ class Reconstruction(torch.nn.Module): ...@@ -52,7 +52,8 @@ class Reconstruction(torch.nn.Module):
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# set the optimizer # set the optimizer
self.optimizer.lr = lr for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device) self.to(self.device)
self.train() self.train()
for epoch in range(n_epochs): for epoch in range(n_epochs):
......
...@@ -158,7 +158,8 @@ class Surrogate(torch.nn.Module): ...@@ -158,7 +158,8 @@ class Surrogate(torch.nn.Module):
# train the surrogate model # train the surrogate model
train_loader = DataLoader(surrogate_dataset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(surrogate_dataset, batch_size=batch_size, shuffle=True)
# set the optimizer # set the optimizer
self.optimizer.lr = lr for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device) self.to(self.device)
self.train() self.train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment