Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # pyre-unsafe | |
| from typing import Optional | |
| import torch | |
| class HarmonicEmbedding(torch.nn.Module): | |
| def __init__( | |
| self, n_harmonic_functions: int = 6, omega_0: float = 1.0, logspace: bool = True, append_input: bool = True | |
| ) -> None: | |
| """ | |
| The harmonic embedding layer supports the classical | |
| Nerf positional encoding described in | |
| `NeRF <https://arxiv.org/abs/2003.08934>`_ | |
| and the integrated position encoding in | |
| `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. | |
| During the inference you can provide the extra argument `diag_cov`. | |
| If `diag_cov is None`, it converts | |
| rays parametrized with a `ray_bundle` to 3D points by | |
| extending each ray according to the corresponding length. | |
| Then it converts each feature | |
| (i.e. vector along the last dimension) in `x` | |
| into a series of harmonic features `embedding`, | |
| where for each i in range(dim) the following are present | |
| in embedding[...]:: | |
| [ | |
| sin(f_1*x[..., i]), | |
| sin(f_2*x[..., i]), | |
| ... | |
| sin(f_N * x[..., i]), | |
| cos(f_1*x[..., i]), | |
| cos(f_2*x[..., i]), | |
| ... | |
| cos(f_N * x[..., i]), | |
| x[..., i], # only present if append_input is True. | |
| ] | |
| where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar | |
| denoting the i-th frequency of the harmonic embedding. | |
| If `diag_cov is not None`, it approximates | |
| conical frustums following a ray bundle as gaussians, | |
| defined by x, the means of the gaussians and diag_cov, | |
| the diagonal covariances. | |
| Then it converts each gaussian | |
| into a series of harmonic features `embedding`, | |
| where for each i in range(dim) the following are present | |
| in embedding[...]:: | |
| [ | |
| sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), | |
| sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), | |
| ... | |
| sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), | |
| cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), | |
| cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, | |
| ... | |
| cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), | |
| x[..., i], # only present if append_input is True. | |
| ] | |
| where N equals `n_harmonic_functions-1`, and f_i is a scalar | |
| denoting the i-th frequency of the harmonic embedding. | |
| If `logspace==True`, the frequencies `[f_1, ..., f_N]` are | |
| powers of 2: | |
| `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` | |
| If `logspace==False`, frequencies are linearly spaced between | |
| `1.0` and `2**(n_harmonic_functions-1)`: | |
| `f_1, ..., f_N = torch.linspace( | |
| 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions | |
| )` | |
| Note that `x` is also premultiplied by the base frequency `omega_0` | |
| before evaluating the harmonic functions. | |
| Args: | |
| n_harmonic_functions: int, number of harmonic | |
| features | |
| omega_0: float, base frequency | |
| logspace: bool, Whether to space the frequencies in | |
| logspace or linear space | |
| append_input: bool, whether to concat the original | |
| input to the harmonic embedding. If true the | |
| output is of the form (embed.sin(), embed.cos(), x) | |
| """ | |
| super().__init__() | |
| if logspace: | |
| frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32) | |
| else: | |
| frequencies = torch.linspace( | |
| 1.0, 2.0 ** (n_harmonic_functions - 1), n_harmonic_functions, dtype=torch.float32 | |
| ) | |
| self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) | |
| self.register_buffer("_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False) | |
| self.append_input = append_input | |
| def forward(self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: tensor of shape [..., dim] | |
| diag_cov: An optional tensor of shape `(..., dim)` | |
| representing the diagonal covariance matrices of our Gaussians, joined with x | |
| as means of the Gaussians. | |
| Returns: | |
| embedding: a harmonic embedding of `x` of shape | |
| [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] | |
| """ | |
| # [..., dim, n_harmonic_functions] | |
| embed = x[..., None] * self._frequencies | |
| # [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions] | |
| embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] | |
| # Use the trig identity cos(x) = sin(x + pi/2) | |
| # and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)). | |
| embed = embed.sin() | |
| if diag_cov is not None: | |
| x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) | |
| exp_var = torch.exp(-0.5 * x_var) | |
| # [..., 2, dim, n_harmonic_functions] | |
| embed = embed * exp_var[..., None, :, :] | |
| embed = embed.reshape(*x.shape[:-1], -1) | |
| if self.append_input: | |
| return torch.cat([embed, x], dim=-1) | |
| return embed | |
| def get_output_dim_static(input_dims: int, n_harmonic_functions: int, append_input: bool) -> int: | |
| """ | |
| Utility to help predict the shape of the output of `forward`. | |
| Args: | |
| input_dims: length of the last dimension of the input tensor | |
| n_harmonic_functions: number of embedding frequencies | |
| append_input: whether or not to concat the original | |
| input to the harmonic embedding | |
| Returns: | |
| int: the length of the last dimension of the output tensor | |
| """ | |
| return input_dims * (2 * n_harmonic_functions + int(append_input)) | |
| def get_output_dim(self, input_dims: int = 3) -> int: | |
| """ | |
| Same as above. The default for input_dims is 3 for 3D applications | |
| which use harmonic embedding for positional encoding, | |
| so the input might be xyz. | |
| """ | |
| return self.get_output_dim_static(input_dims, len(self._frequencies), self.append_input) | |