Source code for pytorch3d.ops.perspective_n_points

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

"""
This file contains Efficient PnP algorithm for Perspective-n-Points problem.
It finds a camera position (defined by rotation `R` and translation `T`) that
minimizes re-projection error between the given 3D points `x` and
the corresponding uncalibrated 2D points `y`.
"""

import warnings
from typing import NamedTuple, Optional

import torch
import torch.nn.functional as F
from pytorch3d.ops import points_alignment, utils as oputil


class EpnpSolution(NamedTuple):
    x_cam: torch.Tensor
    R: torch.Tensor
    T: torch.Tensor
    err_2d: torch.Tensor
    err_3d: torch.Tensor


def _define_control_points(x, weight, storage_opts=None):
    """
    Returns control points that define barycentric coordinates
    Args:
        x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
        weight: Batch of non-negative weights of
            shape `(minibatch, num_point)`. `None` means equal weights.
        storage_opts: dict of keyword arguments to the tensor constructor.
    """
    storage_opts = storage_opts or {}
    x_mean = oputil.wmean(x, weight)
    c_world = F.pad(torch.eye(3, **storage_opts), (0, 0, 0, 1), value=0.0).expand_as(
        x[:, :4, :]
    )
    return c_world + x_mean


def _compute_alphas(x, c_world):
    """
    Computes barycentric coordinates of x in the frame c_world.
    Args:
        x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
        c_world: control points in world coordinates.
    """
    x = F.pad(x, (0, 1), value=1.0)
    c = F.pad(c_world, (0, 1), value=1.0)
    return torch.matmul(x, torch.inverse(c))  # B x N x 4


def _build_M(y, alphas, weight):
    """Returns the matrix defining the reprojection equations.
    Args:
        y: projected points in camera coordinates of size B x N x 2
        alphas: barycentric coordinates of size B x N x 4
        weight: Batch of non-negative weights of
            shape `(minibatch, num_point)`. `None` means equal weights.
    """
    bs, n, _ = y.size()

    # prepend t with the column of v's
    def prepad(t, v):
        return F.pad(t, (1, 0), value=v)

    if weight is not None:
        # weight the alphas in order to get a correctly weighted version of M
        alphas = alphas * weight[:, :, None]

    # outer left-multiply by alphas
    def lm_alphas(t):
        return torch.matmul(alphas[..., None], t).reshape(bs, n, 12)

    M = torch.cat(
        (
            lm_alphas(
                prepad(prepad(-y[:, :, 0, None, None], 0.0), 1.0)
            ),  # u constraints
            lm_alphas(
                prepad(prepad(-y[:, :, 1, None, None], 1.0), 0.0)
            ),  # v constraints
        ),
        dim=-1,
    ).reshape(bs, -1, 12)

    return M


def _null_space(m, kernel_dim):
    """Finds the null space (kernel) basis of the matrix
    Args:
        m: the batch of input matrices, B x N x 12
        kernel_dim: number of dimensions to approximate the kernel
    Returns:
        * a batch of null space basis vectors
            of size B x 4 x 3 x kernel_dim
        * a batch of spectral values where near-0s correspond to actual
            kernel vectors, of size B x kernel_dim
    """
    mTm = torch.bmm(m.transpose(1, 2), m)
    s, v = torch.linalg.eigh(mTm)
    return v[:, :, :kernel_dim].reshape(-1, 4, 3, kernel_dim), s[:, :kernel_dim]


def _reproj_error(y_hat, y, weight, eps=1e-9):
    """Projects estimated 3D points and computes the reprojection error
    Args:
        y_hat: a batch of predicted 2D points in homogeneous coordinates
        y: a batch of ground-truth 2D points
        weight: Batch of non-negative weights of
            shape `(minibatch, num_point)`. `None` means equal weights.
    Returns:
        Optionally weighted RMSE of difference between y and y_hat.
    """
    y_hat = y_hat / torch.clamp(y_hat[..., 2:], eps)
    dist = ((y - y_hat[..., :2]) ** 2).sum(dim=-1, keepdim=True) ** 0.5
    return oputil.wmean(dist, weight)[:, 0, 0]


def _algebraic_error(x_w_rotated, x_cam, weight):
    """Computes the residual of Umeyama in 3D.
    Args:
        x_w_rotated: The given 3D points rotated with the predicted camera.
        x_cam: the lifted 2D points y
        weight: Batch of non-negative weights of
            shape `(minibatch, num_point)`. `None` means equal weights.
    Returns:
        Optionally weighted MSE of difference between x_w_rotated and x_cam.
    """
    dist = ((x_w_rotated - x_cam) ** 2).sum(dim=-1, keepdim=True)
    return oputil.wmean(dist, weight)[:, 0, 0]


