Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
G
GNNs4ObjectReco
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Lehre
GNNs4ObjectReco
Commits
c9da4177
Commit
c9da4177
authored
7 months ago
by
Jan Kieseler
Browse files
Options
Downloads
Patches
Plain Diff
single-batch wise gravnet without geometric
parent
68538d50
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
examples/models/GravNet_jk.py
+76
-0
76 additions, 0 deletions
examples/models/GravNet_jk.py
with
76 additions
and
0 deletions
examples/models/GravNet_jk.py
0 → 100644
+
76
−
0
View file @
c9da4177
import
torch
# implements the gravnet layer without torch geometric
# this expects single-event batches
class
GravNet_Layer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
n_input_dimensions
,
n_neighbours
,
n_S_dimensions
,
n_FLR
,
n_output_filters
):
super
(
GravNet_Layer
,
self
).
__init__
()
self
.
n_neighbours
=
n_neighbours
self
.
n_S_dimensions
=
n_S_dimensions
self
.
n_FLR
=
n_FLR
self
.
n_output_filters
=
n_output_filters
self
.
S
=
torch
.
nn
.
Linear
(
n_input_dimensions
,
n_S_dimensions
)
self
.
FLR
=
torch
.
nn
.
Linear
(
n_input_dimensions
,
n_FLR
)
self
.
out
=
torch
.
nn
.
Linear
(
2
*
n_FLR
+
n_input_dimensions
,
n_output_filters
)
def
forward
(
self
,
x
):
# x is of shape [n_nodes, n_dimensions]
# batch is of shape [n_nodes]
#assert that x is n_nodes x F, and not further nested
assert
len
(
x
.
shape
)
==
2
S_coords
=
self
.
S
(
x
)
FLR_feat
=
self
.
FLR
(
x
)
#build the distance matrix in the S space
S_coords
=
S_coords
.
unsqueeze
(
0
)
S_coords_T
=
S_coords
.
transpose
(
0
,
1
)
S_diff
=
S_coords
-
S_coords_T
S_dist_sq
=
torch
.
sum
(
S_diff
**
2
,
dim
=-
1
)
#find the n_neighbours neighbour indices for each node
neigh_dist_sq
,
neigh_idx
=
torch
.
topk
(
S_dist_sq
,
self
.
n_neighbours
,
largest
=
False
)
#get the neighbour features
neigh_feat
=
torch
.
gather
(
FLR_feat
,
0
,
neigh_idx
)
# -> [n_nodes, n_neighbours, n_FLR]
#weight by exp(-10*dist_sq)
weights
=
torch
.
exp
(
-
10
*
neigh_dist_sq
)
# -> [n_nodes, n_neighbours]
#get the sum of the weighted neighbour features
sum_feat
=
torch
.
sum
(
neigh_feat
*
weights
[:,:,
None
],
dim
=
1
)
# -> [n_nodes, n_FLR]
#get the max of the weighted neighbour features
max_feat
=
torch
.
max
(
neigh_feat
*
weights
[:,:,
None
],
dim
=
1
)
# -> [n_nodes, n_FLR]
#concatenate the input features with the sum and max features
out
=
torch
.
cat
([
x
,
sum_feat
,
max_feat
],
dim
=-
1
)
#apply the final linear layer
out
=
self
.
out
(
out
)
return
out
#implements a mean and pool layer over all nodes
class
GlobalMeanMaxPool
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
GlobalMeanMaxPool
,
self
).
__init__
()
def
forward
(
self
,
x
):
#man and keep 0 dimension
mean
=
torch
.
mean
(
x
,
dim
=
0
,
keepdim
=
True
)
max
=
torch
.
max
(
x
,
dim
=
0
,
keepdim
=
True
)
return
torch
.
cat
([
mean
,
max
],
dim
=-
1
)
# -> [1, 2*n_output_filters]
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment