From 7df4ce5ce09d0ca2e28496f49f9f91c5a16e1763 Mon Sep 17 00:00:00 2001
From: Lorenzo Asfour <lorenzo.asfour@student.kit.edu>
Date: Fri, 19 Jul 2024 13:40:03 +0000
Subject: [PATCH] fix: Correct learning rate update in optimizer,
 reconstruction, surrogate

---
 modules/optimizer.py      | 3 ++-
 modules/reconstruction.py | 3 ++-
 modules/surrogate.py      | 3 ++-
 3 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/modules/optimizer.py b/modules/optimizer.py
index 278b6cf..92ef14d 100644
--- a/modules/optimizer.py
+++ b/modules/optimizer.py
@@ -88,7 +88,8 @@ class Optimizer(object):
         using the reconstruction model loss
         '''
         # set the optimizer
-        self.optimizer.lr = lr
+        for param_group in self.optimizer.param_groups:
+            param_group['lr'] = lr
 
         self.surrogate_model.eval()
         self.reconstruction_model.eval()
diff --git a/modules/reconstruction.py b/modules/reconstruction.py
index fb30f99..78076cb 100644
--- a/modules/reconstruction.py
+++ b/modules/reconstruction.py
@@ -52,7 +52,8 @@ class Reconstruction(torch.nn.Module):
 
         train_loader = DataLoader(dataset,  batch_size=batch_size, shuffle=True)
         # set the optimizer
-        self.optimizer.lr = lr
+        for param_group in self.optimizer.param_groups:
+            param_group['lr'] = lr
         self.to(self.device)
         self.train()
         for epoch in range(n_epochs):
diff --git a/modules/surrogate.py b/modules/surrogate.py
index 61cde54..c0ea01f 100644
--- a/modules/surrogate.py
+++ b/modules/surrogate.py
@@ -158,7 +158,8 @@ class Surrogate(torch.nn.Module):
         # train the surrogate model
         train_loader = DataLoader(surrogate_dataset,  batch_size=batch_size, shuffle=True)
         # set the optimizer
-        self.optimizer.lr = lr
+        for param_group in self.optimizer.param_groups:
+            param_group['lr'] = lr
         self.to(self.device)
         
         self.train()
-- 
GitLab