Source code for pytorch3d.datasets.shapenet.shapenet_core

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import json
import os
import warnings
from os import path
from pathlib import Path
from typing import Dict

from pytorch3d.datasets.shapenet_base import ShapeNetBase


SYNSET_DICT_DIR = Path(__file__).resolve().parent


[docs]class ShapeNetCore(ShapeNetBase): """ This class loads ShapeNetCore from a given directory into a Dataset object. ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from https://www.shapenet.org/. """
[docs] def __init__( self, data_dir, synsets=None, version: int = 1, load_textures: bool = True, texture_resolution: int = 4, ): """ Store each object's synset id and models id from data_dir. Args: data_dir: Path to ShapeNetCore data. synsets: List of synset categories to load from ShapeNetCore in the form of synset offsets or labels. A combination of both is also accepted. When no category is specified, all categories in data_dir are loaded. version: (int) version of ShapeNetCore data in data_dir, 1 or 2. Default is set to be 1. Version 1 has 57 categories and verions 2 has 55 categories. Note: version 1 has two categories 02858304(boat) and 02992529(cellphone) that are hyponyms of categories 04530566(watercraft) and 04401088(telephone) respectively. You can combine the categories manually if needed. Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to version 1. load_textures: Boolean indicating whether textures should loaded for the model. Textures will be of type TexturesAtlas i.e. a texture map per face. texture_resolution: Int specifying the resolution of the texture map per face created using the textures in the obj file. A (texture_resolution, texture_resolution, 3) map is created per face. """ super().__init__() self.shapenet_dir = data_dir self.load_textures = load_textures self.texture_resolution = texture_resolution if version not in [1, 2]: raise ValueError("Version number must be either 1 or 2.") self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj" # Synset dictionary mapping synset offsets to corresponding labels. dict_file = "shapenet_synset_dict_v%d.json" % version with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict: self.synset_dict = json.load(read_dict) # Inverse dicitonary mapping synset labels to corresponding offsets. self.synset_inv = {label: offset for offset, label in self.synset_dict.items()} # If categories are specified, check if each category is in the form of either # synset offset or synset label, and if the category exists in the given directory. if synsets is not None: # Set of categories to load in the form of synset offsets. synset_set = set() for synset in synsets: if (synset in self.synset_dict.keys()) and ( path.isdir(path.join(data_dir, synset)) ): synset_set.add(synset) elif (synset in self.synset_inv.keys()) and ( (path.isdir(path.join(data_dir, self.synset_inv[synset]))) ): synset_set.add(self.synset_inv[synset]) else: msg = ( "Synset category %s either not part of ShapeNetCore dataset " "or cannot be found in %s." ) % (synset, data_dir) warnings.warn(msg) # If no category is given, load every category in the given directory. # Ignore synset folders not included in the official mapping. else: synset_set = { synset for synset in os.listdir(data_dir) if path.isdir(path.join(data_dir, synset)) and synset in self.synset_dict } # Check if there are any categories in the official mapping that are not loaded. # Update self.synset_inv so that it only includes the loaded categories. synset_not_present = set(self.synset_dict.keys()).difference(synset_set) [self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present] if len(synset_not_present) > 0: msg = ( "The following categories are included in ShapeNetCore ver.%d's " "official mapping but not found in the dataset location %s: %s" "" ) % (version, data_dir, ", ".join(synset_not_present)) warnings.warn(msg) # Extract model_id of each object from directory names. # Each grandchildren directory of data_dir contains an object, and the name # of the directory is the object's model_id. for synset in synset_set: self.synset_start_idxs[synset] = len(self.synset_ids) for model in os.listdir(path.join(data_dir, synset)): if not path.exists(path.join(data_dir, synset, model, self.model_dir)): msg = ( "Object file not found in the model directory %s " "under synset directory %s." ) % (model, synset) warnings.warn(msg) continue self.synset_ids.append(synset) self.model_ids.append(model) model_count = len(self.synset_ids) - self.synset_start_idxs[synset] self.synset_num_models[synset] = model_count
[docs] def __getitem__(self, idx: int) -> Dict: """ Read a model by the given index. Args: idx: The idx of the model to be retrieved in the dataset. Returns: dictionary with following keys: - verts: FloatTensor of shape (V, 3). - faces: LongTensor of shape (F, 3) which indexes into the verts tensor. - synset_id (str): synset id - model_id (str): model id - label (str): synset label. """ model = self._get_item_ids(idx) model_path = path.join( self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir ) verts, faces, textures = self._load_mesh(model_path) model["verts"] = verts model["faces"] = faces model["textures"] = textures model["label"] = self.synset_dict[model["synset_id"]] return model