# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import List, Optional, Union
import torch
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .autodecoder import Autodecoder
[docs]
class GlobalEncoderBase(ReplaceableBase):
"""
A base class for implementing encoders of global frame-specific quantities.
The latter includes e.g. the harmonic encoding of a frame timestamp
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
(`SequenceAutodecoder`).
"""
[docs]
def get_encoding_dim(self):
"""
Returns the dimensionality of the returned encoding.
"""
raise NotImplementedError()
[docs]
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
"""
Calculates the squared norm of the encoding to report as the
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
"""
raise NotImplementedError()
[docs]
def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
"""
Given a set of inputs to encode, generates a tensor containing the encoding.
Returns:
encoding: The tensor containing the global encoding.
"""
raise NotImplementedError()
# TODO: probabilistic embeddings?
[docs]
@registry.register
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):
"""
A global encoder implementation which provides an autodecoder encoding
of the frame's sequence identifier.
"""
# pyre-fixme[13]: Attribute `autodecoder` is never initialized.
autodecoder: Autodecoder
def __post_init__(self):
run_auto_creation(self)
[docs]
def get_encoding_dim(self):
return self.autodecoder.get_encoding_dim()
[docs]
def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
if sequence_name is None:
raise ValueError("sequence_name must be provided.")
# run dtype checks and pass sequence_name to self.autodecoder
return self.autodecoder(sequence_name)
[docs]
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return self.autodecoder.calculate_squared_encoding_norm()
[docs]
@registry.register
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
"""
A global encoder implementation which provides harmonic embeddings
of each frame's timestamp.
"""
n_harmonic_functions: int = 10
append_input: bool = True
time_divisor: float = 1.0
def __post_init__(self):
self._harmonic_embedding = HarmonicEmbedding(
n_harmonic_functions=self.n_harmonic_functions,
append_input=self.append_input,
)
[docs]
def get_encoding_dim(self):
return self._harmonic_embedding.get_output_dim(1)
[docs]
def forward(
self,
*,
frame_timestamp: Optional[torch.Tensor] = None,
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
if frame_timestamp is None:
raise ValueError("frame_timestamp must be provided.")
if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._harmonic_embedding(time)
[docs]
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
return None