diff --git a/modules/optimizer.py b/modules/optimizer.py index 278b6cf3e06e8bd8aa4f45b5b14d941371b4ad8f..92ef14d9c10923aea974fcf803a51c0d08c6f7d3 100644 --- a/modules/optimizer.py +++ b/modules/optimizer.py @@ -88,7 +88,8 @@ class Optimizer(object): using the reconstruction model loss ''' # set the optimizer - self.optimizer.lr = lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr self.surrogate_model.eval() self.reconstruction_model.eval() diff --git a/modules/reconstruction.py b/modules/reconstruction.py index fb30f99d989ce575d0cee0698cc89edddb7e6586..78076cb47c961d33c9d9657438b5255113a0d1a0 100644 --- a/modules/reconstruction.py +++ b/modules/reconstruction.py @@ -52,7 +52,8 @@ class Reconstruction(torch.nn.Module): train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # set the optimizer - self.optimizer.lr = lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr self.to(self.device) self.train() for epoch in range(n_epochs): diff --git a/modules/surrogate.py b/modules/surrogate.py index 61cde541467f942cc9ef5b4c763ad30674d5af92..c0ea01fea748c6270ad72a7e41b426c17955e3ea 100644 --- a/modules/surrogate.py +++ b/modules/surrogate.py @@ -158,7 +158,8 @@ class Surrogate(torch.nn.Module): # train the surrogate model train_loader = DataLoader(surrogate_dataset, batch_size=batch_size, shuffle=True) # set the optimizer - self.optimizer.lr = lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr self.to(self.device) self.train()