# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch.nn.functional as F
return_packed: bool = False,
interp_mode: str = "bilinear",
padding_mode: str = "zeros",
align_corners: bool = True,
) -> torch.Tensor:
Sample vertex features from a feature map. This operation is called
"perceptual feature pooling" in  or "vert align" in .
 Wang et al, "Pixel2Mesh: Generating 3D Mesh Models from Single
RGB Images", ECCV 2018.
 Gkioxari et al, "Mesh R-CNN", ICCV 2019
feats: FloatTensor of shape (N, C, H, W) representing image features
from which to sample or a list of features each with potentially
different C, H or W dimensions.
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds)
with `verts_padded' or `points_padded' as an attribute giving the (x, y, z)
vertex positions for which to sample. (x, y) verts should be normalized such
that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right
location in the input feature map.
return_packed: (bool) Indicates whether to return packed features
interp_mode: (str) Specifies how to interpolate features.
('bilinear' or 'nearest')
padding_mode: (str) Specifies how to handle vertices outside of the
[-1, 1] range. ('zeros', 'reflection', or 'border')
align_corners (bool): Geometrically, we consider the pixels of the
input as squares rather than points.
If set to ``True``, the extrema (``-1`` and ``1``) are considered as
referring to the center points of the input's corner pixels. If set
to ``False``, they are instead considered as referring to the corner
points of the input's corner pixels, making the sampling more
resolution agnostic. Default: ``True``
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for each
vertex. If feats is a list, we return concatenated features in axis=2 of
shape (N, V, sum(C_n)) where C_n = feats[n].shape.
If return_packed = True, the features are transformed to a packed
representation of shape (sum(V), C)
if verts.dim() != 3:
raise ValueError("verts tensor should be 3 dimensional")
grid = verts
elif hasattr(verts, "verts_padded"):
grid = verts.verts_padded()
elif hasattr(verts, "points_padded"):
grid = verts.points_padded()
"verts must be a tensor or have a "
+ "`points_padded' or`verts_padded` attribute."
grid = grid[:, None, :, :2] # (N, 1, V, 2)
feats = [feats]
for feat in feats:
if feat.dim() != 4:
raise ValueError("feats must have shape (N, C, H, W)")
if grid.shape != feat.shape:
raise ValueError("inconsistent batch dimension")
feats_sampled = 
for feat in feats:
feat_sampled = F.grid_sample(
) # (N, C, 1, V)
feat_sampled = feat_sampled.squeeze(dim=2).transpose(1, 2) # (N, V, C)
feats_sampled = torch.cat(feats_sampled, dim=2) # (N, V, sum(C))
# flatten the first two dimensions: (N*V, C)
feats_sampled = feats_sampled.view(-1, feats_sampled.shape[-1])
if hasattr(verts, "verts_padded_to_packed_idx"):
idx = (
# pyre-fixme: `Tensor` has no attribute `gather`.
feats_sampled = feats_sampled.gather(0, idx) # (sum(V), C)