Source code for pytorch3d.implicitron.tools.eval_video_trajectory

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
import math
from typing import Optional, Tuple

import torch
from pytorch3d.implicitron.tools import utils
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
from pytorch3d.transforms import Scale


logger = logging.getLogger(__name__)


[docs] def generate_eval_video_cameras( train_cameras, n_eval_cams: int = 100, trajectory_type: str = "figure_eight", trajectory_scale: float = 0.2, scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0), up: Tuple[float, float, float] = (0.0, 0.0, 1.0), focal_length: Optional[torch.Tensor] = None, principal_point: Optional[torch.Tensor] = None, time: Optional[torch.Tensor] = None, infer_up_as_plane_normal: bool = True, traj_offset: Optional[Tuple[float, float, float]] = None, traj_offset_canonical: Optional[Tuple[float, float, float]] = None, remove_outliers_rate: float = 0.0, ) -> PerspectiveCameras: """ Generate a camera trajectory rendering a scene from multiple viewpoints. Args: train_cameras: The set of cameras from the training dataset object. n_eval_cams: Number of cameras in the trajectory. trajectory_type: The type of the camera trajectory. Can be one of: circular_lsq_fit: Camera centers follow a trajectory obtained by fitting a 3D circle to train_cameras centers. All cameras are looking towards scene_center. figure_eight: Figure-of-8 trajectory around the center of the central camera of the training dataset. trefoil_knot: Same as 'figure_eight', but the trajectory has a shape of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot). figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape of a figure-eight knot (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)). trajectory_scale: The extent of the trajectory. scene_center: The center of the scene in world coordinates which all the cameras from the generated trajectory look at. up: The "circular_lsq_fit" vector of the scene (=the normal of the scene floor). Active for the `trajectory_type="circular"`. focal_length: The focal length of the output cameras. If `None`, an average focal length of the train_cameras is used. principal_point: The principal point of the output cameras. If `None`, an average principal point of all train_cameras is used. time: Defines the total length of the generated camera trajectory. All possible trajectories (set with the `trajectory_type` argument) are periodic with the period of `time=2pi`. E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi`, will generate a trajectory of camera poses rotating the total of 720 deg around the object. infer_up_as_plane_normal: Infer the camera `up` vector automatically as the normal of the plane fit to the optical centers of `train_cameras`. traj_offset: 3D offset vector added to each point of the trajectory. traj_offset_canonical: 3D offset vector expressed in the local coordinates of the estimated trajectory which is added to each point of the trajectory. remove_outliers_rate: the number between 0 and 1; if > 0, some outlier train_cameras will be removed from trajectory estimation; the filtering is based on camera center coordinates; top and bottom `remove_outliers_rate` cameras on each dimension are removed. Returns: Batch of camera instances which can be used as the test dataset """ if remove_outliers_rate > 0.0: train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate) if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"): cam_centers = train_cameras.get_camera_center() # get the nearest camera center to the mean of centers mean_camera_idx = ( ((cam_centers - cam_centers.mean(dim=0)[None]) ** 2) .sum(dim=1) .min(dim=0) .indices ) # generate the knot trajectory in canonical coords if time is None: time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams] else: assert time.numel() == n_eval_cams if trajectory_type == "trefoil_knot": traj = _trefoil_knot(time) elif trajectory_type == "figure_eight_knot": traj = _figure_eight_knot(time) elif trajectory_type == "figure_eight": traj = _figure_eight(time) else: raise ValueError(f"bad trajectory type: {trajectory_type}") traj[:, 2] -= traj[:, 2].max() # transform the canonical knot to the coord frame of the mean camera mean_camera = PerspectiveCameras( **{ k: getattr(train_cameras, k)[[int(mean_camera_idx)]] for k in ("focal_length", "principal_point", "R", "T") } ) traj_trans = Scale(cam_centers.std(dim=0).mean() * trajectory_scale).compose( mean_camera.get_world_to_view_transform().inverse() ) if traj_offset_canonical is not None: traj_trans = traj_trans.translate( torch.FloatTensor(traj_offset_canonical)[None].to(traj) ) traj = traj_trans.transform_points(traj) plane_normal = _fit_plane(cam_centers)[:, 0] if infer_up_as_plane_normal: up = _disambiguate_normal(plane_normal, up) elif trajectory_type == "circular_lsq_fit": ### fit plane to the camera centers # get the center of the plane as the median of the camera centers cam_centers = train_cameras.get_camera_center() if time is not None: angle = time else: angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams).to(cam_centers) fit = fit_circle_in_3d( cam_centers, angles=angle, offset=( angle.new_tensor(traj_offset_canonical) if traj_offset_canonical is not None else None ), up=angle.new_tensor(up), ) traj = fit.generated_points # scalethe trajectory _t_mu = traj.mean(dim=0, keepdim=True) traj = (traj - _t_mu) * trajectory_scale + _t_mu plane_normal = fit.normal if infer_up_as_plane_normal: up = _disambiguate_normal(plane_normal, up) else: raise ValueError(f"Uknown trajectory_type {trajectory_type}.") if traj_offset is not None: traj = traj + torch.FloatTensor(traj_offset)[None].to(traj) # point all cameras towards the center of the scene R, T = look_at_view_transform( eye=traj, at=(scene_center,), # (1, 3) up=(up,), # (1, 3) device=traj.device, ) # get the average focal length and principal point if focal_length is None: focal_length = train_cameras.focal_length.mean(dim=0).repeat(n_eval_cams, 1) if principal_point is None: principal_point = train_cameras.principal_point.mean(dim=0).repeat( n_eval_cams, 1 ) test_cameras = PerspectiveCameras( focal_length=focal_length, principal_point=principal_point, R=R, T=T, device=focal_length.device, ) # _visdom_plot_scene( # train_cameras, # test_cameras, # ) return test_cameras
def _remove_outlier_cameras( cameras: PerspectiveCameras, outlier_rate: float ) -> PerspectiveCameras: keep_indices = utils.get_inlier_indicators( cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate ) # pyre-fixme[6]: For 1st param expected `Union[List[int], int, BoolTensor, # LongTensor]` but got `Tensor`. clean_cameras = cameras[keep_indices] logger.info( "Filtered outlier cameras when estimating the trajectory: " f"{len(cameras)}{len(clean_cameras)}" ) # pyre-fixme[7]: Expected `PerspectiveCameras` but got `CamerasBase`. return clean_cameras def _disambiguate_normal(normal, up): up_t = torch.tensor(up).to(normal) flip = (up_t * normal).sum().sign() up = normal * flip up = up.tolist() return up def _fit_plane(x): x = x - x.mean(dim=0)[None] cov = (x.t() @ x) / x.shape[0] _, e_vec = torch.linalg.eigh(cov) return e_vec def _visdom_plot_scene( train_cameras, test_cameras, ) -> None: from pytorch3d.vis.plotly_vis import plot_scene p = plot_scene( { "scene": { "train_cams": train_cameras, "test_cams": test_cameras, } } ) from visdom import Visdom viz = Visdom() viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs") def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5): x = (2 + (2 * t).cos()) * (3 * t).cos() y = (2 + (2 * t).cos()) * (3 * t).sin() z = (4 * t).sin() * z_scale return torch.stack((x, y, z), dim=-1) def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5): x = t.sin() + 2 * (2 * t).sin() y = t.cos() - 2 * (2 * t).cos() z = -(3 * t).sin() * z_scale return torch.stack((x, y, z), dim=-1) def _figure_eight(t: torch.Tensor, z_scale: float = 0.5): x = t.cos() y = (2 * t).sin() / 2 z = t.sin() * z_scale return torch.stack((x, y, z), dim=-1)