Skip to content
Snippets Groups Projects
Commit e325f7dc authored by Krishna Krishna Nikhil's avatar Krishna Krishna Nikhil
Browse files

Update reconstruction.py

parent c934eea0
Branches
No related tags found
No related merge requests found
......@@ -7,12 +7,17 @@
import torch
import torch.utils.data
from torch.utils.data import DataLoader
import json
#a simple reconstruction model, just feed-forward for now
class Reconstruction(torch.nn.Module):
def __init__(self, n_detector_parameters, n_input_parameters, n_target_parameters):
super(Reconstruction, self).__init__()
with open('config.json', 'r') as config_file:
config = json.load(config_file)
self.n_parameters = n_detector_parameters
self.preprocess = torch.nn.Sequential(
......@@ -20,7 +25,8 @@ class Reconstruction(torch.nn.Module):
torch.nn.ELU(),
torch.nn.Linear(100,100),
torch.nn.ELU(),
torch.nn.Linear(100, n_input_parameters)
torch.nn.Linear(100, n_input_parameters),
torch.nn.ReLU()
)
# take into account that
self.layers = torch.nn.Sequential(
......@@ -38,7 +44,7 @@ class Reconstruction(torch.nn.Module):
# some placeholders for a simpler training loop
self.optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
self.device = torch.device('cuda')
self.device = torch.device(config["device"]) #cuda
def forward(self, detector_parameters, x):
#concatenate the detector parameters and the input
......@@ -61,7 +67,8 @@ class Reconstruction(torch.nn.Module):
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# set the optimizer
self.optimizer.lr = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.to(self.device)
self.train()
for epoch in range(n_epochs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment