# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import json
import warnings
from os import path
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from pytorch3d.common.datatypes import Device
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.renderer import HardPhongShader
from tabulate import tabulate
from .utils import (
align_bbox,
BlenderCamera,
compute_extrinsic_matrix,
read_binvox_coords,
voxelize,
)
SYNSET_DICT_DIR = Path(__file__).resolve().parent
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
VOXEL_SIZE = 128
# Intrinsic matrix extracted from Blender. Taken from meshrcnn codebase:
# https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py
BLENDER_INTRINSIC = torch.tensor(
[
[2.1875, 0.0, 0.0, 0.0],
[0.0, 2.1875, 0.0, 0.0],
[0.0, 0.0, -1.002002, -0.2002002],
[0.0, 0.0, -1.0, 0.0],
]
)
[docs]
class R2N2(ShapeNetBase): # pragma: no cover
"""
This class loads the R2N2 dataset from a given directory into a Dataset object.
The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1
dataset. The R2N2 dataset also contains its own 24 renderings of each object and
voxelized models. Most of the models have all 24 views in the same split, but there
are eight of them that divide their views between train and test splits.
"""
[docs]
def __init__(
self,
split: str,
shapenet_dir: str,
r2n2_dir: str,
splits_file: str,
return_all_views: bool = True,
return_voxels: bool = False,
views_rel_path: str = "ShapeNetRendering",
voxels_rel_path: str = "ShapeNetVoxels",
load_textures: bool = True,
texture_resolution: int = 4,
) -> None:
"""
Store each object's synset id and models id the given directories.
Args:
split (str): One of (train, val, test).
shapenet_dir (str): Path to ShapeNet core v1.
r2n2_dir (str): Path to the R2N2 dataset.
splits_file (str): File containing the train/val/test splits.
return_all_views (bool): Indicator of whether or not to load all the views in
the split. If set to False, one of the views in the split will be randomly
selected and loaded.
return_voxels(bool): Indicator of whether or not to return voxels as a tensor
of shape (D, D, D) where D is the number of voxels along each dimension.
views_rel_path: path to rendered views within the r2n2_dir. If not specified,
the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetRendering").
voxels_rel_path: path to rendered views within the r2n2_dir. If not specified,
the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetVoxels").
load_textures: Boolean indicating whether textures should loaded for the model.
Textures will be of type TexturesAtlas i.e. a texture map per face.
texture_resolution: Int specifying the resolution of the texture map per face
created using the textures in the obj file. A
(texture_resolution, texture_resolution, 3) map is created per face.
"""
super().__init__()
self.shapenet_dir = shapenet_dir
self.r2n2_dir = r2n2_dir
self.views_rel_path = views_rel_path
self.voxels_rel_path = voxels_rel_path
self.load_textures = load_textures
self.texture_resolution = texture_resolution
# Examine if split is valid.
if split not in ["train", "val", "test"]:
raise ValueError("split has to be one of (train, val, test).")
# Synset dictionary mapping synset offsets in R2N2 to corresponding labels.
with open(
path.join(SYNSET_DICT_DIR, "r2n2_synset_dict.json"), "r"
) as read_dict:
self.synset_dict = json.load(read_dict)
# Inverse dictionary mapping synset labels to corresponding offsets.
self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}
# Store synset and model ids of objects mentioned in the splits_file.
with open(splits_file) as splits:
split_dict = json.load(splits)[split]
self.return_images = True
# Check if the folder containing R2N2 renderings is included in r2n2_dir.
if not path.isdir(path.join(r2n2_dir, views_rel_path)):
self.return_images = False
msg = (
"%s not found in %s. R2N2 renderings will "
"be skipped when returning models."
) % (views_rel_path, r2n2_dir)
warnings.warn(msg)
self.return_voxels = return_voxels
# Check if the folder containing voxel coordinates is included in r2n2_dir.
if not path.isdir(path.join(r2n2_dir, voxels_rel_path)):
self.return_voxels = False
msg = (
"%s not found in %s. Voxel coordinates will "
"be skipped when returning models."
) % (voxels_rel_path, r2n2_dir)
warnings.warn(msg)
synset_set = set()
# Store lists of views of each model in a list.
self.views_per_model_list = []
# Store tuples of synset label and total number of views in each category in a list.
synset_num_instances = []
for synset in split_dict.keys():
# Examine if the given synset is present in the ShapeNetCore dataset
# and is also part of the standard R2N2 dataset.
if not (
path.isdir(path.join(shapenet_dir, synset))
and synset in self.synset_dict
):
msg = (
"Synset category %s from the splits file is either not "
"present in %s or not part of the standard R2N2 dataset."
) % (synset, shapenet_dir)
warnings.warn(msg)
continue
synset_set.add(synset)
self.synset_start_idxs[synset] = len(self.synset_ids)
# Start counting total number of views in the current category.
synset_view_count = 0
for model in split_dict[synset]:
# Examine if the given model is present in the ShapeNetCore path.
shapenet_path = path.join(shapenet_dir, synset, model)
if not path.isdir(shapenet_path):
msg = "Model %s from category %s is not present in %s." % (
model,
synset,
shapenet_dir,
)
warnings.warn(msg)
continue
self.synset_ids.append(synset)
self.model_ids.append(model)
model_views = split_dict[synset][model]
# Randomly select a view index if return_all_views set to False.
if not return_all_views:
rand_idx = torch.randint(len(model_views), (1,))
model_views = [model_views[rand_idx]]
self.views_per_model_list.append(model_views)
synset_view_count += len(model_views)
synset_num_instances.append((self.synset_dict[synset], synset_view_count))
model_count = len(self.synset_ids) - self.synset_start_idxs[synset]
self.synset_num_models[synset] = model_count
headers = ["category", "#instances"]
synset_num_instances.append(("total", sum(n for _, n in synset_num_instances)))
print(
tabulate(synset_num_instances, headers, numalign="left", stralign="center")
)
# Examine if all the synsets in the standard R2N2 mapping are present.
# Update self.synset_inv so that it only includes the loaded categories.
synset_not_present = [
self.synset_inv.pop(self.synset_dict[synset])
for synset in self.synset_dict
if synset not in synset_set
]
if len(synset_not_present) > 0:
msg = (
"The following categories are included in R2N2's"
"official mapping but not found in the dataset location %s: %s"
) % (shapenet_dir, ", ".join(synset_not_present))
warnings.warn(msg)
[docs]
def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
"""
Read a model by the given index.
Args:
model_idx: The idx of the model to be retrieved in the dataset.
view_idx: List of indices of the view to be returned. Each index needs to be
contained in the loaded split (always between 0 and 23, inclusive). If
an invalid index is supplied, view_idx will be ignored and all the loaded
views will be returned.
Returns:
dictionary with following keys:
- verts: FloatTensor of shape (V, 3).
- faces: faces.verts_idx, LongTensor of shape (F, 3).
- synset_id (str): synset id.
- model_id (str): model id.
- label (str): synset label.
- images: FloatTensor of shape (V, H, W, C), where V is number of views
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
- R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
- T: Translation matrix of shape (V, 3), where V is number of views returned.
- K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
- voxels: Voxels of shape (D, D, D), where D is the number of voxels along each
dimension.
"""
if isinstance(model_idx, tuple):
model_idx, view_idxs = model_idx
if view_idxs is not None:
if isinstance(view_idxs, int):
view_idxs = [view_idxs]
if not isinstance(view_idxs, list) and not torch.is_tensor(view_idxs):
raise TypeError(
"view_idxs is of type %s but it needs to be a list."
% type(view_idxs)
)
model_views = self.views_per_model_list[model_idx]
if view_idxs is not None and any(
idx not in self.views_per_model_list[model_idx] for idx in view_idxs
):
msg = """At least one of the indices in view_idxs is not available.
Specified view of the model needs to be contained in the
loaded split. If return_all_views is set to False, only one
random view is loaded. Try accessing the specified view(s)
after loading the dataset with self.return_all_views set to True.
Now returning all view(s) in the loaded dataset."""
warnings.warn(msg)
elif view_idxs is not None:
model_views = view_idxs
model = self._get_item_ids(model_idx)
model_path = path.join(
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
)
verts, faces, textures = self._load_mesh(model_path)
model["verts"] = verts
model["faces"] = faces
model["textures"] = textures
model["label"] = self.synset_dict[model["synset_id"]]
model["images"] = None
images, Rs, Ts, voxel_RTs = [], [], [], []
# Retrieve R2N2's renderings if required.
if self.return_images:
rendering_path = path.join(
self.r2n2_dir,
self.views_rel_path,
model["synset_id"],
model["model_id"],
"rendering",
)
# Read metadata file to obtain params for calibration matrices.
with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
metadata_lines = f.readlines()
for i in model_views:
# Read image.
image_path = path.join(rendering_path, "%02d.png" % i)
raw_img = Image.open(image_path)
image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
images.append(image.to(dtype=torch.float32))
# Get camera calibration.
azim, elev, yaw, dist_ratio, fov = [
float(v) for v in metadata_lines[i].strip().split(" ")
]
dist = dist_ratio * MAX_CAMERA_DISTANCE
# Extrinsic matrix before transformation to PyTorch3D world space.
RT = compute_extrinsic_matrix(azim, elev, dist)
R, T = self._compute_camera_calibration(RT)
Rs.append(R)
Ts.append(T)
voxel_RTs.append(RT)
# Intrinsic matrix extracted from the Blender with slight modification to work with
# PyTorch3D world space. Taken from meshrcnn codebase:
# https://github.com/facebookresearch/meshrcnn/blob/main/shapenet/utils/coords.py
K = torch.tensor(
[
[2.1875, 0.0, 0.0, 0.0],
[0.0, 2.1875, 0.0, 0.0],
[0.0, 0.0, -1.002002, -0.2002002],
[0.0, 0.0, 1.0, 0.0],
]
)
model["images"] = torch.stack(images)
model["R"] = torch.stack(Rs)
model["T"] = torch.stack(Ts)
model["K"] = K.expand(len(model_views), 4, 4)
voxels_list = []
# Read voxels if required.
voxel_path = path.join(
self.r2n2_dir,
self.voxels_rel_path,
model["synset_id"],
model["model_id"],
"model.binvox",
)
if self.return_voxels:
if not path.isfile(voxel_path):
msg = "Voxel file not found for model %s from category %s."
raise FileNotFoundError(msg % (model["model_id"], model["synset_id"]))
with open(voxel_path, "rb") as f:
# Read voxel coordinates as a tensor of shape (N, 3).
voxel_coords = read_binvox_coords(f)
# Align voxels to the same coordinate system as mesh verts.
voxel_coords = align_bbox(voxel_coords, model["verts"])
for RT in voxel_RTs:
# Compute projection matrix.
P = BLENDER_INTRINSIC.mm(RT)
# Convert voxel coordinates of shape (N, 3) to voxels of shape (D, D, D).
voxels = voxelize(voxel_coords, P, VOXEL_SIZE)
voxels_list.append(voxels)
model["voxels"] = torch.stack(voxels_list)
return model
def _compute_camera_calibration(self, RT):
"""
Helper function for calculating rotation and translation matrices from ShapeNet
to camera transformation and ShapeNet to PyTorch3D transformation.
Args:
RT: Extrinsic matrix that performs ShapeNet world view to camera view
transformation.
Returns:
R: Rotation matrix of shape (3, 3).
T: Translation matrix of shape (3).
"""
# Transform the mesh vertices from shapenet world to pytorch3d world.
shapenet_to_pytorch3d = torch.tensor(
[
[-1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, -1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
dtype=torch.float32,
)
RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4)
# Extract rotation and translation matrices from RT.
R = RT[:3, :3]
T = RT[3, :3]
return R, T
[docs]
def render(
self,
model_ids: Optional[List[str]] = None,
categories: Optional[List[str]] = None,
sample_nums: Optional[List[int]] = None,
idxs: Optional[List[int]] = None,
view_idxs: Optional[List[int]] = None,
shader_type=HardPhongShader,
device: Device = "cpu",
**kwargs,
) -> torch.Tensor:
"""
Render models with BlenderCamera by default to achieve the same orientations as the
R2N2 renderings. Also accepts other types of cameras and any of the args that the
render function in the ShapeNetBase class accepts.
Args:
view_idxs: each model will be rendered with the orientation(s) of the specified
views. Only render by view_idxs if no camera or args for BlenderCamera is
supplied.
Accepts any of the args of the render function in ShapeNetBase:
model_ids: List[str] of model_ids of models intended to be rendered.
categories: List[str] of categories intended to be rendered. categories
and sample_nums must be specified at the same time. categories can be given
in the form of synset offsets or labels, or a combination of both.
sample_nums: List[int] of number of models to be randomly sampled from
each category. Could also contain one single integer, in which case it
will be broadcasted for every category.
idxs: List[int] of indices of models to be rendered in the dataset.
shader_type: Shader to use for rendering. Examples include HardPhongShader
(default), SoftPhongShader etc or any other type of valid Shader class.
device: Device (as str or torch.device) on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports and any of the
args that BlenderCamera supports.
Returns:
Batch of rendered images of shape (N, H, W, 3).
"""
idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
r = torch.cat([self[idxs[i], view_idxs]["R"] for i in range(len(idxs))])
t = torch.cat([self[idxs[i], view_idxs]["T"] for i in range(len(idxs))])
k = torch.cat([self[idxs[i], view_idxs]["K"] for i in range(len(idxs))])
# Initialize default camera using R, T, K from kwargs or R, T, K of the specified views.
blend_cameras = BlenderCamera(
R=kwargs.get("R", r),
T=kwargs.get("T", t),
K=kwargs.get("K", k),
device=device,
)
cameras = kwargs.get("cameras", blend_cameras).to(device)
kwargs.pop("cameras", None)
# pass down all the same inputs
return super().render(
idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
)