def _compute_norm_sign_scaling_factor(c_cam, alphas, x_world, y, weight, eps=1e-9):
    """Given a solution, adjusts the scale and flip
    Args:
        c_cam: control points in camera coordinates
        alphas: barycentric coordinates of the points
        x_world: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`.
        y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`.
        weights: Batch of non-negative weights of
            shape `(minibatch, num_point)`. `None` means equal weights.
        eps: epsilon to threshold negative `z` values
    """
    # position of reference points in camera coordinates
    x_cam = torch.matmul(alphas, c_cam)

    x_cam = x_cam * (1.0 - 2.0 * (oputil.wmean(x_cam[..., 2:], weight) < 0).float())
    if torch.any(x_cam[..., 2:] < -eps):
        neg_rate = oputil.wmean((x_cam[..., 2:] < 0).float(), weight, dim=(0, 1)).item()
        warnings.warn("\nEPnP: %2.2f%% points have z<0." % (neg_rate * 100.0))

    R, T, s = points_alignment.corresponding_points_alignment(
        x_world, x_cam, weight, estimate_scale=True
    )
    s = s.clamp(eps)
    x_cam = x_cam / s[:, None, None]
    T = T / s[:, None]
    x_w_rotated = torch.matmul(x_world, R) + T[:, None, :]
    err_2d = _reproj_error(x_w_rotated, y, weight)
    err_3d = _algebraic_error(x_w_rotated, x_cam, weight)

    return EpnpSolution(x_cam, R, T, err_2d, err_3d)


def _gen_pairs(input, dim=-2, reducer=lambda a, b: ((a - b) ** 2).sum(dim=-1)):
    """Generates all pairs of different rows and then applies the reducer
    Args:
        input: a tensor
        dim: a dimension to generate pairs across
        reducer: a function of generated pair of rows to apply (beyond just concat)
    Returns:
        for default args, for A x B x C input, will output A x (B choose 2)
    """
    n = input.size()[dim]
    range = torch.arange(n)
    idx = torch.combinations(range).to(input).long()
    left = input.index_select(dim, idx[:, 0])
    right = input.index_select(dim, idx[:, 1])
    return reducer(left, right)


def _kernel_vec_distances(v):
    """Computes the coefficients for linearization of the quadratic system
        to match all pairwise distances between 4 control points (dim=1).
        The last dimension corresponds to the coefficients for quadratic terms
        Bij = Bi * Bj, where Bi and Bj correspond to kernel vectors.
    Arg:
        v: tensor of B x 4 x 3 x D, where D is dim(kernel), usually 4
    Returns:
        a tensor of B x 6 x [(D choose 2) + D];
        for D=4, the last dim means [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34].
    """
    dv = _gen_pairs(v, dim=-3, reducer=lambda a, b: a - b)  # B x 6 x 3 x D

    # we should take dot-product of all (i,j), i < j, with coeff 2
    rows_2ij = 2.0 * _gen_pairs(dv, dim=-1, reducer=lambda a, b: (a * b).sum(dim=-2))
    # this should produce B x 6 x (D choose 2) tensor

    # we should take dot-product of all (i,i)
    rows_ii = (dv**2).sum(dim=-2)
    # this should produce B x 6 x D tensor

    return torch.cat((rows_ii, rows_2ij), dim=-1)


def _solve_lstsq_subcols(rhs, lhs, lhs_col_idx):
    """Solves an over-determined linear system for selected LHS columns.
        A batched version of `torch.lstsq`.
    Args:
        rhs: right-hand side vectors
        lhs: left-hand side matrices
        lhs_col_idx: a slice of columns in lhs
    Returns:
        a least-squares solution for lhs * X = rhs
    """
    lhs = lhs.index_select(-1, torch.tensor(lhs_col_idx, device=lhs.device).long())
    return torch.matmul(torch.pinverse(lhs), rhs[:, :, None])


def _binary_sign(t):
    return (t >= 0).to(t) * 2.0 - 1.0


def _find_null_space_coords_1(kernel_dsts, cw_dst, eps=1e-9):
    """Solves case 1 from the paper [1]; solve for 4 coefficients:
       [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
         ^               ^   ^   ^
    Args:
        kernel_dsts: distances between kernel vectors
        cw_dst: distances between control points
    Returns:
        coefficients to weight kernel vectors
    [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
    EPnP: An Accurate O(n) solution to the PnP problem.
    International Journal of Computer Vision.
    https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
    """
    beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 5, 6])

    beta = beta * _binary_sign(beta[:, :1, :])
    return beta / torch.clamp(beta[:, :1, :] ** 0.5, eps)


def _find_null_space_coords_2(kernel_dsts, cw_dst):
    """Solves case 2 from the paper; solve for 3 coefficients:
        [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
          ^   ^           ^
    Args:
        kernel_dsts: distances between kernel vectors
        cw_dst: distances between control points
    Returns:
        coefficients to weight kernel vectors
    [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
    EPnP: An Accurate O(n) solution to the PnP problem.
    International Journal of Computer Vision.
    https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
    """
    beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1])

    coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :])
    coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
        (beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
    ).float()

    return torch.cat((coord_0, coord_1, torch.zeros_like(beta[:, :2, :])), dim=1)


def _find_null_space_coords_3(kernel_dsts, cw_dst, eps=1e-9):
    """Solves case 3 from the paper; solve for 5 coefficients:
        [B11 B22 B33 B44 B12 B13 B14 B23 B24 B34]
          ^   ^           ^   ^       ^
    Args:
        kernel_dsts: distances between kernel vectors
        cw_dst: distances between control points
    Returns:
        coefficients to weight kernel vectors
    [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009).
    EPnP: An Accurate O(n) solution to the PnP problem.
    International Journal of Computer Vision.
    https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/
    """
    beta = _solve_lstsq_subcols(cw_dst, kernel_dsts, [0, 4, 1, 5, 7])

    coord_0 = (beta[:, :1, :].abs() ** 0.5) * _binary_sign(beta[:, 1:2, :])
    coord_1 = (beta[:, 2:3, :].abs() ** 0.5) * (
        (beta[:, :1, :] >= 0) == (beta[:, 2:3, :] >= 0)
    ).float()
    coord_2 = beta[:, 3:4, :] / torch.clamp(coord_0[:, :1, :], eps)

    return torch.cat(
        (coord_0, coord_1, coord_2, torch.zeros_like(beta[:, :1, :])), dim=1
    )


[docs] def efficient_pnp( x: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None, skip_quadratic_eq: bool = False, ) -> EpnpSolution: """ Implements Efficient PnP algorithm [1] for Perspective-n-Points problem: finds a camera position (defined by rotation `R` and translation `T`) that minimizes re-projection error between the given 3D points `x` and the corresponding uncalibrated 2D points `y`, i.e. solves `y[i] = Proj(x[i] R[i] + T[i])` in the least-squares sense, where `i` are indices within the batch, and `Proj` is the perspective projection operator: `Proj([x y z]) = [x/z y/z]`. In the noise-less case, 4 points are enough to find the solution as long as they are not co-planar. Args: x: Batch of 3-dimensional points of shape `(minibatch, num_points, 3)`. y: Batch of 2-dimensional points of shape `(minibatch, num_points, 2)`. weights: Batch of non-negative weights of shape `(minibatch, num_point)`. `None` means equal weights. skip_quadratic_eq: If True, assumes the solution space for the linear system is one-dimensional, i.e. takes the scaled eigenvector that corresponds to the smallest eigenvalue as a solution. If False, finds the candidate coordinates in the potentially 4D null space by approximately solving the systems of quadratic equations. The best candidate is chosen by examining the 2D re-projection error. While this option finds a better solution, especially when the number of points is small or perspective distortions are low (the points are far away), it may be more difficult to back-propagate through. Returns: `EpnpSolution` namedtuple containing elements: **x_cam**: Batch of transformed points `x` that is used to find the camera parameters, of shape `(minibatch, num_points, 3)`. In the general (noisy) case, they are not exactly equal to `x[i] R[i] + T[i]` but are some affine transform of `x[i]`s. **R**: Batch of rotation matrices of shape `(minibatch, 3, 3)`. **T**: Batch of translation vectors of shape `(minibatch, 3)`. **err_2d**: Batch of mean 2D re-projection errors of shape `(minibatch,)`. Specifically, if `yhat` is the re-projection for the `i`-th batch element, it returns `sum_j norm(yhat_j - y_j)` where `j` iterates over points and `norm` denotes the L2 norm. **err_3d**: Batch of mean algebraic errors of shape `(minibatch,)`. Specifically, those are squared distances between `x_world` and estimated points on the rays defined by `y`. [1] Moreno-Noguer, F., Lepetit, V., & Fua, P. (2009). EPnP: An Accurate O(n) solution to the PnP problem. International Journal of Computer Vision. https://www.epfl.ch/labs/cvlab/software/multi-view-stereo/epnp/ """ # define control points in a world coordinate system (centered on the 3d # points centroid); 4 x 3 # TODO: more stable when initialised with the center and eigenvectors! c_world = _define_control_points( x.detach(), weights, storage_opts={"dtype": x.dtype, "device": x.device} ) # find the linear combination of the control points to represent the 3d points alphas = _compute_alphas(x, c_world) M = _build_M(y, alphas, weights) # Compute kernel M kernel, spectrum = _null_space(M, 4) c_world_distances = _gen_pairs(c_world) kernel_dsts = _kernel_vec_distances(kernel) betas = ( [] if skip_quadratic_eq else [ fnsc(kernel_dsts, c_world_distances) for fnsc in [ _find_null_space_coords_1, _find_null_space_coords_2, _find_null_space_coords_3, ] ] ) c_cam_variants = [kernel] + [ torch.matmul(kernel, beta[:, None, :, :]) for beta in betas ] solutions = [ _compute_norm_sign_scaling_factor(c_cam[..., 0], alphas, x, y, weights) for c_cam in c_cam_variants ] sol_zipped = EpnpSolution(*(torch.stack(list(col)) for col in zip(*solutions))) best = torch.argmin(sol_zipped.err_2d, dim=0) def gather1d(source, idx): # reduces the dim=1 by picking the slices in a 1D tensor idx # in other words, it is batched index_select. return source.gather( 0, idx.reshape(1, -1, *([1] * (len(source.shape) - 2))).expand_as(source[:1]), )[0] return EpnpSolution(*[gather1d(sol_col, best) for sol_col in sol_zipped])