# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import torch
import torch.nn as nn
from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable
class GraphConv(nn.Module):
"""A single graph convolution layer."""
def __init__(
input_dim: int,
output_dim: int,
init: str = "normal",
directed: bool = False,
) -> None:
input_dim: Number of input features per vertex.
output_dim: Number of output features per vertex.
init: Weight initialization method. Can be one of ['zero', 'normal'].
directed: Bool indicating if edges in the graph are directed.
self.input_dim = input_dim
self.output_dim = output_dim
self.directed = directed
self.w0 = nn.Linear(input_dim, output_dim)
self.w1 = nn.Linear(input_dim, output_dim)
if init == "normal":
nn.init.normal_(self.w0.weight, mean=0, std=0.01)
nn.init.normal_(self.w1.weight, mean=0, std=0.01)
elif init == "zero":
raise ValueError('Invalid GraphConv initialization "%s"' % init)
def forward(self, verts, edges):
verts: FloatTensor of shape (V, input_dim) where V is the number of
vertices and input_dim is the number of input features
per vertex. input_dim has to match the input_dim specified
in __init__.
edges: LongTensor of shape (E, 2) where E is the number of edges
where each edge has the indices of the two vertices which
form the edge.
out: FloatTensor of shape (V, output_dim) where output_dim is the
number of output features per vertex.
if verts.is_cuda != edges.is_cuda:
raise ValueError("verts and edges tensors must be on the same device.")
if verts.shape[0] == 0:
# empty graph.
return verts.new_zeros((0, self.output_dim)) * verts.sum()
verts_w0 = self.w0(verts) # (V, output_dim)
verts_w1 = self.w1(verts) # (V, output_dim)
if torch.cuda.is_available() and verts.is_cuda and edges.is_cuda:
neighbor_sums = gather_scatter(verts_w1, edges, self.directed)
neighbor_sums = gather_scatter_python(
verts_w1, edges, self.directed
) # (V, output_dim)
# Add neighbor features to each vertex's features.
out = verts_w0 + neighbor_sums
return out
def __repr__(self):
Din, Dout, directed = self.input_dim, self.output_dim, self.directed
return "GraphConv(%d -> %d, directed=%r)" % (Din, Dout, directed)
def gather_scatter_python(input, edges, directed: bool = False):
Python implementation of gather_scatter for aggregating features of
neighbor nodes in a graph.
Given a directed graph: v0 -> v1 -> v2 the updated feature for v1 depends
on v2 in order to be consistent with Morris et al. AAAI 2019
(https://arxiv.org/abs/1810.02244). This only affects
directed graphs; for undirected graphs v1 will depend on both v0 and v2,
no matter which way the edges are physically stored.
input: Tensor of shape (num_vertices, input_dim).
edges: Tensor of edge indices of shape (num_edges, 2).
directed: bool indicating if edges are directed.
output: Tensor of same shape as input.
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
num_vertices, input_feature_dim = input.shape
num_edges = edges.shape[0]
output = torch.zeros_like(input)
idx0 = edges[:, 0].view(num_edges, 1).expand(num_edges, input_feature_dim)
idx1 = edges[:, 1].view(num_edges, 1).expand(num_edges, input_feature_dim)
output = output.scatter_add(0, idx0, input.gather(0, idx1))
if not directed:
output = output.scatter_add(0, idx1, input.gather(0, idx0))
return output
class GatherScatter(Function):
Torch autograd Function wrapper for gather_scatter C++/CUDA implementations.
def forward(ctx, input, edges, directed=False):
ctx: Context object used to calculate gradients.
input: Tensor of shape (num_vertices, input_dim)
edges: Tensor of edge indices of shape (num_edges, 2)
directed: Bool indicating if edges are directed.
output: Tensor of same shape as input.
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
if not (input.dtype == torch.float32):
raise ValueError("input has to be of type torch.float32.")
ctx.directed = directed
input, edges = input.contiguous(), edges.contiguous()
backward = False
output = _C.gather_scatter(input, edges, directed, backward)
return output
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
edges = ctx.saved_tensors[0]
directed = ctx.directed
backward = True
grad_input = _C.gather_scatter(grad_output, edges, directed, backward)
grad_edges = None
grad_directed = None
return grad_input, grad_edges, grad_directed
gather_scatter = GatherScatter.apply