Skip to content
Snippets Groups Projects
Commit 89b93c0b authored by Jan Kieseler's avatar Jan Kieseler
Browse files

trying the flow

parent 56355a4e
Branches
No related tags found
No related merge requests found
......@@ -4,7 +4,8 @@ import multiprocessing
import pickle
from generator import Generator, Parameters, SurrogateDataset
from reconstruction import Reconstruction
from surrogate import Surrogate
#from surrogate import Surrogate
from one_dim_flow_surrogate import Surrogate
from optimizer import Optimizer
from matplotlib import pyplot as plt
import time
......@@ -24,7 +25,7 @@ def reset_weights(m):
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
outpath = 'run2_EM_homo_simple3'
outpath = 'TESTING'
os.system('mkdir '+outpath)
outpath += '/'
......@@ -45,7 +46,7 @@ if __name__ == "__main__":
#'thickness_absorber_6': .1,
#'thickness_absorber_7': .1,
#'thickness_absorber_8': .1,
'thickness_scintillator_0': 0.5,
'thickness_scintillator_0': 10.5,
#'thickness_scintillator_1': 0.5,
#'thickness_scintillator_2': 0.5,
#'thickness_scintillator_3': 0.5,
......@@ -62,8 +63,8 @@ if __name__ == "__main__":
gen = Generator(box,
Parameters(parameters = start_pars),
n_vars = 30,
n_events_per_var = 400//divide,
n_vars = 16, # 30, DEBUG
n_events_per_var = 400//divide, #400
particles=[['gamma',0.22]])#,['pi+',2.11]])#0)
......@@ -80,7 +81,7 @@ if __name__ == "__main__":
optimizer = Optimizer(gen, surrogate_model, reco_model, gen.parameters,
constraints= {'length': 25}) #one meter
n_epochs_pre = 30//divide
n_epochs_pre = 100//divide
n_epochs_main = 100//divide
parameters = gen.parameters
......@@ -189,7 +190,7 @@ if __name__ == "__main__":
surrogate_dataset = SurrogateDataset(ds, reco_result.detach().cpu().numpy())#important to detach here
print('surr pre-training 0')
surrogate_model.train_model(surrogate_dataset, batch_size=256, n_epochs= 10, lr=0.03)
surrogate_model.train_model(surrogate_dataset, batch_size=256, n_epochs= n_epochs_pre//2, lr=0.03)
print('surr pre-training 1')
surrogate_model.train_model(surrogate_dataset, batch_size=256, n_epochs= n_epochs_pre, lr=0.01)
print('surr pre-training 2')
......@@ -198,10 +199,15 @@ if __name__ == "__main__":
surrogate_model.train_model(surrogate_dataset, batch_size=1024, n_epochs= n_epochs_pre, lr=0.001)
print('surr pre-training 4')
surrogate_model.train_model(surrogate_dataset, batch_size=1024, n_epochs= n_epochs_pre, lr=0.0003)
print('surr pre-training 5')
surrogate_model.train_model(surrogate_dataset, batch_size=1024, n_epochs= n_epochs_pre, lr=1e-4)
print('surr pre-training 6')
surrogate_model.train_model(surrogate_dataset, batch_size=1024, n_epochs= n_epochs_pre, lr=3e-5)
#these are un-normalised quantities
surr_out, reco_out, true_in = surrogate_model.apply_model_in_batches(surrogate_dataset, batch_size=512)
surr_out, reco_out, true_in = surrogate_model.apply_model_in_batches(surrogate_dataset, batch_size=1024)
make_check_plot(surr_out, reco_out, true_in, evolution, outname = "pretrain")
pre_train()
exit() #DEBUG
best_surrogate_loss = 1e10
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment