Source code for pytorch3d.renderer.lighting

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
import torch.nn.functional as F

from ..common.datatypes import Device
from .utils import TensorProperties, convert_to_tensors_and_broadcast


[docs]def diffuse(normals, color, direction) -> torch.Tensor: """ Calculate the diffuse component of light reflection using Lambert's cosine law. Args: normals: (N, ..., 3) xyz normal vectors. Normals and points are expected to have the same shape. color: (1, 3) or (N, 3) RGB color of the diffuse component of the light. direction: (x,y,z) direction of the light Returns: colors: (N, ..., 3), same shape as the input points. The normals and light direction should be in the same coordinate frame i.e. if the points have been transformed from world -> view space then the normals and direction should also be in view space. NOTE: to use with the packed vertices (i.e. no batch dimension) reformat the inputs in the following way. .. code-block:: python Args: normals: (P, 3) color: (N, 3)[batch_idx, :] -> (P, 3) direction: (N, 3)[batch_idx, :] -> (P, 3) Returns: colors: (P, 3) where batch_idx is of shape (P). For meshes, batch_idx can be: meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx() depending on whether points refers to the vertex coordinates or average/interpolated face coordinates. """ # TODO: handle multiple directional lights per batch element. # TODO: handle attenuation. # Ensure color and location have same batch dimension as normals normals, color, direction = convert_to_tensors_and_broadcast( normals, color, direction, device=normals.device ) # Reshape direction and color so they have all the arbitrary intermediate # dimensions as normals. Assume first dim = batch dim and last dim = 3. points_dims = normals.shape[1:-1] expand_dims = (-1,) + (1,) * len(points_dims) + (3,) if direction.shape != normals.shape: direction = direction.view(expand_dims) if color.shape != normals.shape: color = color.view(expand_dims) # Renormalize the normals in case they have been interpolated. # We tried to replace the following with F.cosine_similarity, but it wasn't faster. normals = F.normalize(normals, p=2, dim=-1, eps=1e-6) direction = F.normalize(direction, p=2, dim=-1, eps=1e-6) angle = F.relu(torch.sum(normals * direction, dim=-1)) return color * angle[..., None]
[docs]def specular( points, normals, direction, color, camera_position, shininess ) -> torch.Tensor: """ Calculate the specular component of light reflection. Args: points: (N, ..., 3) xyz coordinates of the points. normals: (N, ..., 3) xyz normal vectors for each point. color: (N, 3) RGB color of the specular component of the light. direction: (N, 3) vector direction of the light. camera_position: (N, 3) The xyz position of the camera. shininess: (N) The specular exponent of the material. Returns: colors: (N, ..., 3), same shape as the input points. The points, normals, camera_position, and direction should be in the same coordinate frame i.e. if the points have been transformed from world -> view space then the normals, camera_position, and light direction should also be in view space. To use with a batch of packed points reindex in the following way. .. code-block:: python:: Args: points: (P, 3) normals: (P, 3) color: (N, 3)[batch_idx] -> (P, 3) direction: (N, 3)[batch_idx] -> (P, 3) camera_position: (N, 3)[batch_idx] -> (P, 3) shininess: (N)[batch_idx] -> (P) Returns: colors: (P, 3) where batch_idx is of shape (P). For meshes batch_idx can be: meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx(). """ # TODO: handle multiple directional lights # TODO: attenuate based on inverse squared distance to the light source if points.shape != normals.shape: msg = "Expected points and normals to have the same shape: got %r, %r" raise ValueError(msg % (points.shape, normals.shape)) # Ensure all inputs have same batch dimension as points matched_tensors = convert_to_tensors_and_broadcast( points, color, direction, camera_position, shininess, device=points.device ) _, color, direction, camera_position, shininess = matched_tensors # Reshape direction and color so they have all the arbitrary intermediate # dimensions as points. Assume first dim = batch dim and last dim = 3. points_dims = points.shape[1:-1] expand_dims = (-1,) + (1,) * len(points_dims) if direction.shape != normals.shape: direction = direction.view(expand_dims + (3,)) if color.shape != normals.shape: color = color.view(expand_dims + (3,)) if camera_position.shape != normals.shape: camera_position = camera_position.view(expand_dims + (3,)) if shininess.shape != normals.shape: shininess = shininess.view(expand_dims) # Renormalize the normals in case they have been interpolated. # We tried a version that uses F.cosine_similarity instead of renormalizing, # but it was slower. normals = F.normalize(normals, p=2, dim=-1, eps=1e-6) direction = F.normalize(direction, p=2, dim=-1, eps=1e-6) cos_angle = torch.sum(normals * direction, dim=-1) # No specular highlights if angle is less than 0. mask = (cos_angle > 0).to(torch.float32) # Calculate the specular reflection. view_direction = camera_position - points view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6) reflect_direction = -direction + 2 * (cos_angle[..., None] * normals) # Cosine of the angle between the reflected light ray and the viewer alpha = F.relu(torch.sum(view_direction * reflect_direction, dim=-1)) * mask return color * torch.pow(alpha, shininess)[..., None]
[docs]class DirectionalLights(TensorProperties):
[docs] def __init__( self, ambient_color=((0.5, 0.5, 0.5),), diffuse_color=((0.3, 0.3, 0.3),), specular_color=((0.2, 0.2, 0.2),), direction=((0, 1, 0),), device: Device = "cpu", ) -> None: """ Args: ambient_color: RGB color of the ambient component. diffuse_color: RGB color of the diffuse component. specular_color: RGB color of the specular component. direction: (x, y, z) direction vector of the light. device: Device (as str or torch.device) on which the tensors should be located The inputs can each be - 3 element tuple/list or list of lists - torch tensor of shape (1, 3) - torch tensor of shape (N, 3) The inputs are broadcast against each other so they all have batch dimension N. """ super().__init__( device=device, ambient_color=ambient_color, diffuse_color=diffuse_color, specular_color=specular_color, direction=direction, ) _validate_light_properties(self) if self.direction.shape[-1] != 3: msg = "Expected direction to have shape (N, 3); got %r" raise ValueError(msg % repr(self.direction.shape))
[docs] def clone(self): other = self.__class__(device=self.device) return super().clone(other)
[docs] def diffuse(self, normals, points=None) -> torch.Tensor: # NOTE: Points is not used but is kept in the args so that the API is # the same for directional and point lights. The call sites should not # need to know the light type. return diffuse( normals=normals, color=self.diffuse_color, direction=self.direction, )
[docs] def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: return specular( points=points, normals=normals, color=self.specular_color, direction=self.direction, camera_position=camera_position, shininess=shininess, )
[docs]class PointLights(TensorProperties):
[docs] def __init__( self, ambient_color=((0.5, 0.5, 0.5),), diffuse_color=((0.3, 0.3, 0.3),), specular_color=((0.2, 0.2, 0.2),), location=((0, 1, 0),), device: Device = "cpu", ) -> None: """ Args: ambient_color: RGB color of the ambient component diffuse_color: RGB color of the diffuse component specular_color: RGB color of the specular component location: xyz position of the light. device: Device (as str or torch.device) on which the tensors should be located The inputs can each be - 3 element tuple/list or list of lists - torch tensor of shape (1, 3) - torch tensor of shape (N, 3) The inputs are broadcast against each other so they all have batch dimension N. """ super().__init__( device=device, ambient_color=ambient_color, diffuse_color=diffuse_color, specular_color=specular_color, location=location, ) _validate_light_properties(self) if self.location.shape[-1] != 3: msg = "Expected location to have shape (N, 3); got %r" raise ValueError(msg % repr(self.location.shape))
[docs] def clone(self): other = self.__class__(device=self.device) return super().clone(other)
[docs] def reshape_location(self, points) -> torch.Tensor: """ Reshape the location tensor to have dimensions compatible with the points which can either be of shape (P, 3) or (N, H, W, K, 3). """ if self.location.ndim == points.ndim: # pyre-fixme[7] return self.location # pyre-fixme[29] return self.location[:, None, None, None, :]
[docs] def diffuse(self, normals, points) -> torch.Tensor: location = self.reshape_location(points) direction = location - points return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
[docs] def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: location = self.reshape_location(points) direction = location - points return specular( points=points, normals=normals, color=self.specular_color, direction=direction, camera_position=camera_position, shininess=shininess, )
[docs]class AmbientLights(TensorProperties): """ A light object representing the same color of light everywhere. By default, this is white, which effectively means lighting is not used in rendering. """
[docs] def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None: """ If ambient_color is provided, it should be a sequence of triples of floats. Args: ambient_color: RGB color device: Device (as str or torch.device) on which the tensors should be located The ambient_color if provided, should be - 3 element tuple/list or list of lists - torch tensor of shape (1, 3) - torch tensor of shape (N, 3) """ if ambient_color is None: ambient_color = ((1.0, 1.0, 1.0),) super().__init__(ambient_color=ambient_color, device=device)
[docs] def clone(self): other = self.__class__(device=self.device) return super().clone(other)
[docs] def diffuse(self, normals, points) -> torch.Tensor: return torch.zeros_like(points)
[docs] def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: return torch.zeros_like(points)
def _validate_light_properties(obj) -> None: props = ("ambient_color", "diffuse_color", "specular_color") for n in props: t = getattr(obj, n) if t.shape[-1] != 3: msg = "Expected %s to have shape (N, 3); got %r" raise ValueError(msg % (n, t.shape))