Source code for pytorch3d.ops.points_alignment

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import warnings
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union

import torch
from pytorch3d.ops import knn_points
from pytorch3d.structures import utils as strutil

from . import utils as oputil


if TYPE_CHECKING:
    from pytorch3d.structures.pointclouds import Pointclouds


# named tuples for inputs/outputs
class SimilarityTransform(NamedTuple):
    R: torch.Tensor
    T: torch.Tensor
    s: torch.Tensor


class ICPSolution(NamedTuple):
    converged: bool
    rmse: Union[torch.Tensor, None]
    Xt: torch.Tensor
    RTs: SimilarityTransform
    t_history: List[SimilarityTransform]


[docs] def iterative_closest_point( X: Union[torch.Tensor, "Pointclouds"], Y: Union[torch.Tensor, "Pointclouds"], init_transform: Optional[SimilarityTransform] = None, max_iterations: int = 100, relative_rmse_thr: float = 1e-6, estimate_scale: bool = False, allow_reflection: bool = False, verbose: bool = False, ) -> ICPSolution: """ Executes the iterative closest point (ICP) algorithm [1, 2] in order to find a similarity transformation (rotation `R`, translation `T`, and optionally scale `s`) between two given differently-sized sets of `d`-dimensional points `X` and `Y`, such that: `s[i] X[i] R[i] + T[i] = Y[NN[i]]`, for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands for the indices of nearest neighbors from `Y` to each point in `X`. Note, however, that the solution is only a local optimum. Args: **X**: Batch of `d`-dimensional points of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object. **Y**: Batch of `d`-dimensional points of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object. **init_transform**: A named-tuple `SimilarityTransform` of tensors `R`, `T, `s`, where `R` is a batch of orthonormal matrices of shape `(minibatch, d, d)`, `T` is a batch of translations of shape `(minibatch, d)` and `s` is a batch of scaling factors of shape `(minibatch,)`. **max_iterations**: The maximum number of ICP iterations. **relative_rmse_thr**: A threshold on the relative root mean squared error used to terminate the algorithm. **estimate_scale**: If `True`, also estimates a scaling component `s` of the transformation. Otherwise assumes the identity scale and returns a tensor of ones. **allow_reflection**: If `True`, allows the algorithm to return `R` which is orthonormal but has determinant==-1. **verbose**: If `True`, prints status messages during each ICP iteration. Returns: A named tuple `ICPSolution` with the following fields: **converged**: A boolean flag denoting whether the algorithm converged successfully (=`True`) or not (=`False`). **rmse**: Attained root mean squared error after termination of ICP. **Xt**: The point cloud `X` transformed with the final transformation (`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an instance of `Pointclouds`, otherwise returns `torch.Tensor`. **RTs**: A named tuple `SimilarityTransform` containing a batch of similarity transforms with fields: **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`. **T**: Batch of translations of shape `(minibatch, d)`. **s**: batch of scaling factors of shape `(minibatch, )`. **t_history**: A list of named tuples `SimilarityTransform` the transformation parameters after each ICP iteration. References: [1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992. [2] https://en.wikipedia.org/wiki/Iterative_closest_point """ # make sure we convert input Pointclouds structures to # padded tensors of shape (N, P, 3) Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X) Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y) b, size_X, dim = Xt.shape if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]): raise ValueError( "Point sets X and Y have to have the same " + "number of batches and data dimensions." ) if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and ( num_points_Y != num_points_X ).any(): # we have a heterogeneous input (e.g. because X/Y is # an instance of Pointclouds) mask_X = ( torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None] < num_points_X[:, None] ).type_as(Xt) else: mask_X = Xt.new_ones(b, size_X) # clone the initial point cloud Xt_init = Xt.clone() if init_transform is not None: # parse the initial transform from the input and apply to Xt try: R, T, s = init_transform assert ( R.shape == torch.Size((b, dim, dim)) and T.shape == torch.Size((b, dim)) and s.shape == torch.Size((b,)) ) except Exception: raise ValueError( "The initial transformation init_transform has to be " "a named tuple SimilarityTransform with elements (R, T, s). " "R are dim x dim orthonormal matrices of shape " "(minibatch, dim, dim), T is a batch of dim-dimensional " "translations of shape (minibatch, dim) and s is a batch " "of scalars of shape (minibatch,)." ) from None # apply the init transform to the input point cloud Xt = _apply_similarity_transform(Xt, R, T, s) else: # initialize the transformation with identity R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype) T = Xt.new_zeros((b, dim)) s = Xt.new_ones(b) prev_rmse = None rmse = None iteration = -1 converged = False # initialize the transformation history t_history = [] # the main loop over ICP iterations for iteration in range(max_iterations): Xt_nn_points = knn_points( Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True ).knn[:, :, 0, :] # get the alignment of the nearest neighbors from Yt with Xt_init R, T, s = corresponding_points_alignment( Xt_init, Xt_nn_points, weights=mask_X, estimate_scale=estimate_scale, allow_reflection=allow_reflection, ) # apply the estimated similarity transform to Xt_init Xt = _apply_similarity_transform(Xt_init, R, T, s) # add the current transformation to the history t_history.append(SimilarityTransform(R, T, s)) # compute the root mean squared error # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2) rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0] # compute the relative rmse if prev_rmse is None: relative_rmse = rmse.new_ones(b) else: relative_rmse = (prev_rmse - rmse) / prev_rmse if verbose: rmse_msg = ( f"ICP iteration {iteration}: mean/max rmse = " + f"{rmse.mean():1.2e}/{rmse.max():1.2e} " + f"; mean relative rmse = {relative_rmse.mean():1.2e}" ) print(rmse_msg) # check for convergence if (relative_rmse <= relative_rmse_thr).all(): converged = True break # update the previous rmse prev_rmse = rmse if verbose: if converged: print(f"ICP has converged in {iteration + 1} iterations.") else: print(f"ICP has not converged in {max_iterations} iterations.") if oputil.is_pointclouds(X): Xt = X.update_padded(Xt) # type: ignore return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
# threshold for checking that point crosscorelation # is full rank in corresponding_points_alignment AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
[docs] def corresponding_points_alignment( X: Union[torch.Tensor, "Pointclouds"], Y: Union[torch.Tensor, "Pointclouds"], weights: Union[torch.Tensor, List[torch.Tensor], None] = None, estimate_scale: bool = False, allow_reflection: bool = False, eps: float = 1e-9, ) -> SimilarityTransform: """ Finds a similarity transformation (rotation `R`, translation `T` and optionally scale `s`) between two given sets of corresponding `d`-dimensional points `X` and `Y` such that: `s[i] X[i] R[i] + T[i] = Y[i]`, for all batch indexes `i` in the least squares sense. The algorithm is also known as Umeyama [1]. Args: **X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` or a `Pointclouds` object. **Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)` or a `Pointclouds` object. **weights**: Batch of non-negative weights of shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional tensors that may have different shapes; in that case, the length of i-th tensor should be equal to the number of points in X_i and Y_i. Passing `None` means uniform weights. **estimate_scale**: If `True`, also estimates a scaling component `s` of the transformation. Otherwise assumes an identity scale and returns a tensor of ones. **allow_reflection**: If `True`, allows the algorithm to return `R` which is orthonormal but has determinant==-1. **eps**: A scalar for clamping to avoid dividing by zero. Active for the code that estimates the output scale `s`. Returns: 3-element named tuple `SimilarityTransform` containing - **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`. - **T**: Batch of translations of shape `(minibatch, d)`. - **s**: batch of scaling factors of shape `(minibatch, )`. References: [1] Shinji Umeyama: Least-Suqares Estimation of Transformation Parameters Between Two Point Patterns """ # make sure we convert input Pointclouds structures to tensors Xt, num_points = oputil.convert_pointclouds_to_tensor(X) Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y) if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any(): raise ValueError( "Point sets X and Y have to have the same \ number of batches, points and dimensions." ) if weights is not None: if isinstance(weights, list): if any(np != w.shape[0] for np, w in zip(num_points, weights)): raise ValueError( "number of weights should equal to the " + "number of points in the point cloud." ) weights = [w[..., None] for w in weights] weights = strutil.list_to_padded(weights)[..., 0] if Xt.shape[:2] != weights.shape: raise ValueError("weights should have the same first two dimensions as X.") b, n, dim = Xt.shape if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any(): # in case we got Pointclouds as input, mask the unused entries in Xc, Yc mask = ( torch.arange(n, dtype=torch.int64, device=Xt.device)[None] < num_points[:, None] ).type_as(Xt) weights = mask if weights is None else mask * weights.type_as(Xt) # compute the centroids of the point sets Xmu = oputil.wmean(Xt, weight=weights, eps=eps) Ymu = oputil.wmean(Yt, weight=weights, eps=eps) # mean-center the point sets Xc = Xt - Xmu Yc = Yt - Ymu total_weight = torch.clamp(num_points, 1) # special handling for heterogeneous point clouds and/or input weights if weights is not None: Xc *= weights[:, :, None] Yc *= weights[:, :, None] total_weight = torch.clamp(weights.sum(1), eps) if (num_points < (dim + 1)).any(): warnings.warn( "The size of one of the point clouds is <= dim+1. " + "corresponding_points_alignment cannot return a unique rotation." ) # compute the covariance XYcov between the point sets Xc, Yc XYcov = torch.bmm(Xc.transpose(2, 1), Yc) XYcov = XYcov / total_weight[:, None, None] # decompose the covariance matrix XYcov U, S, V = torch.svd(XYcov) # catch ambiguous rotation by checking the magnitude of singular values if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not ( num_points < (dim + 1) ).any(): warnings.warn( "Excessively low rank of " + "cross-correlation between aligned point clouds. " + "corresponding_points_alignment cannot return a unique rotation." ) # identity matrix used for fixing reflections E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1) if not allow_reflection: # reflection test: # checks whether the estimated rotation has det==1, # if not, finds the nearest rotation s.t. det==1 by # flipping the sign of the last singular vector U R_test = torch.bmm(U, V.transpose(2, 1)) E[:, -1, -1] = torch.det(R_test) # find the rotation matrix by composing U and V again R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1)) if estimate_scale: # estimate the scaling component of the transformation trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1) Xcov = (Xc * Xc).sum((1, 2)) / total_weight # the scaling component s = trace_ES / torch.clamp(Xcov, eps) # translation component T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :] else: # translation component T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :] # unit scaling since we do not estimate scale s = T.new_ones(b) return SimilarityTransform(R, T, s)
def _apply_similarity_transform( X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor ) -> torch.Tensor: """ Applies a similarity transformation parametrized with a batch of orthonormal matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T` of shape `(minibatch, d)` and a batch of scaling factors `s` of shape `(minibatch,)` to a given `d`-dimensional cloud `X` of shape `(minibatch, num_points, d)` """ X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :] return X