Source code for pytorch3d.transforms.se3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from pytorch3d.common.compat import solve

from .so3 import _so3_exp_map, hat, so3_log_map


[docs]def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: """ Convert a batch of logarithmic representations of SE(3) matrices `log_transform` to a batch of 4x4 SE(3) matrices using the exponential map. See e.g. [1], Sec 9.4.2. for more detailed description. A SE(3) matrix has the following form: ``` [ R 0 ] [ T 1 ] , ``` where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector. SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. In the SE(3) logarithmic representation SE(3) matrices are represented as 6-dimensional vectors `[log_translation | log_rotation]`, i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. The conversion from the 6D representation to a 4x4 SE(3) matrix `transform` is done as follows: ``` transform = exp( [ hat(log_rotation) 0 ] [ log_translation 1 ] ) , ``` where `exp` is the matrix exponential and `hat` is the Hat operator [2]. Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi` (i.e. the rotation angle is between 0 and 2pi), the following identity holds: ``` se3_log_map(se3_exponential_map(log_transform)) == log_transform ``` The conversion has a singularity around `||log(transform)|| = 0` which is handled by clamping controlled with the `eps` argument. Args: log_transform: Batch of vectors of shape `(minibatch, 6)`. eps: A threshold for clipping the squared norm of the rotation logarithm to avoid unstable gradients in the singular case. Returns: Batch of transformation matrices of shape `(minibatch, 4, 4)`. Raises: ValueError if `log_transform` is of incorrect shape. [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf [2] https://en.wikipedia.org/wiki/Hat_operator """ if log_transform.ndim != 2 or log_transform.shape[1] != 6: raise ValueError("Expected input to be of shape (N, 6).") N, _ = log_transform.shape log_translation = log_transform[..., :3] log_rotation = log_transform[..., 3:] # rotation is an exponential map of log_rotation ( R, rotation_angles, log_rotation_hat, log_rotation_hat_square, ) = _so3_exp_map(log_rotation, eps=eps) # translation is V @ T V = _se3_V_matrix( log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles, eps=eps, ) T = torch.bmm(V, log_translation[:, :, None])[:, :, 0] transform = torch.zeros( N, 4, 4, dtype=log_transform.dtype, device=log_transform.device ) transform[:, :3, :3] = R transform[:, :3, 3] = T transform[:, 3, 3] = 1.0 return transform.permute(0, 2, 1)
[docs]def se3_log_map( transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4 ) -> torch.Tensor: """ Convert a batch of 4x4 transformation matrices `transform` to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices. See e.g. [1], Sec 9.4.2. for more detailed description. A SE(3) matrix has the following form: ``` [ R 0 ] [ T 1 ] , ``` where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector. SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. In the SE(3) logarithmic representation SE(3) matrices are represented as 6-dimensional vectors `[log_translation | log_rotation]`, i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. The conversion from the 4x4 SE(3) matrix `transform` to the 6D representation `log_transform = [log_translation | log_rotation]` is done as follows: ``` log_transform = log(transform) log_translation = log_transform[3, :3] log_rotation = inv_hat(log_transform[:3, :3]) ``` where `log` is the matrix logarithm and `inv_hat` is the inverse of the Hat operator [2]. Note that for any valid 4x4 `transform` matrix, the following identity holds: ``` se3_exp_map(se3_log_map(transform)) == transform ``` The conversion has a singularity around `(transform=I)` which is handled by clamping controlled with the `eps` and `cos_bound` arguments. Args: transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`. eps: A threshold for clipping the squared norm of the rotation logarithm to avoid division by zero in the singular case. cos_bound: Clamps the cosine of the rotation angle to [-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs. The non-finite outputs can be caused by passing small rotation angles to the `acos` function in `so3_rotation_angle` of `so3_log_map`. Returns: Batch of logarithms of input SE(3) matrices of shape `(minibatch, 6)`. Raises: ValueError if `transform` is of incorrect shape. ValueError if `R` has an unexpected trace. [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf [2] https://en.wikipedia.org/wiki/Hat_operator """ if transform.ndim != 3: raise ValueError("Input tensor shape has to be (N, 4, 4).") N, dim1, dim2 = transform.shape if dim1 != 4 or dim2 != 4: raise ValueError("Input tensor shape has to be (N, 4, 4).") if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])): raise ValueError("All elements of `transform[:, :3, 3]` should be 0.") # log_rot is just so3_log_map of the upper left 3x3 block R = transform[:, :3, :3].permute(0, 2, 1) log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound) # log_translation is V^-1 @ T T = transform[:, 3, :3] V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps) log_translation = solve(V, T[:, :, None])[:, :, 0] return torch.cat((log_translation, log_rotation), dim=1)
def _se3_V_matrix( log_rotation: torch.Tensor, log_rotation_hat: torch.Tensor, log_rotation_hat_square: torch.Tensor, rotation_angles: torch.Tensor, eps: float = 1e-4, ) -> torch.Tensor: """ A helper function that computes the "V" matrix from [1], Sec 9.4.2. [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf """ V = ( torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None] + log_rotation_hat * ((1 - torch.cos(rotation_angles)) / (rotation_angles ** 2))[:, None, None] + ( log_rotation_hat_square * ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles ** 3))[ :, None, None ] ) ) return V def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4): """ A helper function that computes the input variables to the `_se3_V_matrix` function. """ nrms = (log_rotation ** 2).sum(-1) rotation_angles = torch.clamp(nrms, eps).sqrt() log_rotation_hat = hat(log_rotation) log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat) return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles