import warnings
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union
import torch
from pytorch3d.ops import knn_points
from pytorch3d.structures import utils as strutil
from . import utils as oputil
from pytorch3d.structures.pointclouds import Pointclouds
# named tuples for inputs/outputs
class SimilarityTransform(NamedTuple):
R: torch.Tensor
T: torch.Tensor
s: torch.Tensor
class ICPSolution(NamedTuple):
converged: bool
rmse: Union[torch.Tensor, None]
Xt: torch.Tensor
RTs: SimilarityTransform
t_history: List[SimilarityTransform]
def iterative_closest_point(
X: Union[torch.Tensor, "Pointclouds"],
Y: Union[torch.Tensor, "Pointclouds"],
init_transform: Optional[SimilarityTransform] = None,
max_iterations: int = 100,
relative_rmse_thr: float = 1e-6,
estimate_scale: bool = False,
allow_reflection: bool = False,
verbose: bool = False,
) -> ICPSolution:
Executes the iterative closest point (ICP) algorithm [1, 2] in order to find
a similarity transformation (rotation `R`, translation `T`, and
optionally scale `s`) between two given differently-sized sets of
`d`-dimensional points `X` and `Y`, such that:
`s[i] X[i] R[i] + T[i] = Y[NN[i]]`,
for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands
for the indices of nearest neighbors from `Y` to each point in `X`.
Note, however, that the solution is only a local optimum.
**X**: Batch of `d`-dimensional points
of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object.
**Y**: Batch of `d`-dimensional points
of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object.
**init_transform**: A named-tuple `SimilarityTransform` of tensors
`R`, `T, `s`, where `R` is a batch of orthonormal matrices of
shape `(minibatch, d, d)`, `T` is a batch of translations
of shape `(minibatch, d)` and `s` is a batch of scaling factors
of shape `(minibatch,)`.
**max_iterations**: The maximum number of ICP iterations.
**relative_rmse_thr**: A threshold on the relative root mean squared error
used to terminate the algorithm.
**estimate_scale**: If `True`, also estimates a scaling component `s`
of the transformation. Otherwise assumes the identity
scale and returns a tensor of ones.
**allow_reflection**: If `True`, allows the algorithm to return `R`
which is orthonormal but has determinant==-1.
**verbose**: If `True`, prints status messages during each ICP iteration.
A named tuple `ICPSolution` with the following fields:
**converged**: A boolean flag denoting whether the algorithm converged
successfully (=`True`) or not (=`False`).
**rmse**: Attained root mean squared error after termination of ICP.
**Xt**: The point cloud `X` transformed with the final transformation
(`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an
instance of `Pointclouds`, otherwise returns `torch.Tensor`.
**RTs**: A named tuple `SimilarityTransform` containing
a batch of similarity transforms with fields:
**R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
**T**: Batch of translations of shape `(minibatch, d)`.
**s**: batch of scaling factors of shape `(minibatch, )`.
**t_history**: A list of named tuples `SimilarityTransform`
the transformation parameters after each ICP iteration.
[1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992.
[2] https://en.wikipedia.org/wiki/Iterative_closest_point
# make sure we convert input Pointclouds structures to
# padded tensors of shape (N, P, 3)
Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
b, size_X, dim = Xt.shape
if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
raise ValueError(
"Point sets X and Y have to have the same "
+ "number of batches and data dimensions."
if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
num_points_Y != num_points_X
# we have a heterogeneous input (e.g. because X/Y is
# an instance of Pointclouds)
mask_X = (
torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None]
< num_points_X[:, None]
mask_X = Xt.new_ones(b, size_X)
# clone the initial point cloud
Xt_init = Xt.clone()
if init_transform is not None:
# parse the initial transform from the input and apply to Xt
R, T, s = init_transform
assert (
R.shape == torch.Size((b, dim, dim))
and T.shape == torch.Size((b, dim))
and s.shape == torch.Size((b,))
except Exception:
raise ValueError(
"The initial transformation init_transform has to be "
"a named tuple SimilarityTransform with elements (R, T, s). "
"R are dim x dim orthonormal matrices of shape "
"(minibatch, dim, dim), T is a batch of dim-dimensional "
"translations of shape (minibatch, dim) and s is a batch "
"of scalars of shape (minibatch,)."
) from None
# apply the init transform to the input point cloud
Xt = _apply_similarity_transform(Xt, R, T, s)
# initialize the transformation with identity
R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
T = Xt.new_zeros((b, dim))
s = Xt.new_ones(b)
prev_rmse = None
rmse = None
iteration = -1
converged = False
# initialize the transformation history
t_history = []
# the main loop over ICP iterations
for iteration in range(max_iterations):
Xt_nn_points = knn_points(
Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
).knn[:, :, 0, :]
# get the alignment of the nearest neighbors from Yt with Xt_init
R, T, s = corresponding_points_alignment(
# apply the estimated similarity transform to Xt_init
Xt = _apply_similarity_transform(Xt_init, R, T, s)
# add the current transformation to the history
t_history.append(SimilarityTransform(R, T, s))
# compute the root mean squared error
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
# compute the relative rmse
if prev_rmse is None:
relative_rmse = rmse.new_ones(b)
relative_rmse = (prev_rmse - rmse) / prev_rmse
if verbose:
rmse_msg = (
f"ICP iteration {iteration}: mean/max rmse = "
+ f"{rmse.mean():1.2e}/{rmse.max():1.2e} "
+ f"; mean relative rmse = {relative_rmse.mean():1.2e}"
# check for convergence
if (relative_rmse <= relative_rmse_thr).all():
converged = True
# update the previous rmse
prev_rmse = rmse
if verbose:
if converged:
print(f"ICP has converged in {iteration + 1} iterations.")
print(f"ICP has not converged in {max_iterations} iterations.")
if oputil.is_pointclouds(X):
Xt = X.update_padded(Xt) # type: ignore
return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
# threshold for checking that point crosscorelation
# is full rank in corresponding_points_alignment
def corresponding_points_alignment(
X: Union[torch.Tensor, "Pointclouds"],
Y: Union[torch.Tensor, "Pointclouds"],
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
estimate_scale: bool = False,
allow_reflection: bool = False,
eps: float = 1e-9,
) -> SimilarityTransform:
Finds a similarity transformation (rotation `R`, translation `T`
and optionally scale `s`) between two given sets of corresponding
`d`-dimensional points `X` and `Y` such that:
`s[i] X[i] R[i] + T[i] = Y[i]`,
for all batch indexes `i` in the least squares sense.
The algorithm is also known as Umeyama [1].
**X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
**Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
or a `Pointclouds` object.
**weights**: Batch of non-negative weights of
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
tensors that may have different shapes; in that case, the length of
i-th tensor should be equal to the number of points in X_i and Y_i.
Passing `None` means uniform weights.
**estimate_scale**: If `True`, also estimates a scaling component `s`
of the transformation. Otherwise assumes an identity
scale and returns a tensor of ones.
**allow_reflection**: If `True`, allows the algorithm to return `R`
which is orthonormal but has determinant==-1.
**eps**: A scalar for clamping to avoid dividing by zero. Active for the
code that estimates the output scale `s`.
3-element named tuple `SimilarityTransform` containing
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
- **T**: Batch of translations of shape `(minibatch, d)`.
- **s**: batch of scaling factors of shape `(minibatch, )`.
[1] Shinji Umeyama: Least-Suqares Estimation of
Transformation Parameters Between Two Point Patterns
# make sure we convert input Pointclouds structures to tensors
Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
raise ValueError(
"Point sets X and Y have to have the same \
number of batches, points and dimensions."
if weights is not None:
if isinstance(weights, list):
if any(np != w.shape[0] for np, w in zip(num_points, weights)):
raise ValueError(
"number of weights should equal to the "
+ "number of points in the point cloud."
weights = [w[..., None] for w in weights]
weights = strutil.list_to_padded(weights)[..., 0]
if Xt.shape[:2] != weights.shape:
raise ValueError("weights should have the same first two dimensions as X.")
b, n, dim = Xt.shape
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
mask = (
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
< num_points[:, None]
weights = mask if weights is None else mask * weights.type_as(Xt)
# compute the centroids of the point sets
Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
Ymu = oputil.wmean(Yt, weight=weights, eps=eps)
# mean-center the point sets
Xc = Xt - Xmu
Yc = Yt - Ymu
total_weight = torch.clamp(num_points, 1)
# special handling for heterogeneous point clouds and/or input weights
if weights is not None:
Xc *= weights[:, :, None]
Yc *= weights[:, :, None]
total_weight = torch.clamp(weights.sum(1), eps)
if (num_points < (dim + 1)).any():
"The size of one of the point clouds is <= dim+1. "
+ "corresponding_points_alignment cannot return a unique rotation."
# compute the covariance XYcov between the point sets Xc, Yc
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
XYcov = XYcov / total_weight[:, None, None]
# decompose the covariance matrix XYcov
U, S, V = torch.svd(XYcov)
# catch ambiguous rotation by checking the magnitude of singular values
if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (
num_points < (dim + 1)
"Excessively low rank of "
+ "cross-correlation between aligned point clouds. "
+ "corresponding_points_alignment cannot return a unique rotation."
# identity matrix used for fixing reflections
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
if not allow_reflection:
# reflection test:
# checks whether the estimated rotation has det==1,
# if not, finds the nearest rotation s.t. det==1 by
# flipping the sign of the last singular vector U
R_test = torch.bmm(U, V.transpose(2, 1))
E[:, -1, -1] = torch.det(R_test)
# find the rotation matrix by composing U and V again
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
if estimate_scale:
# estimate the scaling component of the transformation
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
Xcov = (Xc * Xc).sum((1, 2)) / total_weight
# the scaling component
s = trace_ES / torch.clamp(Xcov, eps)
# translation component
T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
# translation component
T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]
# unit scaling since we do not estimate scale
s = T.new_ones(b)
return SimilarityTransform(R, T, s)
def _apply_similarity_transform(
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
) -> torch.Tensor:
Applies a similarity transformation parametrized with a batch of orthonormal
matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T`
of shape `(minibatch, d)` and a batch of scaling factors `s`
of shape `(minibatch,)` to a given `d`-dimensional cloud `X`
of shape `(minibatch, num_points, d)`
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
return X