Skip to content
Snippets Groups Projects
Commit ddf3b005 authored by Krishna Krishna Nikhil's avatar Krishna Krishna Nikhil
Browse files

Update surrogate.py

parent e325f7dc
Branches
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@
# the training loop is also in here
import torch
from torch.utils.data import DataLoader
import json
def ddpm_schedules(beta1, beta2, T):
"""
......@@ -89,7 +90,10 @@ class Surrogate(torch.nn.Module):
self.register_buffer(k, v)
self.loss_mse = torch.nn.MSELoss()
self.device = torch.device('cuda')
with open('config.json', 'r') as config_file:
config = json.load(config_file)
self.device = torch.device(config["device"]) #cuda
self.t_is = torch.tensor([i / self.n_time_steps for i in range(self.n_time_steps+1)]).to(self.device)
......@@ -158,7 +162,8 @@ class Surrogate(torch.nn.Module):
# train the surrogate model
train_loader = DataLoader(surrogate_dataset, batch_size=batch_size, shuffle=True)
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device)
self.train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment