Vae

manify.embedders.vae

Variational autoencoder implementation for product manifold spaces.

This module provides a variational autoencoder (VAE) implementation specifically designed for learning representations in mixed-curvature product spaces. The implementation handles the complexities of sampling, KL divergence calculation, and reparameterization in curved spaces, supporting combinations of hyperbolic, Euclidean, and spherical geometries within a single latent space.

For more information, see Skopek et al (2020): Mixed Curvature Variational Autoencoders

ProductSpaceVAE(pm, encoder, decoder, random_state=None, device='cpu', beta=1.0, reconstruction_loss=None, n_samples=16)

Bases: BaseEmbedder, Module

Product Space Variational Autoencoder.

The probabilistic model is defined as:

  • Prior: \(p(z) = \mathcal{WN}(z; \mu_0, I)\) (wrapped normal distribution centered at manifold origin)
  • Likelihood: \(p_\theta(x|z) = \mathcal{N}(x; f_\theta(z), \sigma^2 I)\) or other reconstruction distribution
  • Posterior approximation: \(q_\phi(z|x) = \mathcal{WN}(z; \mu_\phi(x), \Sigma_\phi(x))\)

where \(\mathcal{WN}\) is a wrapped normal distribution on the manifold.

The model is trained by maximizing the evidence lower bound (ELBO):

\(\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z))\)

Attributes:
  • pm

    Product manifold defining the structure of the latent space.

  • random_state

    Random state for reproducibility.

  • encoder

    Neural network that outputs mean and log-variance parameters.

  • decoder

    Neural network that reconstructs inputs from latent embeddings.

  • beta

    Weight for the KL divergence term in the ELBO.

  • device

    Device for tensor computations.

  • n_samples

    Number of samples for Monte Carlo estimation of KL divergence.

  • reconstruction_loss

    Type of reconstruction loss to use.

  • loss_history_

    Dictionary to store the history of loss values during training.

  • is_fitted_

    Boolean flag indicating whether the model has been fitted.

Parameters:
  • pm (ProductManifold) –

    Product manifold defining the structure of the latent space.

  • encoder (Module) –

    Neural network module that produces mean (first half of output) and log-variance (second half of output) of the posterior distribution. The output dimension should match twice the intrinsic dimension of the product manifold.

  • decoder (Module) –

    Neural network module that maps latent representations back to the input space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str, default: 'cpu' ) –

    Optional device for tensor computations.

  • beta (float, default: 1.0 ) –

    Weight of the KL divergence term in the ELBO loss. Values < 1 give a \(\beta\)-VAE with a looser constraint on the latent space.

  • reconstruction_loss (_Loss | None, default: None ) –

    Type of reconstruction loss to use. Currently only "mse" (mean squared error) is supported.

  • n_samples (int, default: 16 ) –

    Number of Monte Carlo samples to use when estimating the KL divergence.

Source code in manify/embedders/vae.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    pm: ProductManifold,
    encoder: torch.nn.Module,
    decoder: torch.nn.Module,
    random_state: int | None = None,
    device: str = "cpu",
    beta: float = 1.0,
    reconstruction_loss: torch.nn.modules.loss._Loss | None = None,
    n_samples: int = 16,
):
    # Init both base classes
    torch.nn.Module.__init__(self)
    BaseEmbedder.__init__(self, pm=pm, random_state=random_state, device=device)

    # Now we assign
    self.encoder = encoder.to(device)
    self.decoder = decoder.to(device)
    self.beta = beta
    self.n_samples = n_samples
    self.reconstruction_loss = (
        reconstruction_loss if reconstruction_loss is not None else torch.nn.MSELoss(reduction="none")
    )
    self.model_ = None
    self.loss_history_ = {}
    self.is_fitted_ = False

    # Ensure encoder last dimension is 2 * pm.intrinsic_dim:
    assert encoder[-1].out_features == 2 * pm.dim, "Encoder output must match 2 * intrinsic dimension of manifold."

    # Ensure decoder input dimension is pm.intrinsic_dim
    assert decoder[0].in_features == pm.ambient_dim, "Decoder input must match ambient dimension of manifold."

encode(x)

Encodes input data to obtain latent means and log-variances in the manifold.

This method processes input data through the encoder network to obtain parameters of the approximate posterior distribution \(q(z|x)\) in the product manifold space. For non-Euclidean components, the method:

  1. Gets tangent space vectors and log-variances from the encoder,
  2. Projects tangent vectors to the ambient space by adding zeros in the right places, and
  3. Maps the ambient space vectors to the manifold using the exponential map
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • z_mean_tangent( Float[Tensor, 'batch_size n_latent'] ) –

    Mean of the posterior distribution in the tangent plane at the origin.

  • z_logvar( Float[Tensor, 'batch_size n_latent'] ) –

    Log-variance of the posterior distribution, used for constructing the covariance matrices.

Source code in manify/embedders/vae.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def encode(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]:
    r"""Encodes input data to obtain latent means and log-variances in the manifold.

    This method processes input data through the encoder network to obtain parameters of the approximate posterior
    distribution $q(z|x)$ in the product manifold space. For non-Euclidean components, the method:

    1. Gets tangent space vectors and log-variances from the encoder,
    2. Projects tangent vectors to the ambient space by adding zeros in the right places, and
    3. Maps the ambient space vectors to the manifold using the exponential map

    Args:
        x: Input data tensor.

    Returns:
        z_mean_tangent: Mean of the posterior distribution in the tangent plane at the origin.
        z_logvar: Log-variance of the posterior distribution, used for constructing the covariance matrices.
    """
    z = self.encoder(x)
    z_mean_tangent, z_logvar = z[..., : self.pm.dim], z[..., self.pm.dim :]
    # z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
    # z_mean = self.pm.expmap(u=z_mean_ambient, base=None)
    return z_mean_tangent, z_logvar

decode(z)

Decodes latent points from the manifold space back to the input space.

Takes points from the product manifold latent space and passes them through the decoder network to reconstruct the original input data.

Parameters:
  • z (Float[Tensor, 'batch_size n_ambient']) –

    Latent points in the product manifold

Returns:
  • reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Tensor containing the reconstructed input data, with shape (batch_size, n_features).

Source code in manify/embedders/vae.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def decode(self, z: Float[torch.Tensor, "batch_size n_ambient"]) -> Float[torch.Tensor, "batch_size n_features"]:
    """Decodes latent points from the manifold space back to the input space.

    Takes points from the product manifold latent space and passes them through
    the decoder network to reconstruct the original input data.

    Args:
        z: Latent points in the product manifold

    Returns:
        reconstructed: Tensor containing the reconstructed input data,
            with shape (batch_size, n_features).
    """
    return self.decoder(z)

forward(x)

Performs the forward pass of the VAE in product manifold space.

This method implements the complete VAE forward pass, with manifold projection:

  1. Encode the input to get posterior parameters (z_means, z_logvars)
  2. Project means onto the manifold using exponential map
  3. Factorize the log-variances for each manifold component and convert to covariance matrices
  4. Sample points from the posterior distributions in the product manifold
  5. Decode the sampled points to get reconstructions
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • x_reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Reconstructed data tensor with the same shape as the input.

  • z_means( Float[Tensor, 'batch_size n_ambient'] ) –

    Means of the posterior distributions in the manifold space.

  • sigmas( list[Float[Tensor, 'batch_size n_latent n_latent']] ) –

    List of covariance matrices for each manifold component.

Source code in manify/embedders/vae.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def forward(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[
    Float[torch.Tensor, "batch_size n_features"],
    Float[torch.Tensor, "batch_size n_ambient"],
    list[Float[torch.Tensor, "batch_size n_latent n_latent"]],
]:
    r"""Performs the forward pass of the VAE in product manifold space.

    This method implements the complete VAE forward pass, with manifold projection:

    1. Encode the input to get posterior parameters (`z_means`, `z_logvars`)
    2. Project means onto the manifold using exponential map
    3. Factorize the log-variances for each manifold component and convert to covariance matrices
    4. Sample points from the posterior distributions in the product manifold
    5. Decode the sampled points to get reconstructions

    Args:
        x: Input data tensor.

    Returns:
        x_reconstructed: Reconstructed data tensor with the same shape as the input.
        z_means: Means of the posterior distributions in the manifold space.
        sigmas: List of covariance matrices for each manifold component.
    """
    z_mean_tangent, z_logvars = self.encode(x)

    # Need to convert from implicit parameterization to extrinsic coordinates
    z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
    z_means = self.pm.expmap(u=z_mean_ambient, base=None)

    # Factorize log-variances; convert to covariances
    sigma_factorized = self.pm.factorize(z_logvars, intrinsic=True)
    sigmas = [torch.diag_embed(torch.exp(z_logvar) + 1e-8) for z_logvar in sigma_factorized]

    # Sample and decode
    z = self.pm.sample(z_mean=z_means, sigma_factorized=sigmas)
    x_reconstructed = self.decode(z)
    return x_reconstructed, z_means, sigmas

kl_divergence(z_mean, sigma_factorized)

Computes the KL divergence between posterior and prior distributions in the manifold.

For distributions in Riemannian manifolds, computing the KL divergence analytically is often intractable. This method uses Monte Carlo sampling to approximate the KL divergence:

\[D_{KL}(q(z|x) || p(z)) \approx \frac{1}{N} \sum_{i=1}^{N} [\log q(z_i|x) - \log p(z_i)]\]

where \(z_i\) are samples from \(q(z|x)\).

This implementation follows the approach described in: http://joschu.net/blog/kl-approx.html

Parameters:
  • z_mean (Float[Tensor, 'batch_size n_latent']) –

    Means of the posterior distributions in the manifold.

  • sigma_factorized (list[Float[Tensor, 'batch_size manifold_dim manifold_dim']]) –

    List of covariance matrices for each manifold component.

Returns:
  • kl_divergence( Float[Tensor, 'batch_size'] ) –

    KL divergence values for each data point in the batch.

Source code in manify/embedders/vae.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def kl_divergence(
    self,
    z_mean: Float[torch.Tensor, "batch_size n_latent"],
    sigma_factorized: list[Float[torch.Tensor, "batch_size manifold_dim manifold_dim"]],
) -> Float[torch.Tensor, "batch_size"]:
    r"""Computes the KL divergence between posterior and prior distributions in the manifold.

    For distributions in Riemannian manifolds, computing the KL divergence analytically
    is often intractable. This method uses Monte Carlo sampling to approximate the KL divergence:

    $$D_{KL}(q(z|x) || p(z)) \approx \frac{1}{N} \sum_{i=1}^{N} [\log q(z_i|x) - \log p(z_i)]$$

    where $z_i$ are samples from $q(z|x)$.

    This implementation follows the approach described in:
    http://joschu.net/blog/kl-approx.html

    Args:
        z_mean: Means of the posterior distributions in the manifold.
        sigma_factorized: List of covariance matrices for each manifold component.

    Returns:
        kl_divergence: KL divergence values for each data point in the batch.
    """
    # Get KL divergence as the average of log q(z|x) - log p(z)
    means = torch.repeat_interleave(z_mean, self.n_samples, dim=0)
    sigmas_factorized_interleaved = [
        torch.repeat_interleave(sigma, self.n_samples, dim=0) for sigma in sigma_factorized
    ]
    # We want to use n_samples = 1 here, since we'll need to pass the interleaved means/sigmas to the log-likelihood
    z_samples = self.pm.sample(z_mean=means, sigma_factorized=sigmas_factorized_interleaved)
    log_qz = self.pm.log_likelihood(z_samples, means, sigmas_factorized_interleaved)
    log_pz = self.pm.log_likelihood(z_samples)
    return (log_qz - log_pz).view(-1, self.n_samples).mean(dim=1)

elbo(x)

Computes the Evidence Lower Bound (ELBO) for the VAE objective.

The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term (log-likelihood) and a regularization term (KL divergence):

\[\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z)),\]

where:

  • \(\theta\) are the decoder parameters
  • \(\phi\) are the encoder parameters
  • \(\beta\) is a weight for the KL term (setting \(\beta < 1\) creates a \(\beta\)-VAE)
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • elbo( Float[Tensor, ''] ) –

    Mean ELBO value across the batch (higher is better).

  • log_likelihood( Float[Tensor, ''] ) –

    Mean reconstruction log-likelihood across the batch.

  • kl_divergence( Float[Tensor, ''] ) –

    Mean KL divergence across the batch.

Source code in manify/embedders/vae.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def elbo(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[Float[torch.Tensor, ""], Float[torch.Tensor, ""], Float[torch.Tensor, ""]]:
    r"""Computes the Evidence Lower Bound (ELBO) for the VAE objective.

    The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term
    (log-likelihood) and a regularization term (KL divergence):

    $$\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z)),$$

    where:

    - $\theta$ are the decoder parameters
    - $\phi$ are the encoder parameters
    - $\beta$ is a weight for the KL term (setting $\beta < 1$ creates a $\beta$-VAE)

    Args:
        x: Input data tensor.

    Returns:
        elbo: Mean ELBO value across the batch (higher is better).
        log_likelihood: Mean reconstruction log-likelihood across the batch.
        kl_divergence: Mean KL divergence across the batch.
    """
    x_reconstructed, z_means, sigma_factorized = self(x)
    kld = self.kl_divergence(z_means, sigma_factorized)
    ll = -self.reconstruction_loss(x_reconstructed.view(x.shape[0], -1), x.view(x.shape[0], -1)).sum(dim=1)
    return (ll - self.beta * kld).mean(), ll.mean(), kld.mean()

fit(X, D=None, lr=0.001, burn_in_lr=0.0001, curvature_lr=0.0, burn_in_iterations=1, training_iterations=9, loss_window_size=100, logging_interval=10, batch_size=32, clip_grad=True)

Trains the VAE model on the provided data.

The training process consists of two phases:

  1. Burn-in phase: Initial training with a lower learning rate for stability
  2. Main training phase: Training with the full learning rate and optional curvature optimization

Training uses Adam optimizer with gradient clipping to prevent exploding gradients. During training, the model maximizes the Evidence Lower Bound (ELBO).

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Training data tensor.

  • D (None, default: None ) –

    Ignored.

  • lr (float, default: 0.001 ) –

    Learning rate for the main training phase.

  • burn_in_lr (float, default: 0.0001 ) –

    Learning rate for the burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for optimizing manifold scale factors. Off (no learning) by default.

  • burn_in_iterations (int, default: 1 ) –

    Number of iterations for the burn-in phase.

  • training_iterations (int, default: 9 ) –

    Number of iterations for the main training phase.

  • loss_window_size (int, default: 100 ) –

    Window size for computing moving average loss.

  • logging_interval (int, default: 10 ) –

    Interval for logging training progress.

  • batch_size (int, default: 32 ) –

    Batch size for training.

  • clip_grad (bool, default: True ) –

    Whether to apply gradient clipping.

Returns:
  • losses( 'ProductSpaceVAE' ) –

    List of loss values recorded during training.

Source code in manify/embedders/vae.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def fit(  # type: ignore[override]
    self,
    X: Float[torch.Tensor, "n_points n_features"],
    D: None = None,
    lr: float = 1e-3,
    burn_in_lr: float = 1e-4,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 1,
    training_iterations: int = 9,
    loss_window_size: int = 100,
    logging_interval: int = 10,
    batch_size: int = 32,
    clip_grad: bool = True,
) -> "ProductSpaceVAE":
    """Trains the VAE model on the provided data.

    The training process consists of two phases:

    1. Burn-in phase: Initial training with a lower learning rate for stability
    2. Main training phase: Training with the full learning rate and optional curvature optimization

    Training uses Adam optimizer with gradient clipping to prevent exploding gradients. During training, the model
    maximizes the Evidence Lower Bound (ELBO).

    Args:
        X: Training data tensor.
        D: Ignored.
        lr: Learning rate for the main training phase.
        burn_in_lr: Learning rate for the burn-in phase.
        curvature_lr: Learning rate for optimizing manifold scale factors. Off (no learning) by default.
        burn_in_iterations: Number of iterations for the burn-in phase.
        training_iterations: Number of iterations for the main training phase.
        loss_window_size: Window size for computing moving average loss.
        logging_interval: Interval for logging training progress.
        batch_size: Batch size for training.
        clip_grad: Whether to apply gradient clipping.

    Returns:
        losses: List of loss values recorded during training.
    """
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    my_tqdm = tqdm(total=(burn_in_iterations + training_iterations) * len(X))
    opt = torch.optim.Adam(
        [
            {"params": [p for p in self.parameters() if p not in set(self.pm.parameters())], "lr": burn_in_lr},
            {"params": self.pm.parameters(), "lr": 0},
        ]
    )
    losses: Dict[str, List[float]] = {"elbo": [], "ll": [], "kl": []}
    for epoch in range(burn_in_iterations + training_iterations):
        if epoch == burn_in_iterations:
            opt.param_groups[0]["lr"] = lr
            opt.param_groups[1]["lr"] = curvature_lr

        for i in range(0, len(X), batch_size):
            opt.zero_grad()
            X_batch = X[i : i + batch_size]
            elbo, ll, kl = self.elbo(X_batch)
            L = -elbo
            L.backward()

            # Add to losses
            losses["elbo"].append(elbo.item())
            losses["ll"].append(ll.item())
            losses["kl"].append(kl.item())

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.pm.parameters(), max_norm=1.0)
            if torch.isnan(L) or torch.isinf(L):
                print(f"Invalid loss detected at epoch {epoch}, batch {i}")
                continue
            elif self._grads_ok():
                opt.step()

            # TQDM management
            my_tqdm.update(batch_size)
            my_tqdm.set_description(f"L: {L.item():.3e}, ll: {ll.item():.3e}, kl: {kl.item():.3e}")

            # Logging
            if i % logging_interval == 0:
                d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
                # d["D_avg"] = f"{d_avg(D_tt, D[train][:, train], pairwise=True):.4f}"
                d["L_avg"] = f"{np.mean(losses['elbo'][-loss_window_size:]):.3e}"
                d["ll_avg"] = f"{np.mean(losses['ll'][-loss_window_size:]):.3e}"
                d["kl_avg"] = f"{np.mean(losses['kl'][-loss_window_size:]):.3e}"
                my_tqdm.set_postfix(d)

    # Final maintenance: update attributes
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self

transform(X, D=None, batch_size=32, expmap=True)

Transform data using the trained VAE. Outputs means of the variational distribution.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Features to embed with VAE.

  • D (None, default: None ) –

    Ignored.

  • batch_size (int, default: 32 ) –

    Number of samples per batch.

  • expmap (bool, default: True ) –

    Whether to use exponential map for embedding.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Source code in manify/embedders/vae.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def transform(
    self, X: Float[torch.Tensor, "n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using the trained VAE. Outputs means of the variational distribution.

    Args:
        X: Features to embed with VAE.
        D: Ignored.
        batch_size: Number of samples per batch.
        expmap: Whether to use exponential map for embedding.

    Returns:
        embeddings: Learned embeddings.
    """
    # Set random state
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Save the test embeddings
    embeddings_list = []
    for i in range(0, len(X), batch_size):
        x_batch = X[i : i + batch_size]
        z_mean_tangent, _ = self.encode(x_batch)
        if expmap:
            z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
            z_mean = self.pm.expmap(u=z_mean_ambient, base=None)
        else:
            z_mean = z_mean_tangent
        embeddings_list.append(z_mean.detach().cpu())

    embeddings = torch.cat(embeddings_list, dim=0)

    return embeddings