From 7df4ce5ce09d0ca2e28496f49f9f91c5a16e1763 Mon Sep 17 00:00:00 2001 From: Lorenzo Asfour <lorenzo.asfour@student.kit.edu> Date: Fri, 19 Jul 2024 13:40:03 +0000 Subject: [PATCH] fix: Correct learning rate update in optimizer, reconstruction, surrogate --- modules/optimizer.py | 3 ++- modules/reconstruction.py | 3 ++- modules/surrogate.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/optimizer.py b/modules/optimizer.py index 278b6cf..92ef14d 100644 --- a/modules/optimizer.py +++ b/modules/optimizer.py @@ -88,7 +88,8 @@ class Optimizer(object): using the reconstruction model loss ''' # set the optimizer - self.optimizer.lr = lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr self.surrogate_model.eval() self.reconstruction_model.eval() diff --git a/modules/reconstruction.py b/modules/reconstruction.py index fb30f99..78076cb 100644 --- a/modules/reconstruction.py +++ b/modules/reconstruction.py @@ -52,7 +52,8 @@ class Reconstruction(torch.nn.Module): train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # set the optimizer - self.optimizer.lr = lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr self.to(self.device) self.train() for epoch in range(n_epochs): diff --git a/modules/surrogate.py b/modules/surrogate.py index 61cde54..c0ea01f 100644 --- a/modules/surrogate.py +++ b/modules/surrogate.py @@ -158,7 +158,8 @@ class Surrogate(torch.nn.Module): # train the surrogate model train_loader = DataLoader(surrogate_dataset, batch_size=batch_size, shuffle=True) # set the optimizer - self.optimizer.lr = lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr self.to(self.device) self.train() -- GitLab