Skip to content
Snippets Groups Projects
Commit 270a63f1 authored by Aritra Bal's avatar Aritra Bal
Browse files

Fixed minor issue with torch.gather output shape, added MeanMaxSum Pooling...

Fixed minor issue with torch.gather output shape, added MeanMaxSum Pooling Layer and added example GravNet based model
parent c9da4177
Branches
No related tags found
No related merge requests found
import torch
# implements the gravnet layer without torch geometric
# this expects single-event batches
import torch.nn as nn
class GravNet_Layer(torch.nn.Module):
def __init__(self,
......@@ -37,29 +35,42 @@ class GravNet_Layer(torch.nn.Module):
S_coords_T = S_coords.transpose(0,1)
S_diff = S_coords - S_coords_T
S_dist_sq = torch.sum(S_diff**2, dim=-1)
#find the n_neighbours neighbour indices for each node
neigh_dist_sq, neigh_idx = torch.topk(S_dist_sq, self.n_neighbours, largest=False)
#get the neighbour features
neigh_feat = torch.gather(FLR_feat, 0, neigh_idx) # -> [n_nodes, n_neighbours, n_FLR]
neigh_feat = torch.gather(torch.transpose(FLR_feat.unsqueeze(1).expand(-1, neigh_idx.size(0), -1),0,1), 1, neigh_idx.unsqueeze(-1).expand(-1, -1, FLR_feat.size(-1))) # -> [n_nodes, n_neighbours, n_FLR]
#weight by exp(-10*dist_sq)
weights = torch.exp(-10*neigh_dist_sq) # -> [n_nodes, n_neighbours]
#get the sum of the weighted neighbour features
sum_feat = torch.sum(neigh_feat * weights[:,:,None], dim=1) # -> [n_nodes, n_FLR]
#get the max of the weighted neighbour features
max_feat = torch.max(neigh_feat * weights[:,:,None], dim=1) # -> [n_nodes, n_FLR]
#concatenate the input features with the sum and max features
out = torch.cat([x, sum_feat, max_feat], dim=-1)
out = torch.cat([x, sum_feat, max_feat.values], dim=-1)
#apply the final linear layer
out = self.out(out)
return out
class GlobalMeanMaxSumPool(nn.Module):
def __init__(self):
super(GlobalMeanMaxSumPool, self).__init__()
# Linear layer to reduce (3,) -> (1,)
self.linear = nn.Linear(3, 1)
def forward(self, x):
# x is expected to be of shape [N, 1] (e.g., [50, 1])
mean_val = torch.mean(x) # Shape: (1,)
max_val = torch.max(x) # Shape: (1,)
sum_val = torch.sum(x) # Shape: (1,)
# Concatenate the mean, max, and sum into a single tensor of shape (3,)
final_input = torch.cat([mean_val.unsqueeze(0), max_val.unsqueeze(0), sum_val.unsqueeze(0)], dim=0) # Shape: (3,)
# Pass through the linear layer to get the final output of shape (1,)
final_output = self.linear(final_input) # Shape: (1,)
return final_output
#implements a mean and pool layer over all nodes
class GlobalMeanMaxPool(torch.nn.Module):
......@@ -74,3 +85,111 @@ class GlobalMeanMaxPool(torch.nn.Module):
return torch.cat([mean, max], dim=-1) # -> [1, 2*n_output_filters]
class GNNmodel(torch.nn.Module):
"""
Simple example GNN model set up using the GravNet layer defined above
"""
def __init__(
self,
len_features,
dense_layer_dim=22,
feature_space_dim=16,
spatial_information_dim=6,
k=14,
n_gravblocks=4,
batch_norm_momentum=0.01,
):
"""
Args:
len_features (int): Number of features per node.
dense_layer_dim (int): Number of neurons in the dense layers.
feature_space_dim (int): Number of dimensions for the feature space.
spatial_information_dim (int): Number of dimensions for the spatial information.
k (int): Number of nearest neighbors in the GravNet layer.
n_gravblocks (int): Number of GravNet blocks.
batch_norm_momentum (float): Momentum for the batch normalization layers.
"""
super().__init__()
input_dim = len_features # input length can be inferred from features
# First block starts with the input dimension
self.blocks = torch.nn.ModuleList(
[
torch.nn.ModuleList(
[
torch.nn.Linear(2 * input_dim, dense_layer_dim),
torch.nn.Linear(dense_layer_dim, dense_layer_dim),
torch.nn.Linear(dense_layer_dim, dense_layer_dim),
GravNet_Layer(
n_input_dimensions=dense_layer_dim,
n_neighbours=k,
n_S_dimensions=spatial_information_dim,
n_FLR=feature_space_dim,
n_output_filters=dense_layer_dim,
),
BatchNorm1d(dense_layer_dim, momentum=batch_norm_momentum),
]
)
]
)
# Additional GravNet blocks
for _ in range(n_gravblocks - 1):
self.blocks.append(
torch.nn.ModuleList(
[
torch.nn.Linear(dense_layer_dim, dense_layer_dim),
torch.nn.Linear(dense_layer_dim, dense_layer_dim),
torch.nn.Linear(dense_layer_dim, dense_layer_dim),
GravNet_Layer(
n_input_dimensions=dense_layer_dim,
n_neighbours=k,
n_S_dimensions=spatial_information_dim,
n_FLR=feature_space_dim,
n_output_filters=dense_layer_dim,
),
BatchNorm1d(dense_layer_dim, momentum=batch_norm_momentum),
]
)
)
# Final fully connected layers
self.final1 = torch.nn.Linear(n_gravblocks * dense_layer_dim, 64)
self.final2 = torch.nn.Linear(64, 1)
self.final_layer=GlobalMeanMaxSumPool()
def forward(self, batch_x):
outputs = [] # To store output from each event
# Loop over each event in the batch
for x in batch_x:
feat = []
# Initial global exchange to append average features to each node
out = torch.mean(x, dim=0, keepdim=True)
x = torch.cat([x, out.repeat(x.shape[0], 1)], dim=-1)
# Process through each block
for block in self.blocks:
x = F.elu(block[0](x)) # Linear layer 1
x = F.elu(block[1](x)) # Linear layer 2
x = torch.tanh(block[2](x)) # Linear layer 3
x = block[3](x) # GravNet layer
x = block[4](x) # BatchNorm layer
feat.append(x)
# Concatenate skip connections from all blocks
x = torch.cat(feat, dim=1)
# Pass through final layers
x = F.relu(self.final1(x)) # Final linear layer 1
x = self.final2(x) # Final linear layer 2
# mean_val=torch.mean(x)
# max_val=torch.max(x)
# sum_val=torch.sum(x)
# final_input = torch.cat([mean_val.unsqueeze(0), max_val.unsqueeze(0), sum_val.unsqueeze(0)], dim=0)
x=self.final_layer(x)
outputs.append(x)
# Stack outputs for all events and return
return torch.stack(outputs)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment