Skip to content
Snippets Groups Projects
Commit c9da4177 authored by Jan Kieseler's avatar Jan Kieseler
Browse files

single-batch wise gravnet without geometric

parent 68538d50
Branches
No related tags found
No related merge requests found
import torch
# implements the gravnet layer without torch geometric
# this expects single-event batches
class GravNet_Layer(torch.nn.Module):
def __init__(self,
n_input_dimensions,
n_neighbours,
n_S_dimensions,
n_FLR,
n_output_filters):
super(GravNet_Layer, self).__init__()
self.n_neighbours = n_neighbours
self.n_S_dimensions = n_S_dimensions
self.n_FLR = n_FLR
self.n_output_filters = n_output_filters
self.S = torch.nn.Linear(n_input_dimensions, n_S_dimensions)
self.FLR = torch.nn.Linear(n_input_dimensions, n_FLR)
self.out = torch.nn.Linear(2*n_FLR + n_input_dimensions, n_output_filters)
def forward(self, x):
# x is of shape [n_nodes, n_dimensions]
# batch is of shape [n_nodes]
#assert that x is n_nodes x F, and not further nested
assert len(x.shape) == 2
S_coords = self.S(x)
FLR_feat = self.FLR(x)
#build the distance matrix in the S space
S_coords = S_coords.unsqueeze(0)
S_coords_T = S_coords.transpose(0,1)
S_diff = S_coords - S_coords_T
S_dist_sq = torch.sum(S_diff**2, dim=-1)
#find the n_neighbours neighbour indices for each node
neigh_dist_sq, neigh_idx = torch.topk(S_dist_sq, self.n_neighbours, largest=False)
#get the neighbour features
neigh_feat = torch.gather(FLR_feat, 0, neigh_idx) # -> [n_nodes, n_neighbours, n_FLR]
#weight by exp(-10*dist_sq)
weights = torch.exp(-10*neigh_dist_sq) # -> [n_nodes, n_neighbours]
#get the sum of the weighted neighbour features
sum_feat = torch.sum(neigh_feat * weights[:,:,None], dim=1) # -> [n_nodes, n_FLR]
#get the max of the weighted neighbour features
max_feat = torch.max(neigh_feat * weights[:,:,None], dim=1) # -> [n_nodes, n_FLR]
#concatenate the input features with the sum and max features
out = torch.cat([x, sum_feat, max_feat], dim=-1)
#apply the final linear layer
out = self.out(out)
return out
#implements a mean and pool layer over all nodes
class GlobalMeanMaxPool(torch.nn.Module):
def __init__(self):
super(GlobalMeanMaxPool, self).__init__()
def forward(self, x):
#man and keep 0 dimension
mean = torch.mean(x, dim=0, keepdim=True)
max = torch.max(x, dim=0, keepdim=True)
return torch.cat([mean, max], dim=-1) # -> [1, 2*n_output_filters]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment