pytorch3d.transforms

pytorch3d.transforms.euler_angles_to_matrix(euler_angles, convention: str)[source]

Convert rotations given as Euler angles in radians to rotation matrices.

Parameters:
  • euler_angles – Euler angles in radians as tensor of shape (…, 3).
  • convention – Convention string of three uppercase letters from {“X”, “Y”, and “Z”}.
Returns:

Rotation matrices as tensor of shape (…, 3, 3).

pytorch3d.transforms.matrix_to_euler_angles(matrix, convention: str)[source]

Convert rotations given as rotation matrices to Euler angles in radians.

Parameters:
  • matrix – Rotation matrices as tensor of shape (…, 3, 3).
  • convention – Convention string of three uppercase letters.
Returns:

Euler angles in radians as tensor of shape (…, 3).

pytorch3d.transforms.matrix_to_quaternion(matrix)[source]

Convert rotations given as rotation matrices to quaternions.

Parameters:matrix – Rotation matrices as tensor of shape (…, 3, 3).
Returns:quaternions with real part first, as tensor of shape (…, 4).
pytorch3d.transforms.quaternion_apply(quaternion, point)[source]

Apply the rotation given by a quaternion to a 3D point. Usual torch rules for broadcasting apply.

Parameters:
  • quaternion – Tensor of quaternions, real part first, of shape (…, 4).
  • point – Tensor of 3D points of shape (…, 3).
Returns:

Tensor of rotated points of shape (…, 3).

pytorch3d.transforms.quaternion_invert(quaternion)[source]

Given a quaternion representing rotation, get the quaternion representing its inverse.

Parameters:quaternion – Quaternions as tensor of shape (…, 4), with real part first, which must be versors (unit quaternions).
Returns:The inverse, a tensor of quaternions of shape (…, 4).
pytorch3d.transforms.quaternion_multiply(a, b)[source]

Multiply two quaternions representing rotations, returning the quaternion representing their composition, i.e. the versor with nonnegative real part. Usual torch rules for broadcasting apply.

Parameters:
  • a – Quaternions as tensor of shape (…, 4), real part first.
  • b – Quaternions as tensor of shape (…, 4), real part first.
Returns:

The product of a and b, a tensor of quaternions of shape (…, 4).

pytorch3d.transforms.quaternion_raw_multiply(a, b)[source]

Multiply two quaternions. Usual torch rules for broadcasting apply.

Parameters:
  • a – Quaternions as tensor of shape (…, 4), real part first.
  • b – Quaternions as tensor of shape (…, 4), real part first.
Returns:

The product of a and b, a tensor of quaternions shape (…, 4).

pytorch3d.transforms.quaternion_to_matrix(quaternions)[source]

Convert rotations given as quaternions to rotation matrices.

Parameters:quaternions – quaternions with real part first, as tensor of shape (…, 4).
Returns:Rotation matrices as tensor of shape (…, 3, 3).
pytorch3d.transforms.random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False)[source]

Generate random quaternions representing rotations, i.e. versors with nonnegative real part.

Parameters:
  • n – Number of quaternions in a batch to return.
  • dtype – Type to return.
  • device – Desired device of returned tensor. Default: uses the current device for the default tensor type.
  • requires_grad – Whether the resulting tensor should have the gradient flag set.
Returns:

Quaternions as tensor of shape (N, 4).

pytorch3d.transforms.random_rotation(dtype: Optional[torch.dtype] = None, device=None, requires_grad=False)[source]

Generate a single random 3x3 rotation matrix.

Parameters:
  • dtype – Type to return
  • device – Device of returned tensor. Default: if None, uses the current device for the default tensor type
  • requires_grad – Whether the resulting tensor should have the gradient flag set
Returns:

Rotation matrix as tensor of shape (3, 3).

pytorch3d.transforms.random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False)[source]

Generate random rotations as 3x3 rotation matrices.

Parameters:
  • n – Number of rotation matrices in a batch to return.
  • dtype – Type to return.
  • device – Device of returned tensor. Default: if None, uses the current device for the default tensor type.
  • requires_grad – Whether the resulting tensor should have the gradient flag set.
Returns:

Rotation matrices as tensor of shape (n, 3, 3).

pytorch3d.transforms.standardize_quaternion(quaternions)[source]

Convert a unit quaternion to a standard form: one in which the real part is non negative.

Parameters:quaternions – Quaternions with real part first, as tensor of shape (…, 4).
Returns:Standardized quaternions as tensor of shape (…, 4).
pytorch3d.transforms.so3_exponential_map(log_rot, eps: float = 0.0001)[source]

Convert a batch of logarithmic representations of rotation matrices log_rot to a batch of 3x3 rotation matrices using Rodrigues formula [1].

In the logarithmic representation, each rotation matrix is represented as a 3-dimensional vector (log_rot) who’s l2-norm and direction correspond to the magnitude of the rotation angle and the axis of rotation respectively.

The conversion has a singularity around log(R) = 0 which is handled by clamping controlled with the eps argument.

Parameters:
  • log_rot – Batch of vectors of shape (minibatch , 3).
  • eps – A float constant handling the conversion singularity.
Returns:

Batch of rotation matrices of shape (minibatch , 3 , 3).

Raises:

ValueError if log_rot is of incorrect shape.

[1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula

pytorch3d.transforms.so3_log_map(R, eps: float = 0.0001)[source]

Convert a batch of 3x3 rotation matrices R to a batch of 3-dimensional matrix logarithms of rotation matrices The conversion has a singularity around (R=I) which is handled by clamping controlled with the eps argument.

Parameters:
  • R – batch of rotation matrices of shape (minibatch, 3, 3).
  • eps – A float constant handling the conversion singularity.
Returns:

Batch of logarithms of input rotation matrices of shape (minibatch, 3).

Raises:
  • ValueError if R is of incorrect shape.
  • ValueError if R has an unexpected trace.
pytorch3d.transforms.so3_relative_angle(R1, R2, cos_angle: bool = False)[source]

Calculates the relative angle (in radians) between pairs of rotation matrices R1 and R2 with angle = acos(0.5 * (Trace(R1 R2^T)-1))

Note

This corresponds to a geodesic distance on the 3D manifold of rotation matrices.

Parameters:
  • R1 – Batch of rotation matrices of shape (minibatch, 3, 3).
  • R2 – Batch of rotation matrices of shape (minibatch, 3, 3).
  • cos_angle – If==True return cosine of the relative angle rather than the angle itself. This can avoid the unstable calculation of acos.
Returns:

Corresponding rotation angles of shape (minibatch,). If cos_angle==True, returns the cosine of the angles.

Raises:
  • ValueError if R1 or R2 is of incorrect shape.
  • ValueError if R1 or R2 has an unexpected trace.
pytorch3d.transforms.so3_rotation_angle(R, eps: float = 0.0001, cos_angle: bool = False)[source]

Calculates angles (in radians) of a batch of rotation matrices R with angle = acos(0.5 * (Trace(R)-1)). The trace of the input matrices is checked to be in the valid range [-1-eps,3+eps]. The eps argument is a small constant that allows for small errors caused by limited machine precision.

Parameters:
  • R – Batch of rotation matrices of shape (minibatch, 3, 3).
  • eps – Tolerance for the valid trace check.
  • cos_angle – If==True return cosine of the rotation angles rather than the angle itself. This can avoid the unstable calculation of acos.
Returns:

Corresponding rotation angles of shape (minibatch,). If cos_angle==True, returns the cosine of the angles.

Raises:
  • ValueError if R is of incorrect shape.
  • ValueError if R has an unexpected trace.
class pytorch3d.transforms.Rotate(R, dtype=torch.float32, device: str = 'cpu', orthogonal_tol: float = 1e-05)[source]

Bases: pytorch3d.transforms.transform3d.Transform3d

__init__(R, dtype=torch.float32, device: str = 'cpu', orthogonal_tol: float = 1e-05)[source]

Create a new Transform3d representing 3D rotation using a rotation matrix as the input.

Parameters:
  • R – a tensor of shape (3, 3) or (N, 3, 3)
  • orthogonal_tol – tolerance for the test of the orthogonality of R
class pytorch3d.transforms.RotateAxisAngle(angle, axis: str = 'X', degrees: bool = True, dtype=torch.float64, device: str = 'cpu')[source]

Bases: pytorch3d.transforms.transform3d.Rotate

__init__(angle, axis: str = 'X', degrees: bool = True, dtype=torch.float64, device: str = 'cpu')[source]

Create a new Transform3d representing 3D rotation about an axis by an angle.

Assuming a right-hand coordinate system, positive rotation angles result in a counter clockwise rotation.

Parameters:
  • angle
    • A torch tensor of shape (N,)
    • A python scalar
    • A torch scalar
  • axis – string: one of [“X”, “Y”, “Z”] indicating the axis about which to rotate. NOTE: All batch elements are rotated about the same axis.
class pytorch3d.transforms.Scale(x, y=None, z=None, dtype=torch.float32, device: str = 'cpu')[source]

Bases: pytorch3d.transforms.transform3d.Transform3d

__init__(x, y=None, z=None, dtype=torch.float32, device: str = 'cpu')[source]

A Transform3d representing a scaling operation, with different scale factors along each coordinate axis.

Option I: Scale(s, dtype=torch.float32, device=’cpu’)
s can be one of
  • Python scalar or torch scalar: Single uniform scale
  • 1D torch tensor of shape (N,): A batch of uniform scale
  • 2D torch tensor of shape (N, 3): Scale differently along each axis
Option II: Scale(x, y, z, dtype=torch.float32, device=’cpu’)
Each of x, y, and z can be one of
  • python scalar
  • torch scalar
  • 1D torch tensor
class pytorch3d.transforms.Transform3d(dtype: torch.dtype = torch.float32, device='cpu', matrix: Optional[torch.Tensor] = None)[source]

Bases: object

A Transform3d object encapsulates a batch of N 3D transformations, and knows how to transform points and normal vectors. Suppose that t is a Transform3d; then we can do the following:

N = len(t)
points = torch.randn(N, P, 3)
normals = torch.randn(N, P, 3)
points_transformed = t.transform_points(points)    # => (N, P, 3)
normals_transformed = t.transform_normals(normals)  # => (N, P, 3)

BROADCASTING Transform3d objects supports broadcasting. Suppose that t1 and tN are Transform3D objects with len(t1) == 1 and len(tN) == N respectively. Then we can broadcast transforms like this:

t1.transform_points(torch.randn(P, 3))     # => (P, 3)
t1.transform_points(torch.randn(1, P, 3))  # => (1, P, 3)
t1.transform_points(torch.randn(M, P, 3))  # => (M, P, 3)
tN.transform_points(torch.randn(P, 3))     # => (N, P, 3)
tN.transform_points(torch.randn(1, P, 3))  # => (N, P, 3)

COMBINING TRANSFORMS Transform3d objects can be combined in two ways: composing and stacking. Composing is function composition. Given Transform3d objects t1, t2, t3, the following all compute the same thing:

y1 = t3.transform_points(t2.transform_points(t1.transform_points(x)))
y2 = t1.compose(t2).compose(t3).transform_points(x)
y3 = t1.compose(t2, t3).transform_points(x)

Composing transforms should broadcast.

if len(t1) == 1 and len(t2) == N, then len(t1.compose(t2)) == N.

We can also stack a sequence of Transform3d objects, which represents composition along the batch dimension; then the following should compute the same thing.

N, M = len(tN), len(tM)
xN = torch.randn(N, P, 3)
xM = torch.randn(M, P, 3)
y1 = torch.cat([tN.transform_points(xN), tM.transform_points(xM)], dim=0)
y2 = tN.stack(tM).transform_points(torch.cat([xN, xM], dim=0))

BUILDING TRANSFORMS We provide convenience methods for easily building Transform3d objects as compositions of basic transforms.

# Scale by 0.5, then translate by (1, 2, 3)
t1 = Transform3d().scale(0.5).translate(1, 2, 3)

# Scale each axis by a different amount, then translate, then scale
t2 = Transform3d().scale(1, 3, 3).translate(2, 3, 1).scale(2.0)

t3 = t1.compose(t2)
tN = t1.stack(t3, t3)

BACKPROP THROUGH TRANSFORMS When building transforms, we can also parameterize them by Torch tensors; in this case we can backprop through the construction and application of Transform objects, so they could be learned via gradient descent or predicted by a neural network.

s1_params = torch.randn(N, requires_grad=True)
t_params = torch.randn(N, 3, requires_grad=True)
s2_params = torch.randn(N, 3, requires_grad=True)

t = Transform3d().scale(s1_params).translate(t_params).scale(s2_params)
x = torch.randn(N, 3)
y = t.transform_points(x)
loss = compute_loss(y)
loss.backward()

with torch.no_grad():
    s1_params -= lr * s1_params.grad
    t_params -= lr * t_params.grad
    s2_params -= lr * s2_params.grad

CONVENTIONS We adopt a right-hand coordinate system, meaning that rotation about an axis with a positive angle results in a counter clockwise rotation.

This class assumes that transformations are applied on inputs which are row vectors. The internal representation of the Nx4x4 transformation matrix is of the form:

M = [
        [Rxx, Ryx, Rzx, 0],
        [Rxy, Ryy, Rzy, 0],
        [Rxz, Ryz, Rzz, 0],
        [Tx,  Ty,  Tz,  1],
    ]

To apply the transformation to points which are row vectors, the M matrix can be pre multiplied by the points:

points = [[0, 1, 2]]  # (1 x 3) xyz coordinates of a point
transformed_points = points * M
__init__(dtype: torch.dtype = torch.float32, device='cpu', matrix: Optional[torch.Tensor] = None)[source]
Parameters:
  • dtype – The data type of the transformation matrix. to be used if matrix = None.
  • device – The device for storing the implemented transformation. If matrix != None, uses the device of input matrix.
  • matrix – A tensor of shape (4, 4) or of shape (minibatch, 4, 4) representing the 4x4 3D transformation matrix. If None, initializes with identity using the specified device and dtype.
compose(*others)[source]

Return a new Transform3d with the tranforms to compose stored as an internal list.

Parameters:*others – Any number of Transform3d objects
Returns:A new Transform3d with the stored transforms
get_matrix()[source]

Return a matrix which is the result of composing this transform with others stored in self.transforms. Where necessary transforms are broadcast against each other. For example, if self.transforms contains transforms t1, t2, and t3, and given a set of points x, the following should be true:

y1 = t1.compose(t2, t3).transform(x)
y2 = t3.transform(t2.transform(t1.transform(x)))
y1.get_matrix() == y2.get_matrix()
Returns:A transformation matrix representing the composed inputs.
inverse(invert_composed: bool = False)[source]

Returns a new Transform3D object that represents an inverse of the current transformation.

Parameters:invert_composed
  • True: First compose the list of stored transformations and then apply inverse to the result. This is potentially slower for classes of transformations with inverses that can be computed efficiently (e.g. rotations and translations).
  • False: Invert the individual stored transformations independently without composing them.
Returns:A new Transform3D object contaning the inverse of the original transformation.
stack(*others)[source]
transform_points(points, eps: Optional[float] = None)[source]

Use this transform to transform a set of 3D points. Assumes row major ordering of the input points.

Parameters:
  • points – Tensor of shape (P, 3) or (N, P, 3)
  • eps – If eps!=None, the argument is used to clamp the last coordinate before peforming the final division. The clamping corresponds to: last_coord := (last_coord.sign() + (last_coord==0)) * torch.clamp(last_coord.abs(), eps), i.e. the last coordinates that are exactly 0 will be clamped to +eps.
Returns:

points_out – points of shape (N, P, 3) or (P, 3) depending on the dimensions of the transform

transform_normals(normals)[source]

Use this transform to transform a set of normal vectors.

Parameters:normals – Tensor of shape (P, 3) or (N, P, 3)
Returns:normals_out – Tensor of shape (P, 3) or (N, P, 3) depending on the dimensions of the transform
translate(*args, **kwargs)[source]
scale(*args, **kwargs)[source]
rotate_axis_angle(*args, **kwargs)[source]
clone()[source]

Deep copy of Transforms object. All internal tensors are cloned individually.

Returns:new Transforms object.
to(device, copy: bool = False, dtype=None)[source]

Match functionality of torch.Tensor.to() If copy = True or the self Tensor is on a different device, the returned tensor is a copy of self with the desired torch.device. If copy = False and the self Tensor already has the correct torch.device, then self is returned.

Parameters:
  • device – Device id for the new tensor.
  • copy – Boolean indicator whether or not to clone self. Default False.
  • dtype – If not None, casts the internal tensor variables to a given torch.dtype.
Returns:

Transform3d object.

cpu()[source]
cuda()[source]
class pytorch3d.transforms.Translate(x, y=None, z=None, dtype=torch.float32, device: str = 'cpu')[source]

Bases: pytorch3d.transforms.transform3d.Transform3d

__init__(x, y=None, z=None, dtype=torch.float32, device: str = 'cpu')[source]

Create a new Transform3d representing 3D translations.

Option I: Translate(xyz, dtype=torch.float32, device=’cpu’)
xyz should be a tensor of shape (N, 3)
Option II: Translate(x, y, z, dtype=torch.float32, device=’cpu’)

Here x, y, and z will be broadcast against each other and concatenated to form the translation. Each can be:

  • A python scalar
  • A torch scalar
  • A 1D torch tensor