Skip to content
Snippets Groups Projects
Commit 5e4bb512 authored by Jan Kieseler's avatar Jan Kieseler
Browse files

Merge branch 'main' into 'main'

fix: Correct learning rate update in optimizer, reconstruction and surrogate

See merge request !1
parents b48fabda 7df4ce5c
Branches
No related tags found
1 merge request!1fix: Correct learning rate update in optimizer, reconstruction and surrogate
......@@ -88,7 +88,8 @@ class Optimizer(object):
using the reconstruction model loss
'''
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.surrogate_model.eval()
self.reconstruction_model.eval()
......
......@@ -52,7 +52,8 @@ class Reconstruction(torch.nn.Module):
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device)
self.train()
for epoch in range(n_epochs):
......
......@@ -158,7 +158,8 @@ class Surrogate(torch.nn.Module):
# train the surrogate model
train_loader = DataLoader(surrogate_dataset, batch_size=batch_size, shuffle=True)
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device)
self.train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment