Skip to content
Snippets Groups Projects
Commit 7df4ce5c authored by Lorenzo Asfour's avatar Lorenzo Asfour
Browse files

fix: Correct learning rate update in optimizer, reconstruction, surrogate

parent b48fabda
Branches
No related tags found
1 merge request!1fix: Correct learning rate update in optimizer, reconstruction and surrogate
......@@ -88,7 +88,8 @@ class Optimizer(object):
using the reconstruction model loss
'''
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.surrogate_model.eval()
self.reconstruction_model.eval()
......
......@@ -52,7 +52,8 @@ class Reconstruction(torch.nn.Module):
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device)
self.train()
for epoch in range(n_epochs):
......
......@@ -158,7 +158,8 @@ class Surrogate(torch.nn.Module):
# train the surrogate model
train_loader = DataLoader(surrogate_dataset, batch_size=batch_size, shuffle=True)
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device)
self.train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment