pytorch3d.ops

pytorch3d.ops.ball_query(p1: Tensor, p2: Tensor, lengths1: Tensor | None = None, lengths2: Tensor | None = None, K: int = 500, radius: float = 0.2, return_nn: bool = True)[source]

Ball Query is an alternative to KNN. It can be used to find all points in p2 that are within a specified radius to the query point in p1 (with an upper limit of K neighbors).

The neighbors returned are not necssarily the nearest to the point in p1, just the first K values in p2 which are within the specified radius.

This method is faster than kNN when there are large numbers of points in p2 and the ordering of neighbors is not important compared to the distance being within the radius threshold.

“Ball query’s local neighborhood guarantees a fixed region scale thus making local region features more generalizable across space, which is preferred for tasks requiring local pattern recognition (e.g. semantic point labeling)” [1].

[1] Charles R. Qi et al, “PointNet++: Deep Hierarchical Feature Learning

on Point Sets in a Metric Space”, NeurIPS 2017.

Parameters:
  • p1 – Tensor of shape (N, P1, D) giving a batch of N point clouds, each containing up to P1 points of dimension D. These represent the centers of the ball queries.

  • p2 – Tensor of shape (N, P2, D) giving a batch of N point clouds, each containing up to P2 points of dimension D.

  • lengths1 – LongTensor of shape (N,) of values in the range [0, P1], giving the length of each pointcloud in p1. Or None to indicate that every cloud has length P1.

  • lengths2 – LongTensor of shape (N,) of values in the range [0, P2], giving the length of each pointcloud in p2. Or None to indicate that every cloud has length P2.

  • K – Integer giving the upper bound on the number of samples to take within the radius

  • radius – the radius around each point within which the neighbors need to be located

  • return_nn – If set to True returns the K neighbor points in p2 for each point in p1.

Returns:

dists

Tensor of shape (N, P1, K) giving the squared distances to

the neighbors. This is padded with zeros both where a cloud in p2 has fewer than S points and where a cloud in p1 has fewer than P1 points and also if there are fewer than K points which satisfy the radius threshold.

idx: LongTensor of shape (N, P1, K) giving the indices of the

S neighbors in p2 for points in p1. Concretely, if p1_idx[n, i, k] = j then p2[n, j] is the k-th neighbor to p1[n, i] in p2[n]. This is padded with -1 both where a cloud in p2 has fewer than S points and where a cloud in p1 has fewer than P1 points and also if there are fewer than K points which satisfy the radius threshold.

nn: Tensor of shape (N, P1, K, D) giving the K neighbors in p2 for

each point in p1. Concretely, p2_nn[n, i, k] gives the k-th neighbor for p1[n, i]. Returned if return_nn is True. The output is a tensor of shape (N, P1, K, U).

pytorch3d.ops.corresponding_cameras_alignment(cameras_src: CamerasBase, cameras_tgt: CamerasBase, estimate_scale: bool = True, mode: str = 'extrinsics', eps: float = 1e-09) CamerasBase[source]

Warning

The corresponding_cameras_alignment API is experimental and subject to change!

Estimates a single similarity transformation between two sets of cameras cameras_src and cameras_tgt and returns an aligned version of cameras_src.

Given source cameras [(R_1, T_1), (R_2, T_2), …, (R_N, T_N)] and target cameras [(R_1’, T_1’), (R_2’, T_2’), …, (R_N’, T_N’)], where (R_i, T_i) is a 2-tuple of the camera rotation and translation matrix respectively, the algorithm finds a global rotation, translation and scale (R_A, T_A, s_A) which aligns all source cameras with the target cameras such that the following holds:

Under the change of coordinates using a similarity transform (R_A, T_A, s_A) a 3D point X’ is mapped to X with:

X = (X' R_A + T_A) / s_A

Then, for all cameras i, we assume that the following holds:

X R_i + T_i = s' (X' R_i' + T_i'),

i.e. an adjusted point X’ is mapped by a camera (R_i’, T_i’) to the same point as imaged from camera (R_i, T_i) after resolving the scale ambiguity with a global scalar factor s’.

Substituting for X above gives rise to the following:

(X' R_A + T_A) / s_A R_i + T_i = s' (X' R_i' + T_i')       // · s_A
(X' R_A + T_A) R_i + T_i s_A = (s' s_A) (X' R_i' + T_i')
s' := 1 / s_A  # without loss of generality
(X' R_A + T_A) R_i + T_i s_A = X' R_i' + T_i'
X' R_A R_i + T_A R_i + T_i s_A = X' R_i' + T_i'
   ^^^^^^^   ^^^^^^^^^^^^^^^^^
   ~= R_i'        ~= T_i'

i.e. after estimating R_A, T_A, s_A, the aligned source cameras have extrinsics:

cameras_src_align = (R_A R_i, T_A R_i + T_i s_A) ~= (R_i', T_i')
We support two ways R_A, T_A, s_A can be estimated:
  1. mode==’centers’

    Estimates the similarity alignment between camera centers using Umeyama’s algorithm (see pytorch3d.ops.corresponding_points_alignment for details) and transforms camera extrinsics accordingly.

  2. mode==’extrinsics’

    Defines the alignment problem as a system of the following equations:

    for all i:
    [ R_A   0 ] x [ R_i         0 ] = [ R_i' 0 ]
    [ T_A^T 1 ]   [ (s_A T_i^T) 1 ]   [ T_i' 1 ]
    

    R_A, T_A and s_A are then obtained by solving the system in the least squares sense.

The estimated camera transformation is a true similarity transform, i.e. it cannot be a reflection.

Parameters:
  • cameras_srcN cameras to be aligned.

  • cameras_tgtN target cameras.

  • estimate_scale – Controls whether the alignment transform is rigid (estimate_scale=False), or a similarity (estimate_scale=True). s_A is set to 1 if estimate_scale==False.

  • mode – Controls the alignment algorithm. Can be one either ‘centers’ or ‘extrinsics’. Please refer to the description above for details.

  • eps – A scalar for clamping to avoid dividing by zero. Active when estimate_scale==True.

Returns:

cameras_src_alignedcameras_src after applying the alignment transform.

pytorch3d.ops.cubify(voxels: Tensor, thresh: float, *, feats: Tensor | None = None, device=None, align: str = 'topleft') Meshes[source]

Converts a voxel to a mesh by replacing each occupied voxel with a cube consisting of 12 faces and 8 vertices. Shared vertices are merged, and internal faces are removed. :param voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities. :param thresh: A scalar threshold. If a voxel occupancy is larger than

thresh, the voxel is considered occupied.

Parameters:
  • feats – A FloatTensor of shape (N, K, D, H, W) containing the color information of each voxel. K is the number of channels. This is supported only when align == “center”

  • device – The device of the output meshes

  • align – Defines the alignment of the mesh vertices and the grid locations. Has to be one of {“topleft”, “corner”, “center”}. See below for explanation. Default is “topleft”.

Returns:

meshes – A Meshes object of the corresponding meshes.

The alignment between the vertices of the cubified mesh and the voxel locations (or pixels) is defined by the choice of align. We support three modes, as shown below for a 2x2 grid:

X—X—- X——-X ——— | | | | | | | X | X | X—X—- ——— ——— | | | | | | | X | X | ——— X——-X ———

topleft corner center

In the figure, X denote the grid locations and the squares represent the added cuboids. When align=”topleft”, then the top left corner of each cuboid corresponds to the pixel coordinate of the input grid. When align=”corner”, then the corners of the output mesh span the whole grid. When align=”center”, then the grid locations form the center of the cuboids.

class pytorch3d.ops.GraphConv(input_dim: int, output_dim: int, init: str = 'normal', directed: bool = False)[source]

A single graph convolution layer.

__init__(input_dim: int, output_dim: int, init: str = 'normal', directed: bool = False) None[source]
Parameters:
  • input_dim – Number of input features per vertex.

  • output_dim – Number of output features per vertex.

  • init – Weight initialization method. Can be one of [‘zero’, ‘normal’].

  • directed – Bool indicating if edges in the graph are directed.

forward(verts, edges)[source]
Parameters:
  • verts – FloatTensor of shape (V, input_dim) where V is the number of vertices and input_dim is the number of input features per vertex. input_dim has to match the input_dim specified in __init__.

  • edges – LongTensor of shape (E, 2) where E is the number of edges where each edge has the indices of the two vertices which form the edge.

Returns:

out – FloatTensor of shape (V, output_dim) where output_dim is the number of output features per vertex.

pytorch3d.ops.interpolate_face_attributes(pix_to_face: Tensor, barycentric_coords: Tensor, face_attributes: Tensor) Tensor[source]

Interpolate arbitrary face attributes using the barycentric coordinates for each pixel in the rasterized output.

Parameters:
  • pix_to_face – LongTensor of shape (…) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. A value < 0 indicates that the pixel does not overlap any face and should be skipped.

  • barycentric_coords – FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordinates of each pixel relative to the faces (in the packed representation) which overlap the pixel.

  • face_attributes – packed attributes of shape (total_faces, 3, D), specifying the value of the attribute for each vertex in the face.

Returns:

pixel_vals – tensor of shape (N, H, W, K, D) giving the interpolated value of the face attribute for each pixel.

pytorch3d.ops.box3d_overlap(boxes1: Tensor, boxes2: Tensor, eps: float = 0.0001) Tuple[Tensor, Tensor][source]

Computes the intersection of 3D boxes1 and boxes2.

Inputs boxes1, boxes2 are tensors of shape (B, 8, 3) (where B doesn’t have to be the same for boxes1 and boxes2), containing the 8 corners of the boxes, as follows:

  1. +———+. (5) | ` . | ` . | (0) +—+—–+ (1) | | | |

  1. +—–+—+. (6)| ` . | ` . | (3) ` +———+ (2)

NOTE: Throughout this implementation, we assume that boxes are defined by their 8 corners exactly in the order specified in the diagram above for the function to give correct results. In addition the vertices on each plane must be coplanar. As an alternative to the diagram, this is a unit bounding box which has the correct vertex ordering:

box_corner_vertices = [

[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1],

]

Parameters:
  • boxes1 – tensor of shape (N, 8, 3) of the coordinates of the 1st boxes

  • boxes2 – tensor of shape (M, 8, 3) of the coordinates of the 2nd boxes

Returns:

vol – (N, M) tensor of the volume of the intersecting convex shapes iou: (N, M) tensor of the intersection over union which is

defined as: iou = vol / (vol1 + vol2 - vol)

pytorch3d.ops.knn_gather(x: Tensor, idx: Tensor, lengths: Tensor | None = None)[source]

A helper function for knn that allows indexing a tensor x with the indices idx returned by knn_points.

For example, if dists, idx = knn_points(p, x, lengths_p, lengths, K) where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D), then one can compute the K nearest neighbors of p with p_nn = knn_gather(x, idx, lengths). It can also be applied for any tensor x of shape (N, M, U) where U != D.

Parameters:
  • x – Tensor of shape (N, M, U) containing U-dimensional features to be gathered.

  • idx – LongTensor of shape (N, L, K) giving the indices returned by knn_points.

  • lengths – LongTensor of shape (N,) of values in the range [0, M], giving the length of each example in the batch in x. Or None to indicate that every example has length M.

Returns:

x_out

Tensor of shape (N, L, K, U) resulting from gathering the elements of x

with idx, s.t. x_out[n, l, k] = x[n, idx[n, l, k]]. If k > lengths[n] then x_out[n, l, k] is filled with 0.0.

pytorch3d.ops.knn_points(p1: Tensor, p2: Tensor, lengths1: Tensor | None = None, lengths2: Tensor | None = None, norm: int = 2, K: int = 1, version: int = -1, return_nn: bool = False, return_sorted: bool = True) KNN[source]

K-Nearest neighbors on point clouds.

Parameters:
  • p1 – Tensor of shape (N, P1, D) giving a batch of N point clouds, each containing up to P1 points of dimension D.

  • p2 – Tensor of shape (N, P2, D) giving a batch of N point clouds, each containing up to P2 points of dimension D.

  • lengths1 – LongTensor of shape (N,) of values in the range [0, P1], giving the length of each pointcloud in p1. Or None to indicate that every cloud has length P1.

  • lengths2 – LongTensor of shape (N,) of values in the range [0, P2], giving the length of each pointcloud in p2. Or None to indicate that every cloud has length P2.

  • norm – Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.

  • K – Integer giving the number of nearest neighbors to return.

  • version – Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs.

  • return_nn – If set to True returns the K nearest neighbors in p2 for each point in p1.

  • return_sorted – (bool) whether to return the nearest neighbors sorted in ascending order of distance.

Returns:

dists

Tensor of shape (N, P1, K) giving the squared distances to

the nearest neighbors. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.

idx: LongTensor of shape (N, P1, K) giving the indices of the

K nearest neighbors from points in p1 to points in p2. Concretely, if p1_idx[n, i, k] = j then p2[n, j] is the k-th nearest neighbors to p1[n, i] in p2[n]. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.

nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for

each point in p1. Concretely, p2_nn[n, i, k] gives the k-th nearest neighbor for p1[n, i]. Returned if return_nn is True. The nearest neighbors are collected using knn_gather

p2_nn = knn_gather(p2, p1_idx, lengths2)

which is a helper function that allows indexing any tensor of shape (N, P2, U) with the indices p1_idx returned by knn_points. The output is a tensor of shape (N, P1, K, U).

pytorch3d.ops.cot_laplacian(verts: Tensor, faces: Tensor, eps: float = 1e-12) Tuple[Tensor, Tensor][source]

Returns the Laplacian matrix with cotangent weights and the inverse of the face areas.

Parameters:
  • verts – tensor of shape (V, 3) containing the vertices of the graph

  • faces – tensor of shape (F, 3) containing the vertex indices of each face

Returns:

2-element tuple containing - L: Sparse FloatTensor of shape (V,V) for the Laplacian matrix.

Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes. See the description above for more clarity.

  • inv_areas: FloatTensor of shape (V,) containing the inverse of sum of

    face areas containing each vertex

pytorch3d.ops.laplacian(verts: Tensor, edges: Tensor) Tensor[source]

Computes the laplacian matrix. The definition of the laplacian is L[i, j] = -1 , if i == j L[i, j] = 1 / deg(i) , if (i, j) is an edge L[i, j] = 0 , otherwise where deg(i) is the degree of the i-th vertex in the graph.

Parameters:
  • verts – tensor of shape (V, 3) containing the vertices of the graph

  • edges – tensor of shape (E, 2) containing the vertex indices of each edge

Returns:

L – Sparse FloatTensor of shape (V, V)

pytorch3d.ops.norm_laplacian(verts: Tensor, edges: Tensor, eps: float = 1e-12) Tensor[source]

Norm laplacian computes a variant of the laplacian matrix which weights each affinity with the normalized distance of the neighboring nodes. More concretely, L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes

Parameters:
  • verts – tensor of shape (V, 3) containing the vertices of the graph

  • edges – tensor of shape (E, 2) containing the vertex indices of each edge

Returns:

L – Sparse FloatTensor of shape (V, V)

pytorch3d.ops.mesh_face_areas_normals(*args, **kwargs)
pytorch3d.ops.taubin_smoothing(meshes: Meshes, lambd: float = 0.53, mu: float = -0.53, num_iter: int = 10) Meshes[source]

Taubin smoothing [1] is an iterative smoothing operator for meshes. At each iteration

verts := (1 - λ) * verts + λ * L * verts verts := (1 - μ) * verts + μ * L * verts

This function returns a new mesh with smoothed vertices. :param meshes: Meshes input to be smoothed :param lambd: float parameters for Taubin smoothing,

lambd > 0, mu < 0

Parameters:
  • mu – float parameters for Taubin smoothing, lambd > 0, mu < 0

  • num_iter – number of iterations to execute smoothing

Returns:

mesh – Smoothed input Meshes

[1] Curve and Surface Smoothing without Shrinkage,

Gabriel Taubin, ICCV 1997

pytorch3d.ops.packed_to_padded(inputs: Tensor, first_idxs: LongTensor, max_size: int) Tensor[source]

Torch wrapper that handles allowed input shapes. See description below.

Parameters:
  • inputs – FloatTensor of shape (F,) or (F, …), representing the packed batch tensor, e.g. areas for faces in a batch of meshes.

  • first_idxs – LongTensor of shape (N,) where N is the number of elements in the batch and first_idxs[i] = f means that the inputs for batch element i begin at inputs[f].

  • max_size – Max length of an element in the batch.

Returns:

inputs_padded

FloatTensor of shape (N, max_size) or (N, max_size, …)

where max_size is max of sizes. The values for batch element i which start at inputs[first_idxs[i]] will be copied to inputs_padded[i, :], with zeros padding out the extra inputs.

To handle the allowed input shapes, we convert the inputs tensor of shape (F,) to (F, 1). We reshape the output back to (N, max_size) from (N, max_size, 1).

pytorch3d.ops.padded_to_packed(inputs: Tensor, first_idxs: LongTensor, num_inputs: int, max_size_dim: int = 1) Tensor[source]

Torch wrapper that handles allowed input shapes. See description below.

Parameters:
  • inputs – FloatTensor of shape (N, …, max_size) or (N, …, max_size, …), representing the padded tensor, e.g. areas for faces in a batch of meshes, where max_size occurs on max_size_dim-th position.

  • first_idxs – LongTensor of shape (N,) where N is the number of elements in the batch and first_idxs[i] = f means that the inputs for batch element i begin at inputs_packed[f].

  • num_inputs – Number of packed entries (= F)

  • max_size_dim – the dimension to be packed

Returns:

inputs_packed

FloatTensor of shape (F,) or (F, …) where

inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, …, :delta[i]], where delta[i] = first_idx[i+1] - first_idx[i].

To handle the allowed input shapes, we convert the inputs tensor of shape (N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from (F, 1).

pytorch3d.ops.efficient_pnp(x: Tensor, y: Tensor, weights: Tensor | None = None, skip_quadratic_eq: bool = False) EpnpSolution[source]

Implements Efficient PnP algorithm [1] for Perspective-n-Points problem: finds a camera position (defined by rotation R and translation T) that minimizes re-projection error between the given 3D points x and the corresponding uncalibrated 2D points y, i.e. solves

y[i] = Proj(x[i] R[i] + T[i])

in the least-squares sense, where i are indices within the batch, and Proj is the perspective projection operator: Proj([x y z]) = [x/z y/z]. In the noise-less case, 4 points are enough to find the solution as long as they are not co-planar.

Parameters:
  • x – Batch of 3-dimensional points of shape (minibatch, num_points, 3).

  • y – Batch of 2-dimensional points of shape (minibatch, num_points, 2).

  • weights – Batch of non-negative weights of shape (minibatch, num_point). None means equal weights.

  • skip_quadratic_eq – If True, assumes the solution space for the linear system is one-dimensional, i.e. takes the scaled eigenvector that corresponds to the smallest eigenvalue as a solution. If False, finds the candidate coordinates in the potentially 4D null space by approximately solving the systems of quadratic equations. The best candidate is chosen by examining the 2D re-projection error. While this option finds a better solution, especially when the number of points is small or perspective distortions are low (the points are far away), it may be more difficult to back-propagate through.

Returns:

EpnpSolution namedtuple containing elements –

x_cam: Batch of transformed points x that is used to find

the camera parameters, of shape (minibatch, num_points, 3). In the general (noisy) case, they are not exactly equal to x[i] R[i] + T[i] but are some affine transform of `x[i]`s.

R: Batch of rotation matrices of shape (minibatch, 3, 3). T: Batch of translation vectors of shape (minibatch, 3). err_2d: Batch of mean 2D re-projection errors of shape

(minibatch,). Specifically, if yhat is the re-projection for the i-th batch element, it returns sum_j norm(yhat_j - y_j) where j iterates over points and norm denotes the L2 norm.

err_3d: Batch of mean algebraic errors of shape (minibatch,).

Specifically, those are squared distances between x_world and estimated points on the rays defined by y.

[1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). EPnP: An Accurate O(n) solution to the PnP problem. International Journal of Computer Vision. https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/

pytorch3d.ops.corresponding_points_alignment(X: Tensor | Pointclouds, Y: Tensor | Pointclouds, weights: Tensor | List[Tensor] | None = None, estimate_scale: bool = False, allow_reflection: bool = False, eps: float = 1e-09) SimilarityTransform[source]

Finds a similarity transformation (rotation R, translation T and optionally scale s) between two given sets of corresponding d-dimensional points X and Y such that:

s[i] X[i] R[i] + T[i] = Y[i],

for all batch indexes i in the least squares sense.

The algorithm is also known as Umeyama [1].

Parameters:
  • **X** – Batch of d-dimensional points of shape (minibatch, num_point, d) or a Pointclouds object.

  • **Y** – Batch of d-dimensional points of shape (minibatch, num_point, d) or a Pointclouds object.

  • **weights** – Batch of non-negative weights of shape (minibatch, num_point) or list of minibatch 1-dimensional tensors that may have different shapes; in that case, the length of i-th tensor should be equal to the number of points in X_i and Y_i. Passing None means uniform weights.

  • **estimate_scale** – If True, also estimates a scaling component s of the transformation. Otherwise assumes an identity scale and returns a tensor of ones.

  • **allow_reflection** – If True, allows the algorithm to return R which is orthonormal but has determinant==-1.

  • **eps** – A scalar for clamping to avoid dividing by zero. Active for the code that estimates the output scale s.

Returns:

3-element named tuple SimilarityTransform containing - R: Batch of orthonormal matrices of shape (minibatch, d, d). - T: Batch of translations of shape (minibatch, d). - s: batch of scaling factors of shape (minibatch, ).

References

[1] Shinji Umeyama: Least-Suqares Estimation of Transformation Parameters Between Two Point Patterns

pytorch3d.ops.iterative_closest_point(X: Tensor | Pointclouds, Y: Tensor | Pointclouds, init_transform: SimilarityTransform | None = None, max_iterations: int = 100, relative_rmse_thr: float = 1e-06, estimate_scale: bool = False, allow_reflection: bool = False, verbose: bool = False) ICPSolution[source]

Executes the iterative closest point (ICP) algorithm [1, 2] in order to find a similarity transformation (rotation R, translation T, and optionally scale s) between two given differently-sized sets of d-dimensional points X and Y, such that:

s[i] X[i] R[i] + T[i] = Y[NN[i]],

for all batch indices i in the least squares sense. Here, Y[NN[i]] stands for the indices of nearest neighbors from Y to each point in X. Note, however, that the solution is only a local optimum.

Parameters:
  • **X** – Batch of d-dimensional points of shape (minibatch, num_points_X, d) or a Pointclouds object.

  • **Y** – Batch of d-dimensional points of shape (minibatch, num_points_Y, d) or a Pointclouds object.

  • **init_transform** – A named-tuple SimilarityTransform of tensors R, T, `s, where R is a batch of orthonormal matrices of shape (minibatch, d, d), T is a batch of translations of shape (minibatch, d) and s is a batch of scaling factors of shape (minibatch,).

  • **max_iterations** – The maximum number of ICP iterations.

  • **relative_rmse_thr** – A threshold on the relative root mean squared error used to terminate the algorithm.

  • **estimate_scale** – If True, also estimates a scaling component s of the transformation. Otherwise assumes the identity scale and returns a tensor of ones.

  • **allow_reflection** – If True, allows the algorithm to return R which is orthonormal but has determinant==-1.

  • **verbose** – If True, prints status messages during each ICP iteration.

Returns:

A named tuple ICPSolution with the following fields –

converged: A boolean flag denoting whether the algorithm converged

successfully (=`True`) or not (=`False`).

rmse: Attained root mean squared error after termination of ICP. Xt: The point cloud X transformed with the final transformation

(R, T, s). If X is a Pointclouds object, returns an instance of Pointclouds, otherwise returns torch.Tensor.

RTs: A named tuple SimilarityTransform containing a batch of similarity transforms with fields:

R: Batch of orthonormal matrices of shape (minibatch, d, d). T: Batch of translations of shape (minibatch, d). s: batch of scaling factors of shape (minibatch, ).

t_history: A list of named tuples SimilarityTransform

the transformation parameters after each ICP iteration.

References

[1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992. [2] https://en.wikipedia.org/wiki/Iterative_closest_point

pytorch3d.ops.estimate_pointcloud_local_coord_frames(pointclouds: Tensor | Pointclouds, neighborhood_size: int = 50, disambiguate_directions: bool = True, *, use_symeig_workaround: bool = True) Tuple[Tensor, Tensor][source]

Estimates the principal directions of curvature (which includes normals) of a batch of pointclouds.

The algorithm first finds neighborhood_size nearest neighbors for each point of the point clouds, followed by obtaining principal vectors of covariance matrices of each of the point neighborhoods. The main principal vector corresponds to the normals, while the other 2 are the direction of the highest curvature and the 2nd highest curvature.

Note that each principal direction is given up to a sign. Hence, the function implements disambiguate_directions switch that allows to ensure consistency of the sign of neighboring normals. The implementation follows the sign disabiguation from SHOT descriptors [1].

The algorithm also returns the curvature values themselves. These are the eigenvalues of the estimated covariance matrices of each point neighborhood.

Parameters:
  • **pointclouds** – Batch of 3-dimensional points of shape (minibatch, num_point, 3) or a Pointclouds object.

  • **neighborhood_size** – The size of the neighborhood used to estimate the geometry around each point.

  • **disambiguate_directions** – If True, uses the algorithm from [1] to ensure sign consistency of the normals of neighboring points.

  • **use_symeig_workaround** – If True, uses a custom eigenvalue calculation.

Returns:

*curvatures*

The three principal curvatures of each point

of shape (minibatch, num_point, 3). If pointclouds are of Pointclouds class, returns a padded tensor.

local_coord_frames: The three principal directions of the curvature

around each point of shape (minibatch, num_point, 3, 3). The principal directions are stored in columns of the output. E.g. local_coord_frames[i, j, :, 0] is the normal of j-th point in the i-th pointcloud. If pointclouds are of Pointclouds class, returns a padded tensor.

References

[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for Local Surface Description, ECCV 2010.

pytorch3d.ops.estimate_pointcloud_normals(pointclouds: Tensor | Pointclouds, neighborhood_size: int = 50, disambiguate_directions: bool = True, *, use_symeig_workaround: bool = True) Tensor[source]

Estimates the normals of a batch of pointclouds.

The function uses estimate_pointcloud_local_coord_frames to estimate the normals. Please refer to that function for more detailed information.

Parameters:
  • **pointclouds** – Batch of 3-dimensional points of shape (minibatch, num_point, 3) or a Pointclouds object.

  • **neighborhood_size** – The size of the neighborhood used to estimate the geometry around each point.

  • **disambiguate_directions** – If True, uses the algorithm from [1] to ensure sign consistency of the normals of neighboring points.

  • **use_symeig_workaround** – If True, uses a custom eigenvalue calculation.

Returns:

*normals*

A tensor of normals for each input point

of shape (minibatch, num_point, 3). If pointclouds are of Pointclouds class, returns a padded tensor.

References

[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for Local Surface Description, ECCV 2010.

pytorch3d.ops.add_pointclouds_to_volumes(pointclouds: Pointclouds, initial_volumes: Volumes, mode: str = 'trilinear', min_weight: float = 0.0001, rescale_features: bool = True, _python: bool = False) Volumes[source]

Add a batch of point clouds represented with a Pointclouds structure pointclouds to a batch of existing volumes represented with a Volumes structure initial_volumes.

More specifically, the method casts a set of weighted votes (the weights are determined based on mode=”trilinear”|”nearest”) into the pre-initialized features and densities fields of initial_volumes.

The method returns an updated Volumes object that contains a copy of initial_volumes with its features and densities updated with the result of the pointcloud addition.

Example:

# init a random point cloud
pointclouds = Pointclouds(
    points=torch.randn(4, 100, 3), features=torch.rand(4, 100, 5)
)
# init an empty volume centered around [0.5, 0.5, 0.5] in world coordinates
# with a voxel size of 1.0.
initial_volumes = Volumes(
    features = torch.zeros(4, 5, 25, 25, 25),
    densities = torch.zeros(4, 1, 25, 25, 25),
    volume_translation = [-0.5, -0.5, -0.5],
    voxel_size = 1.0,
)
# add the pointcloud to the 'initial_volumes' buffer using
# trilinear splatting
updated_volumes = add_pointclouds_to_volumes(
    pointclouds=pointclouds,
    initial_volumes=initial_volumes,
    mode="trilinear",
)
Parameters:
  • pointclouds – Batch of 3D pointclouds represented with a Pointclouds structure. Note that pointclouds.features have to be defined.

  • initial_volumes – Batch of initial Volumes with pre-initialized 1-dimensional densities which contain non-negative numbers corresponding to the opaqueness of each voxel (the higher, the less transparent).

  • mode

    The mode of the conversion of individual points into the volume. Set either to nearest or trilinear: nearest: Each 3D point is first rounded to the volumetric

    lattice. Each voxel is then labeled with the average over features that fall into the given voxel. The gradients of nearest neighbor conversion w.r.t. the 3D locations of the points in pointclouds are not defined.

    trilinear: Each 3D point casts 8 weighted votes to the 8-neighborhood

    of its floating point coordinate. The weights are determined using a trilinear interpolation scheme. Trilinear splatting is fully differentiable w.r.t. all input arguments.

  • min_weight – A scalar controlling the lowest possible total per-voxel weight used to normalize the features accumulated in a voxel. Only active for mode==trilinear.

  • rescale_features – If False, output features are just the sum of input and added points. If True, they are averaged. In both cases, output densities are just summed without rescaling, so you may need to rescale them afterwards.

  • _python – Set to True to use a pure Python implementation, e.g. for test purposes, which requires more memory and may be slower.

Returns:

updated_volumes – Output Volumes structure containing the conversion result.

pytorch3d.ops.add_points_features_to_volume_densities_features(points_3d: Tensor, points_features: Tensor, volume_densities: Tensor, volume_features: Tensor | None, mode: str = 'trilinear', min_weight: float = 0.0001, mask: Tensor | None = None, grid_sizes: LongTensor | None = None, rescale_features: bool = True, _python: bool = False, align_corners: bool = True) Tuple[Tensor, Tensor][source]

Convert a batch of point clouds represented with tensors of per-point 3d coordinates and their features to a batch of volumes represented with tensors of densities and features.

Parameters:
  • points_3d – Batch of 3D point cloud coordinates of shape (minibatch, N, 3) where N is the number of points in each point cloud. Coordinates have to be specified in the local volume coordinates (ranging in [-1, 1]).

  • points_features – Features of shape (minibatch, N, feature_dim) corresponding to the points of the input point clouds pointcloud.

  • volume_densities – Batch of input feature volume densities of shape (minibatch, 1, D, H, W). Each voxel should contain a non-negative number corresponding to its opaqueness (the higher, the less transparent).

  • volume_features – Batch of input feature volumes of shape (minibatch, feature_dim, D, H, W) If set to None, the volume_features will be automatically instantiated with a correct size and filled with 0s.

  • mode

    The mode of the conversion of individual points into the volume. Set either to nearest or trilinear: nearest: Each 3D point is first rounded to the volumetric

    lattice. Each voxel is then labeled with the average over features that fall into the given voxel. The gradients of nearest neighbor rounding w.r.t. the input point locations points_3d are not defined.

    trilinear: Each 3D point casts 8 weighted votes to the 8-neighborhood

    of its floating point coordinate. The weights are determined using a trilinear interpolation scheme. Trilinear splatting is fully differentiable w.r.t. all input arguments.

  • min_weight – A scalar controlling the lowest possible total per-voxel weight used to normalize the features accumulated in a voxel. Only active for mode==trilinear.

  • mask – A binary mask of shape (minibatch, N) determining which 3D points are going to be converted to the resulting volume. Set to None if all points are valid.

  • grid_sizesLongTensor of shape (minibatch, 3) representing the spatial resolutions of each of the the non-flattened volumes tensors, or None to indicate the whole volume is used for every batch element.

  • rescale_features – If False, output features are just the sum of input and added points. If True, they are averaged. In both cases, output densities are just summed without rescaling, so you may need to rescale them afterwards.

  • _python – Set to True to use a pure Python implementation.

  • align_corners – as for grid_sample.

Returns:

volume_features – Output volume of shape (minibatch, feature_dim, D, H, W) volume_densities: Occupancy volume of shape (minibatch, 1, D, H, W)

containing the total amount of votes cast to each of the voxels.

pytorch3d.ops.sample_farthest_points(points: Tensor, lengths: Tensor | None = None, K: int | List | Tensor = 50, random_start_point: bool = False) Tuple[Tensor, Tensor][source]

Iterative farthest point sampling algorithm [1] to subsample a set of K points from a given pointcloud. At each iteration, a point is selected which has the largest nearest neighbor distance to any of the already selected points.

Farthest point sampling provides more uniform coverage of the input point cloud compared to uniform random sampling.

[1] Charles R. Qi et al, “PointNet++: Deep Hierarchical Feature Learning

on Point Sets in a Metric Space”, NeurIPS 2017.

Parameters:
  • points – (N, P, D) array containing the batch of pointclouds

  • lengths – (N,) number of points in each pointcloud (to support heterogeneous batches of pointclouds)

  • K – samples required in each sampled point cloud (this is typically << P). If K is an int then the same number of samples are selected for each pointcloud in the batch. If K is a tensor is should be length (N,) giving the number of samples to select for each element in the batch

  • random_start_point – bool, if True, a random point is selected as the starting point for iterative sampling.

Returns:

selected_points

(N, K, D), array of selected values from points. If the input

K is a tensor, then the shape will be (N, max(K), D), and padded with 0.0 for batch elements where k_i < max(K).

selected_indices: (N, K) array of selected indices. If the input

K is a tensor, then the shape will be (N, max(K), D), and padded with -1 for batch elements where k_i < max(K).

pytorch3d.ops.sample_points_from_meshes(meshes, num_samples: int = 10000, return_normals: bool = False, return_textures: bool = False) Tensor | Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor, Tensor][source]

Convert a batch of meshes to a batch of pointclouds by uniformly sampling points on the surface of the mesh with probability proportional to the face area.

Parameters:
  • meshes – A Meshes object with a batch of N meshes.

  • num_samples – Integer giving the number of point samples per mesh.

  • return_normals – If True, return normals for the sampled points.

  • return_textures – If True, return textures for the sampled points.

Returns:

3-element tuple containing

  • samples: FloatTensor of shape (N, num_samples, 3) giving the coordinates of sampled points for each mesh in the batch. For empty meshes the corresponding row in the samples array will be filled with 0.

  • normals: FloatTensor of shape (N, num_samples, 3) giving a normal vector to each sampled point. Only returned if return_normals is True. For empty meshes the corresponding row in the normals array will be filled with 0.

  • textures: FloatTensor of shape (N, num_samples, C) giving a C-dimensional texture vector to each sampled point. Only returned if return_textures is True. For empty meshes the corresponding row in the textures array will be filled with 0.

Note that in a future releases, we will replace the 3-element tuple output with a Pointclouds datastructure, as follows

Pointclouds(samples, normals=normals, features=textures)

class pytorch3d.ops.SubdivideMeshes(meshes=None)[source]

Subdivide a triangle mesh by adding a new vertex at the center of each edge and dividing each face into four new faces. Vectors of vertex attributes can also be subdivided by averaging the values of the attributes at the two vertices which form each edge. This implementation preserves face orientation - if the vertices of a face are all ordered counter-clockwise, then the faces in the subdivided meshes will also have their vertices ordered counter-clockwise.

If meshes is provided as an input, the initializer performs the relatively expensive computation of determining the new face indices. This one-time computation can be reused for all meshes with the same face topology but different vertex positions.

__init__(meshes=None) None[source]
Parameters:

meshes – Meshes object or None. If a meshes object is provided, the first mesh is used to compute the new faces of the subdivided topology which can be reused for meshes with the same input topology.

subdivide_faces(meshes)[source]
Parameters:

meshes – a Meshes object.

Returns:

subdivided_faces_packed – (4*sum(F_n), 3) shape LongTensor of original and new faces.

Refer to pytorch3d.structures.meshes.py for more details on packed representations of faces.

Each face is split into 4 faces e.g. Input face

         v0
         /\
        /  \
       /    \
   e1 /      \ e0
     /        \
    /          \
   /            \
  /______________\
v2       e2       v1

faces_packed = [[0, 1, 2]]
faces_packed_to_edges_packed = [[2, 1, 0]]

faces_packed_to_edges_packed is used to represent all the new vertex indices corresponding to the mid-points of edges in the mesh. The actual vertex coordinates will be computed in the forward function. To get the indices of the new vertices, offset faces_packed_to_edges_packed by the total number of vertices.

faces_packed_to_edges_packed = [[2, 1, 0]] + 3 = [[5, 4, 3]]

e.g. subdivided face

        v0
        /\
       /  \
      / f0 \
  v4 /______\ v3
    /\      /\
   /  \ f3 /  \
  / f2 \  / f1 \
 /______\/______\
v2       v5       v1

f0 = [0, 3, 4]
f1 = [1, 5, 3]
f2 = [2, 4, 5]
f3 = [5, 4, 3]
forward(meshes, feats=None)[source]

Subdivide a batch of meshes by adding a new vertex on each edge, and dividing each face into four new faces. New meshes contains two types of vertices: 1) Vertices that appear in the input meshes.

Data for these vertices are copied from the input meshes.

  1. New vertices at the midpoint of each edge. Data for these vertices is the average of the data for the two vertices that make up the edge.

Parameters:
  • meshes – Meshes object representing a batch of meshes.

  • feats – Per-vertex features to be subdivided along with the verts. Should be parallel to the packed vert representation of the input meshes; so it should have shape (V, D) where V is the total number of verts in the input meshes. Default: None.

Returns:

2-element tuple containing

  • new_meshes: Meshes object of a batch of subdivided meshes.

  • new_feats: (optional) Tensor of subdivided feats, parallel to the (packed) vertices of the subdivided meshes. Only returned if feats is not None.

subdivide_homogeneous(meshes, feats=None)[source]

Subdivide verts (and optionally features) of a batch of meshes where each mesh has the same topology of faces. The subdivided faces are precomputed in the initializer.

Parameters:
  • meshes – Meshes object representing a batch of meshes.

  • feats – Per-vertex features to be subdivided along with the verts.

Returns:

2-element tuple containing

  • new_meshes: Meshes object of a batch of subdivided meshes.

  • new_feats: (optional) Tensor of subdivided feats, parallel to the (packed) vertices of the subdivided meshes. Only returned if feats is not None.

subdivide_heterogenerous(meshes, feats=None)[source]

Subdivide faces, verts (and optionally features) of a batch of meshes where each mesh can have different face topologies.

Parameters:
  • meshes – Meshes object representing a batch of meshes.

  • feats – Per-vertex features to be subdivided along with the verts.

Returns:

2-element tuple containing

  • new_meshes: Meshes object of a batch of subdivided meshes.

  • new_feats: (optional) Tensor of subdivided feats, parallel to the (packed) vertices of the subdivided meshes. Only returned if feats is not None.

pytorch3d.ops.convert_pointclouds_to_tensor(pcl: Tensor | Pointclouds)[source]

If type(pcl)==Pointclouds, converts a pcl object to a padded representation and returns it together with the number of points per batch. Otherwise, returns the input itself with the number of points set to the size of the second dimension of pcl.

pytorch3d.ops.eyes(dim: int, N: int, device: device | None = None, dtype: dtype = torch.float32) Tensor[source]

Generates a batch of N identity matrices of shape (N, dim, dim).

Parameters:
  • **dim** – The dimensionality of the identity matrices.

  • **N** – The number of identity matrices.

  • **device** – The device to be used for allocating the matrices.

  • **dtype** – The datatype of the matrices.

Returns:

*identities* – A batch of identity matrices of shape (N, dim, dim).

pytorch3d.ops.get_point_covariances(points_padded: Tensor, num_points_per_cloud: Tensor, neighborhood_size: int) Tuple[Tensor, Tensor][source]

Computes the per-point covariance matrices by of the 3D locations of K-nearest neighbors of each point.

Parameters:
  • **points_padded** – Input point clouds as a padded tensor of shape (minibatch, num_points, dim).

  • **num_points_per_cloud** – Number of points per cloud of shape (minibatch,).

  • **neighborhood_size** – Number of nearest neighbors for each point used to estimate the covariance matrices.

Returns:

*covariances*

A batch of per-point covariance matrices

of shape (minibatch, dim, dim).

k_nearest_neighbors: A batch of neighborhood_size nearest

neighbors for each of the point cloud points of shape (minibatch, num_points, neighborhood_size, dim).

pytorch3d.ops.is_pointclouds(pcl: Tensor | Pointclouds) bool[source]

Checks whether the input pcl is an instance of Pointclouds by checking the existence of points_padded and num_points_per_cloud functions.

pytorch3d.ops.wmean(x: Tensor, weight: Tensor | None = None, dim: int | Tuple[int] = -2, keepdim: bool = True, eps: float = 1e-09) Tensor[source]

Finds the mean of the input tensor across the specified dimension. If the weight argument is provided, computes weighted mean. :param x: tensor of shape (*, D), where D is assumed to be spatial; :param weights: if given, non-negative tensor of shape (*,). It must be

broadcastable to x.shape[:-1]. Note that the weights for the last (spatial) dimension are assumed same;

Parameters:
  • dim – dimension(s) in x to average over;

  • keepdim – tells whether to keep the resulting singleton dimension.

  • eps – minimum clamping value in the denominator.

Returns:

the mean tensor

  • if weights is None => mean(x, dim),

  • otherwise => sum(x*w, dim) / max{sum(w, dim), eps}.

pytorch3d.ops.vert_align(feats, verts, return_packed: bool = False, interp_mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: bool = True) Tensor[source]

Sample vertex features from a feature map. This operation is called “perceptual feature pooling” in [1] or “vert align” in [2].

[1] Wang et al, “Pixel2Mesh: Generating 3D Mesh Models from Single

RGB Images”, ECCV 2018.

[2] Gkioxari et al, “Mesh R-CNN”, ICCV 2019

Parameters:
  • feats – FloatTensor of shape (N, C, H, W) representing image features from which to sample or a list of features each with potentially different C, H or W dimensions.

  • verts – FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds) with `verts_padded’ or `points_padded’ as an attribute giving the (x, y, z) vertex positions for which to sample. (x, y) verts should be normalized such that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right location in the input feature map.

  • return_packed – (bool) Indicates whether to return packed features

  • interp_mode – (str) Specifies how to interpolate features. (‘bilinear’ or ‘nearest’)

  • padding_mode – (str) Specifies how to handle vertices outside of the [-1, 1] range. (‘zeros’, ‘reflection’, or ‘border’)

  • align_corners (bool) – Geometrically, we consider the pixels of the input as squares rather than points. If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. Default: True

Returns:

feats_sampled

FloatTensor of shape (N, V, C) giving sampled features for each

vertex. If feats is a list, we return concatenated features in axis=2 of shape (N, V, sum(C_n)) where C_n = feats[n].shape[1]. If return_packed = True, the features are transformed to a packed representation of shape (sum(V), C)