Embedders

manify.embedders

Tools for embedding data into Riemannian manifolds and product spaces.

The embedders module provides various ways to embed data into manifolds of constant or mixed curvature. The module includes:

  • coordinate_learning: Direct optimization of coordinates in a product manifold.
  • siamese: Siamese network-based embedding for metric learning.
  • vae: Variational autoencoders for learning representations in product manifolds.
  • _losses: Loss functions for measuring embedding quality.
  • _base: Base class for embedders.

CoordinateLearning(pm, random_state=None, device=None)

Bases: BaseEmbedder

Coordinate learning method class.

This embedder implements the approach described in Gu et al., "Learning Mixed-Curvature Representations in Product Spaces". It directly optimizes point coordinates to preserve a given distance matrix, using Riemannian optimization techniques.

Trains point coordinates in a product manifold to match target distances.

This class optimizes the coordinates of points in a product manifold to match a given distance matrix. The optimization is performed in two phases:

  1. Burn-in phase: Initial optimization with a smaller learning rate to find a good starting configuration.
  2. Training phase: Fine-tuning of the coordinates with a larger learning rate, and optionally optimizing the scale factors (curvatures) of the manifold components.

The optimization uses Riemannian Adam optimizer to respect the manifold structure during gradient updates. The loss is computed based on the distortion between the pairwise distances in the embedding and the target distances.

For non-transductive settings, the class supports split between training and testing points, optimizing different combinations of distances (train-train, test-test, train-test).

Attributes:
  • pm

    Product manifold defining the target embedding space.

  • embeddings_

    Optimized point coordinates after fitting.

  • loss_history_ (dict[str, list[float]]) –

    Training loss history.

  • is_fitted_ (bool) –

    Boolean flag indicating if the embedder has been fitted.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object defining the target embedding space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str | None, default: None ) –

    Optional device for tensor computations.

Source code in manify/embedders/coordinate_learning.py
69
70
def __init__(self, pm: ProductManifold, random_state: int | None = None, device: str | None = None) -> None:
    super().__init__(pm=pm, random_state=random_state, device=device)

fit(X, D, test_indices=None, lr=0.01, burn_in_lr=0.001, curvature_lr=0.0, burn_in_iterations=2000, training_iterations=18000, loss_window_size=100, logging_interval=10)

Fit the Coordinate Learning Embedder. Sets attributes embeddings_, loss_history_, and is_fitted_.

Parameters:
  • X (None) –

    Ignored.

  • D (Float[Tensor, 'n_points n_points']) –

    Tensor representing the target pairwise distance matrix between points.

  • test_indices (Int[Tensor, 'n_test'] | None, default: None ) –

    Tensor containing indices of test points for transductive learning. Defaults to an empty tensor (all points are used for training).

  • lr (float, default: 0.01 ) –

    Learning rate for the main training phase.

  • burn_in_lr (float, default: 0.001 ) –

    Learning rate for the burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for optimizing manifold scale factors. Off (no learning) by default.

  • burn_in_iterations (int, default: 2000 ) –

    Number of iterations for the burn-in phase.

  • training_iterations (int, default: 18000 ) –

    Number of iterations for the main training phase.

  • loss_window_size (int, default: 100 ) –

    Window size for computing moving average loss.

  • logging_interval (int, default: 10 ) –

    Interval for logging training progress.

Returns:
  • self( 'CoordinateLearning' ) –

    Fitted embedder instance.

Raises:
  • ValueError

    If the distance matrix D is None or if X is provided.

  • Warning

    If X is provided, it will be ignored during fitting.

Source code in manify/embedders/coordinate_learning.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def fit(  # type: ignore[override]
    self,
    X: None,
    D: Float[torch.Tensor, "n_points n_points"],
    test_indices: Int[torch.Tensor, "n_test"] | None = None,
    lr: float = 1e-2,
    burn_in_lr: float = 1e-3,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 2_000,
    training_iterations: int = 18_000,
    loss_window_size: int = 100,
    logging_interval: int = 10,
) -> "CoordinateLearning":
    """Fit the Coordinate Learning Embedder. Sets attributes `embeddings_`, `loss_history_`, and `is_fitted_`.

    Args:
        X: Ignored.
        D: Tensor representing the target pairwise distance matrix between points.
        test_indices: Tensor containing indices of test points for transductive learning.
            Defaults to an empty tensor (all points are used for training).
        lr: Learning rate for the main training phase.
        burn_in_lr: Learning rate for the burn-in phase.
        curvature_lr: Learning rate for optimizing manifold scale factors. Off (no learning) by default.
        burn_in_iterations: Number of iterations for the burn-in phase.
        training_iterations: Number of iterations for the main training phase.
        loss_window_size: Window size for computing moving average loss.
        logging_interval: Interval for logging training progress.

    Returns:
        self: Fitted embedder instance.

    Raises:
        ValueError: If the distance matrix D is None or if X is provided.
        Warning: If X is provided, it will be ignored during fitting.
    """
    # Input validation
    if D is None:
        raise ValueError("Distance matrix D is needed for coordinate learning")
    if X is not None:
        warnings.warn(
            "Input X has been given. This will be ignored during fitting. If you have provided a distance matrix,please run embedder.fit(None, D) instead.",
            stacklevel=2,
        )

    # Set random seed if provided
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Move everything to the device; initialize random embeddings
    n = D.shape[0]
    covs = [torch.stack([torch.eye(M.dim) / self.pm.dim] * n).to(self.device) for M in self.pm.P]
    means = torch.vstack([self.pm.mu0] * n).to(self.device)
    X_embed = self.pm.sample(z_mean=means, sigma_factorized=covs)
    D = D.to(self.device)

    # Get train and test indices set up
    test_indices = test_indices if test_indices is not None else torch.tensor([])
    use_test = len(test_indices) > 0
    test = torch.tensor([i in test_indices for i in range(len(D))]).to(self.device)
    train = ~test

    # Initialize optimizer
    X_embed = geoopt.ManifoldParameter(X_embed, manifold=self.pm.manifold)
    ropt = geoopt.optim.RiemannianAdam(
        [{"params": [X_embed], "lr": burn_in_lr}, {"params": self.pm.parameters(), "lr": 0}]
    )

    # Init TQDM
    my_tqdm = tqdm(total=burn_in_iterations + training_iterations, leave=False)

    # Outer training loop - mostly setting optimizer learning rates up here
    losses: dict[str, list[float]] = {"train_train": [], "test_test": [], "train_test": [], "total": []}

    # Actual training loop
    for i in range(burn_in_iterations + training_iterations):
        if i == burn_in_iterations:
            # Optimize curvature by changing lr
            ropt.param_groups[0]["lr"] = lr
            ropt.param_groups[1]["lr"] = curvature_lr

        # Zero grad
        ropt.zero_grad()

        # 1. Train-train loss
        X_t = X_embed[train]
        D_tt = self.pm.pdist(X_t)
        L_tt = distortion_loss(D_tt, D[train][:, train], pairwise=True)
        L_tt.backward(retain_graph=True)
        losses["train_train"].append(L_tt.item())

        if use_test:
            # 2. Test-test loss
            X_q = X_embed[test]
            D_qq = self.pm.pdist(X_q)
            L_qq = distortion_loss(D_qq, D[test][:, test], pairwise=True)
            L_qq.backward(retain_graph=True)
            losses["test_test"].append(L_qq.item())

            # 3. Train-test loss
            X_t_detached = X_embed[train].detach()
            D_tq = self.pm.dist(X_t_detached, X_q)  # Note 'dist' not 'pdist', as we're comparing different sets
            L_tq = distortion_loss(D_tq, D[train][:, test], pairwise=False)
            L_tq.backward()
            losses["train_test"].append(L_tq.item())
        else:
            L_qq = 0
            L_tq = 0

        # Step
        ropt.step()
        L = L_tt + L_qq + L_tq
        losses["total"].append(L.item())

        # TQDM management
        my_tqdm.update(1)
        my_tqdm.set_description(f"Loss: {L.item():.3e}")

        # Logging
        if i % logging_interval == 0:
            d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
            d["D_avg"] = f"{d_avg(D_tt, D[train][:, train], pairwise=True):.4f}"
            d["L_avg"] = f"{np.mean(losses['total'][-loss_window_size:]):.3e}"
            my_tqdm.set_postfix(d)

        # Early stopping for errors
        if torch.isnan(L):
            raise ValueError("Loss is NaN")

    # Final maintenance: update attributes
    self.embeddings_ = X_embed.data.detach()
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self

transform(X=None)

Transform data using learned embedding. This is not meaningful for new data during coordinate learning.

Parameters:
  • X (None, default: None ) –

    Ignored.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Raises:
  • ValueError

    If the embedder has not been fitted yet.

  • Warning

    If X is provided, as it will be ignored.

Source code in manify/embedders/coordinate_learning.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def transform(self, X: None = None) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using learned embedding. This is not meaningful for new data during coordinate learning.

    Args:
        X: Ignored.

    Returns:
        embeddings: Learned embeddings.

    Raises:
        ValueError: If the embedder has not been fitted yet.
        Warning: If X is provided, as it will be ignored.
    """
    if not self.is_fitted_:
        raise ValueError("The embedder has not been fitted yet.")

    if X is not None:
        warnings.warn("Coordinate learning can only return trained embeddings. X will be ignored.", stacklevel=2)

    return self.embeddings_

fit_transform(X, D, **fit_kwargs)

Transform data using learned embedding based on the provided distance matrix D.

This method overrides the base class method BaseEmbedder.fit_transform() to not use the input data X.

Parameters:
  • X (None) –

    Ignored.

  • D (Float[Tensor, 'n_points n_points']) –

    Distance matrix for the points.

  • fit_kwargs (Any, default: {} ) –

    Additional keyword arguments passed to the model.fit() method.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Source code in manify/embedders/coordinate_learning.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def fit_transform(  # type: ignore[override]
    self, X: None, D: Float[torch.Tensor, "n_points n_points"], **fit_kwargs: Any
) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using learned embedding based on the provided distance matrix D.

    This method overrides the base class method `BaseEmbedder.fit_transform()` to not use the input data X.

    Args:
        X: Ignored.
        D: Distance matrix for the points.
        fit_kwargs: Additional keyword arguments passed to the `model.fit()` method.

    Returns:
        embeddings: Learned embeddings.
    """
    return self.fit(X=None, D=D, **fit_kwargs).transform(X=None)

SiameseNetwork(pm, encoder, decoder=None, reconstruction_loss='mse', beta=1.0, random_state=None, device='cpu')

Bases: BaseEmbedder, Module

Siamese network for embedding data into a product manifold space.

A Siamese network consists of an encoder network that maps input data to a latent representation in a product manifold, and optionally a decoder network that maps the latent representation back to the original feature space.

Attributes:
  • pm

    Product manifold defining the structure of the latent space.

  • random_state

    Random state for reproducibility.

  • encoder

    Neural network that maps inputs to latent embeddings.

  • decoder

    Neural network that reconstructs inputs from latent embeddings.

  • beta

    Weight for the distortion term in the loss function.

  • device

    Device for tensor computations.

  • reconstruction_loss

    Type of reconstruction loss to use.

Parameters:
  • pm (ProductManifold) –

    Product manifold defining the structure of the latent space.

  • encoder (Module) –

    Neural network module that maps inputs to the manifold's intrinsic dimension. The output dimension should match the intrinsic dimension of the product manifold.

  • decoder (Module | None, default: None ) –

    Neural network module that maps latent representations back to the input space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str, default: 'cpu' ) –

    Optional device for tensor computations.

  • beta (float, default: 1.0 ) –

    Weight of the distortion term in the loss function.

  • reconstruction_loss (str, default: 'mse' ) –

    Type of reconstruction loss to use.

Source code in manify/embedders/siamese.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    pm: ProductManifold,
    encoder: torch.nn.Module,
    decoder: torch.nn.Module | None = None,
    reconstruction_loss: str = "mse",
    beta: float = 1.0,
    random_state: int | None = None,
    device: str = "cpu",
):
    # Init both base classes
    torch.nn.Module.__init__(self)
    BaseEmbedder.__init__(self, pm=pm, random_state=random_state, device=device)

    # Now we assign
    self.pm = pm
    self.encoder = encoder
    self.beta = beta

    if decoder is not None:
        self.decoder = decoder
    else:
        self.decoder = torch.nn.Identity()
        self.decoder.requires_grad_(False)
        self.decoder.to(pm.device)

    if reconstruction_loss == "mse":
        self.reconstruction_loss = torch.nn.MSELoss(reduction="none")
    else:
        raise ValueError(f"Unknown reconstruction loss: {reconstruction_loss}")

encode(x)

Encodes input data into the manifold embedding space.

Takes a batch of input data and passes it through the encoder network to obtain embeddings in the manifold.

Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor..

Returns:
  • embeddings( Float[Tensor, 'batch_size n_latent'] ) –

    Tensor containing the embeddings in the manifold space.

Source code in manify/embedders/siamese.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def encode(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Float[torch.Tensor, "batch_size n_latent"]:
    """Encodes input data into the manifold embedding space.

    Takes a batch of input data and passes it through the encoder network to obtain embeddings in the manifold.

    Args:
        x: Input data tensor..

    Returns:
        embeddings: Tensor containing the embeddings in the manifold space.
    """
    return self.encoder(x)

decode(z)

Decodes manifold embeddings back to the original input space.

Takes a batch of embeddings from the manifold space and passes them through the decoder network to reconstruct the original input data.

Parameters:
  • z (Float[Tensor, 'batch_size n_latent']) –

    Embedding tensor from the manifold space.

Returns:
  • reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Tensor containing the reconstructed input data.

Source code in manify/embedders/siamese.py
104
105
106
107
108
109
110
111
112
113
114
115
116
def decode(self, z: Float[torch.Tensor, "batch_size n_latent"]) -> Float[torch.Tensor, "batch_size n_features"]:
    """Decodes manifold embeddings back to the original input space.

    Takes a batch of embeddings from the manifold space and passes them through
    the decoder network to reconstruct the original input data.

    Args:
        z: Embedding tensor from the manifold space.

    Returns:
        reconstructed: Tensor containing the reconstructed input data.
    """
    return self.decoder(z)

forward(x1, x2)

Given two points, return their encodings, reconstructions, and embedding distance.

Parameters:
  • x1 (Float[Tensor, 'batch_size n_features']) –

    First input tensor.

  • x2 (Float[Tensor, 'batch_size n_features']) –

    Second input tensor.

Returns:
  • z1( Float[Tensor, 'batch_size n_latent'] ) –

    Encoded representation of the first input.

  • z2( Float[Tensor, 'batch_size n_latent'] ) –

    Encoded representation of the second input.

  • D_hat( Float[Tensor, 'batch_size'] ) –

    Estimated distance between the two embeddings.

  • reconstructed1( Float[Tensor, 'batch_size n_features'] ) –

    Reconstructed input from the first embedding.

  • reconstructed2( Float[Tensor, 'batch_size n_features'] ) –

    Reconstructed input from the second embedding.

Source code in manify/embedders/siamese.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def forward(
    self, x1: Float[torch.Tensor, "batch_size n_features"], x2: Float[torch.Tensor, "batch_size n_features"]
) -> Tuple[
    Float[torch.Tensor, "batch_size n_latent"],
    Float[torch.Tensor, "batch_size n_latent"],
    Float[torch.Tensor, "batch_size"],
    Float[torch.Tensor, "batch_size n_features"],
    Float[torch.Tensor, "batch_size n_features"],
]:
    """Given two points, return their encodings, reconstructions, and embedding distance.

    Args:
        x1: First input tensor.
        x2: Second input tensor.

    Returns:
        z1: Encoded representation of the first input.
        z2: Encoded representation of the second input.
        D_hat: Estimated distance between the two embeddings.
        reconstructed1: Reconstructed input from the first embedding.
        reconstructed2: Reconstructed input from the second embedding.
    """
    z1 = self.pm.expmap(self.encode(x1) @ self.pm.projection_matrix)
    z2 = self.pm.expmap(self.encode(x2) @ self.pm.projection_matrix)
    D_hat = self.pm.manifold.dist(z1, z2)  # use manifold dist to get (batch_size, ) vector of dists
    reconstructed1 = self.decode(z1)
    reconstructed2 = self.decode(z2)
    return z1, z2, D_hat, reconstructed1, reconstructed2

fit(X, D, lr=0.001, burn_in_lr=0.0001, curvature_lr=0.0, burn_in_iterations=1, training_iterations=9, loss_window_size=100, logging_interval=10, batch_size=32, clip_grad=True)

Fit the SiameseNetwork embedder.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Input data features to encode.

  • D (Float[Tensor, 'n_points n_points']) –

    Pairwise distances to emulate.

  • lr (float, default: 0.001 ) –

    Learning rate for the optimizer.

  • burn_in_lr (float, default: 0.0001 ) –

    Learning rate during burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for curvature updates.

  • burn_in_iterations (int, default: 1 ) –

    Number of iterations for burn-in phase.

  • training_iterations (int, default: 9 ) –

    Number of iterations for training phase.

  • loss_window_size (int, default: 100 ) –

    Size of the window for loss averaging.

  • logging_interval (int, default: 10 ) –

    Interval for logging progress.

  • batch_size (int, default: 32 ) –

    Number of samples per batch.

  • clip_grad (bool, default: True ) –

    Whether to clip gradients.

Returns:
  • self( 'SiameseNetwork' ) –

    Fitted SiameseNetwork instance.

Source code in manify/embedders/siamese.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def fit(  # type: ignore[override]
    self,
    X: Float[torch.Tensor, "n_points n_features"],
    D: Float[torch.Tensor, "n_points n_points"],
    lr: float = 1e-3,
    burn_in_lr: float = 1e-4,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 1,
    training_iterations: int = 9,
    loss_window_size: int = 100,
    logging_interval: int = 10,
    batch_size: int = 32,
    clip_grad: bool = True,
) -> "SiameseNetwork":
    """Fit the SiameseNetwork embedder.

    Args:
        X: Input data features to encode.
        D: Pairwise distances to emulate.
        lr: Learning rate for the optimizer.
        burn_in_lr: Learning rate during burn-in phase.
        curvature_lr: Learning rate for curvature updates.
        burn_in_iterations: Number of iterations for burn-in phase.
        training_iterations: Number of iterations for training phase.
        loss_window_size: Size of the window for loss averaging.
        logging_interval: Interval for logging progress.
        batch_size: Number of samples per batch.
        clip_grad: Whether to clip gradients.

    Returns:
        self: Fitted SiameseNetwork instance.
    """
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    n_samples = len(X)

    # Generate all upper triangular pairs using torch
    indices = torch.triu_indices(n_samples, n_samples, offset=1)
    pairs = torch.hstack([indices]).T  # (n_pairs, 2)

    # Number of pairs and batches
    n_pairs = len(pairs)
    n_batches_per_epoch = (n_pairs + batch_size - 1) // batch_size  # Ceiling division
    total_iterations = (burn_in_iterations + training_iterations) * n_batches_per_epoch

    my_tqdm = tqdm(total=total_iterations)

    opt = torch.optim.Adam(
        [
            {"params": [p for p in self.parameters() if p not in set(self.pm.parameters())], "lr": burn_in_lr},
            {"params": self.pm.parameters(), "lr": 0},
        ]
    )
    losses: Dict[str, List[float]] = {"total": [], "reconstruction": [], "distortion": []}

    for epoch in range(burn_in_iterations + training_iterations):
        if epoch == burn_in_iterations:
            opt.param_groups[0]["lr"] = lr
            opt.param_groups[1]["lr"] = curvature_lr

        # Shuffle all pairs
        shuffle_idx = torch.randperm(n_pairs)
        shuffled_pairs = pairs[shuffle_idx]

        for batch_start in range(0, n_pairs, batch_size):
            batch_end = min(batch_start + batch_size, n_pairs)
            batch_pairs = shuffled_pairs[batch_start:batch_end]

            # Extract indices for this batch
            batch_indices1 = batch_pairs[:, 0]
            batch_indices2 = batch_pairs[:, 1]

            # Get data for these indices
            X1 = X[batch_indices1]
            X2 = X[batch_indices2]

            # Extract the corresponding distances from D using advanced indexing
            D_batch = D[batch_indices1, batch_indices2]

            # Forward pass
            opt.zero_grad()
            _, _, D_hat, Y1, Y2 = self(X1, X2)
            mse1 = torch.nn.functional.mse_loss(Y1, X1)
            mse2 = torch.nn.functional.mse_loss(Y2, X2)

            # D_hat and D_batch are now 1D tensors of pairwise distances
            distortion = distortion_loss(D_hat, D_batch, pairwise=False)
            L = mse1 + mse2 + self.beta * distortion
            L.backward()

            # Add to losses
            losses["total"].append(L.item())
            losses["reconstruction"].append(mse1.item() + mse2.item())
            losses["distortion"].append(distortion.item())

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.pm.parameters(), max_norm=1.0)

            opt.step()

            # TQDM management
            my_tqdm.update(1)
            my_tqdm.set_description(
                f"L: {L.item():.3e}, recon: {mse1.item() + mse2.item():.3e}, dist: {distortion.item():.3e}"
            )

            # Logging
            if my_tqdm.n % logging_interval == 0:
                d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
                d["L_avg"] = f"{np.mean(losses['total'][-loss_window_size:]):.3e}"
                d["recon_avg"] = f"{np.mean(losses['reconstruction'][-loss_window_size:]):.3e}"
                d["dist_avg"] = f"{np.mean(losses['distortion'][-loss_window_size:]):.3e}"
                my_tqdm.set_postfix(d)

    # Final maintenance: update attributes
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self

transform(X, D=None, batch_size=32, expmap=True)

Transforms input data into manifold embeddings.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Features to embed with SiameseNetwork.

  • D (None, default: None ) –

    Ignored.

  • batch_size (int, default: 32 ) –

    Number of samples per batch.

  • expmap (bool, default: True ) –

    Whether to use exponential map for embedding.

Returns:
  • embeddings( Float[Tensor, 'n_points n_latent'] ) –

    Embeddings produced by forward pass of trained SiameseNetwork model.

Source code in manify/embedders/siamese.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def transform(
    self, X: Float[torch.Tensor, "n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
) -> Float[torch.Tensor, "n_points n_latent"]:
    """Transforms input data into manifold embeddings.

    Args:
        X: Features to embed with SiameseNetwork.
        D: Ignored.
        batch_size: Number of samples per batch.
        expmap: Whether to use exponential map for embedding.

    Returns:
        embeddings: Embeddings produced by forward pass of trained SiameseNetwork model.
    """
    # Set random state
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Save the  embeddings
    embeddings_list = []
    for i in range(0, len(X), batch_size):
        batch = X[i : i + batch_size]
        embeddings = self.encode(batch)
        if expmap:
            embeddings = self.pm.expmap(embeddings @ self.pm.projection_matrix)
        embeddings_list.append(embeddings)
    embeddings = torch.cat(embeddings_list, dim=0)

    return embeddings

ProductSpaceVAE(pm, encoder, decoder, random_state=None, device='cpu', beta=1.0, reconstruction_loss=None, n_samples=16)

Bases: BaseEmbedder, Module

Product Space Variational Autoencoder.

The probabilistic model is defined as:

  • Prior: \(p(z) = \mathcal{WN}(z; \mu_0, I)\) (wrapped normal distribution centered at manifold origin)
  • Likelihood: \(p_\theta(x|z) = \mathcal{N}(x; f_\theta(z), \sigma^2 I)\) or other reconstruction distribution
  • Posterior approximation: \(q_\phi(z|x) = \mathcal{WN}(z; \mu_\phi(x), \Sigma_\phi(x))\)

where \(\mathcal{WN}\) is a wrapped normal distribution on the manifold.

The model is trained by maximizing the evidence lower bound (ELBO):

\(\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z))\)

Attributes:
  • pm

    Product manifold defining the structure of the latent space.

  • random_state

    Random state for reproducibility.

  • encoder

    Neural network that outputs mean and log-variance parameters.

  • decoder

    Neural network that reconstructs inputs from latent embeddings.

  • beta

    Weight for the KL divergence term in the ELBO.

  • device

    Device for tensor computations.

  • n_samples

    Number of samples for Monte Carlo estimation of KL divergence.

  • reconstruction_loss

    Type of reconstruction loss to use.

  • loss_history_

    Dictionary to store the history of loss values during training.

  • is_fitted_

    Boolean flag indicating whether the model has been fitted.

Parameters:
  • pm (ProductManifold) –

    Product manifold defining the structure of the latent space.

  • encoder (Module) –

    Neural network module that produces mean (first half of output) and log-variance (second half of output) of the posterior distribution. The output dimension should match twice the intrinsic dimension of the product manifold.

  • decoder (Module) –

    Neural network module that maps latent representations back to the input space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str, default: 'cpu' ) –

    Optional device for tensor computations.

  • beta (float, default: 1.0 ) –

    Weight of the KL divergence term in the ELBO loss. Values < 1 give a \(\beta\)-VAE with a looser constraint on the latent space.

  • reconstruction_loss (_Loss | None, default: None ) –

    Type of reconstruction loss to use. Currently only "mse" (mean squared error) is supported.

  • n_samples (int, default: 16 ) –

    Number of Monte Carlo samples to use when estimating the KL divergence.

Source code in manify/embedders/vae.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    pm: ProductManifold,
    encoder: torch.nn.Module,
    decoder: torch.nn.Module,
    random_state: int | None = None,
    device: str = "cpu",
    beta: float = 1.0,
    reconstruction_loss: torch.nn.modules.loss._Loss | None = None,
    n_samples: int = 16,
):
    # Init both base classes
    torch.nn.Module.__init__(self)
    BaseEmbedder.__init__(self, pm=pm, random_state=random_state, device=device)

    # Now we assign
    self.encoder = encoder.to(device)
    self.decoder = decoder.to(device)
    self.beta = beta
    self.n_samples = n_samples
    self.reconstruction_loss = (
        reconstruction_loss if reconstruction_loss is not None else torch.nn.MSELoss(reduction="none")
    )
    self.model_ = None
    self.loss_history_ = {}
    self.is_fitted_ = False

    # Ensure encoder last dimension is 2 * pm.intrinsic_dim:
    assert encoder[-1].out_features == 2 * pm.dim, "Encoder output must match 2 * intrinsic dimension of manifold."

    # Ensure decoder input dimension is pm.intrinsic_dim
    assert decoder[0].in_features == pm.ambient_dim, "Decoder input must match ambient dimension of manifold."

encode(x)

Encodes input data to obtain latent means and log-variances in the manifold.

This method processes input data through the encoder network to obtain parameters of the approximate posterior distribution \(q(z|x)\) in the product manifold space. For non-Euclidean components, the method:

  1. Gets tangent space vectors and log-variances from the encoder,
  2. Projects tangent vectors to the ambient space by adding zeros in the right places, and
  3. Maps the ambient space vectors to the manifold using the exponential map
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • z_mean_tangent( Float[Tensor, 'batch_size n_latent'] ) –

    Mean of the posterior distribution in the tangent plane at the origin.

  • z_logvar( Float[Tensor, 'batch_size n_latent'] ) –

    Log-variance of the posterior distribution, used for constructing the covariance matrices.

Source code in manify/embedders/vae.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def encode(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]:
    r"""Encodes input data to obtain latent means and log-variances in the manifold.

    This method processes input data through the encoder network to obtain parameters of the approximate posterior
    distribution $q(z|x)$ in the product manifold space. For non-Euclidean components, the method:

    1. Gets tangent space vectors and log-variances from the encoder,
    2. Projects tangent vectors to the ambient space by adding zeros in the right places, and
    3. Maps the ambient space vectors to the manifold using the exponential map

    Args:
        x: Input data tensor.

    Returns:
        z_mean_tangent: Mean of the posterior distribution in the tangent plane at the origin.
        z_logvar: Log-variance of the posterior distribution, used for constructing the covariance matrices.
    """
    z = self.encoder(x)
    z_mean_tangent, z_logvar = z[..., : self.pm.dim], z[..., self.pm.dim :]
    # z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
    # z_mean = self.pm.expmap(u=z_mean_ambient, base=None)
    return z_mean_tangent, z_logvar

decode(z)

Decodes latent points from the manifold space back to the input space.

Takes points from the product manifold latent space and passes them through the decoder network to reconstruct the original input data.

Parameters:
  • z (Float[Tensor, 'batch_size n_ambient']) –

    Latent points in the product manifold

Returns:
  • reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Tensor containing the reconstructed input data, with shape (batch_size, n_features).

Source code in manify/embedders/vae.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def decode(self, z: Float[torch.Tensor, "batch_size n_ambient"]) -> Float[torch.Tensor, "batch_size n_features"]:
    """Decodes latent points from the manifold space back to the input space.

    Takes points from the product manifold latent space and passes them through
    the decoder network to reconstruct the original input data.

    Args:
        z: Latent points in the product manifold

    Returns:
        reconstructed: Tensor containing the reconstructed input data,
            with shape (batch_size, n_features).
    """
    return self.decoder(z)

forward(x)

Performs the forward pass of the VAE in product manifold space.

This method implements the complete VAE forward pass, with manifold projection:

  1. Encode the input to get posterior parameters (z_means, z_logvars)
  2. Project means onto the manifold using exponential map
  3. Factorize the log-variances for each manifold component and convert to covariance matrices
  4. Sample points from the posterior distributions in the product manifold
  5. Decode the sampled points to get reconstructions
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • x_reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Reconstructed data tensor with the same shape as the input.

  • z_means( Float[Tensor, 'batch_size n_ambient'] ) –

    Means of the posterior distributions in the manifold space.

  • sigmas( list[Float[Tensor, 'batch_size n_latent n_latent']] ) –

    List of covariance matrices for each manifold component.

Source code in manify/embedders/vae.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def forward(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[
    Float[torch.Tensor, "batch_size n_features"],
    Float[torch.Tensor, "batch_size n_ambient"],
    list[Float[torch.Tensor, "batch_size n_latent n_latent"]],
]:
    r"""Performs the forward pass of the VAE in product manifold space.

    This method implements the complete VAE forward pass, with manifold projection:

    1. Encode the input to get posterior parameters (`z_means`, `z_logvars`)
    2. Project means onto the manifold using exponential map
    3. Factorize the log-variances for each manifold component and convert to covariance matrices
    4. Sample points from the posterior distributions in the product manifold
    5. Decode the sampled points to get reconstructions

    Args:
        x: Input data tensor.

    Returns:
        x_reconstructed: Reconstructed data tensor with the same shape as the input.
        z_means: Means of the posterior distributions in the manifold space.
        sigmas: List of covariance matrices for each manifold component.
    """
    z_mean_tangent, z_logvars = self.encode(x)

    # Need to convert from implicit parameterization to extrinsic coordinates
    z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
    z_means = self.pm.expmap(u=z_mean_ambient, base=None)

    # Factorize log-variances; convert to covariances
    sigma_factorized = self.pm.factorize(z_logvars, intrinsic=True)
    sigmas = [torch.diag_embed(torch.exp(z_logvar) + 1e-8) for z_logvar in sigma_factorized]

    # Sample and decode
    z = self.pm.sample(z_mean=z_means, sigma_factorized=sigmas)
    x_reconstructed = self.decode(z)
    return x_reconstructed, z_means, sigmas

kl_divergence(z_mean, sigma_factorized)

Computes the KL divergence between posterior and prior distributions in the manifold.

For distributions in Riemannian manifolds, computing the KL divergence analytically is often intractable. This method uses Monte Carlo sampling to approximate the KL divergence:

\[D_{KL}(q(z|x) || p(z)) \approx \frac{1}{N} \sum_{i=1}^{N} [\log q(z_i|x) - \log p(z_i)]\]

where \(z_i\) are samples from \(q(z|x)\).

This implementation follows the approach described in: http://joschu.net/blog/kl-approx.html

Parameters:
  • z_mean (Float[Tensor, 'batch_size n_latent']) –

    Means of the posterior distributions in the manifold.

  • sigma_factorized (list[Float[Tensor, 'batch_size manifold_dim manifold_dim']]) –

    List of covariance matrices for each manifold component.

Returns:
  • kl_divergence( Float[Tensor, 'batch_size'] ) –

    KL divergence values for each data point in the batch.

Source code in manify/embedders/vae.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def kl_divergence(
    self,
    z_mean: Float[torch.Tensor, "batch_size n_latent"],
    sigma_factorized: list[Float[torch.Tensor, "batch_size manifold_dim manifold_dim"]],
) -> Float[torch.Tensor, "batch_size"]:
    r"""Computes the KL divergence between posterior and prior distributions in the manifold.

    For distributions in Riemannian manifolds, computing the KL divergence analytically
    is often intractable. This method uses Monte Carlo sampling to approximate the KL divergence:

    $$D_{KL}(q(z|x) || p(z)) \approx \frac{1}{N} \sum_{i=1}^{N} [\log q(z_i|x) - \log p(z_i)]$$

    where $z_i$ are samples from $q(z|x)$.

    This implementation follows the approach described in:
    http://joschu.net/blog/kl-approx.html

    Args:
        z_mean: Means of the posterior distributions in the manifold.
        sigma_factorized: List of covariance matrices for each manifold component.

    Returns:
        kl_divergence: KL divergence values for each data point in the batch.
    """
    # Get KL divergence as the average of log q(z|x) - log p(z)
    means = torch.repeat_interleave(z_mean, self.n_samples, dim=0)
    sigmas_factorized_interleaved = [
        torch.repeat_interleave(sigma, self.n_samples, dim=0) for sigma in sigma_factorized
    ]
    # We want to use n_samples = 1 here, since we'll need to pass the interleaved means/sigmas to the log-likelihood
    z_samples = self.pm.sample(z_mean=means, sigma_factorized=sigmas_factorized_interleaved)
    log_qz = self.pm.log_likelihood(z_samples, means, sigmas_factorized_interleaved)
    log_pz = self.pm.log_likelihood(z_samples)
    return (log_qz - log_pz).view(-1, self.n_samples).mean(dim=1)

elbo(x)

Computes the Evidence Lower Bound (ELBO) for the VAE objective.

The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term (log-likelihood) and a regularization term (KL divergence):

\[\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z)),\]

where:

  • \(\theta\) are the decoder parameters
  • \(\phi\) are the encoder parameters
  • \(\beta\) is a weight for the KL term (setting \(\beta < 1\) creates a \(\beta\)-VAE)
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • elbo( Float[Tensor, ''] ) –

    Mean ELBO value across the batch (higher is better).

  • log_likelihood( Float[Tensor, ''] ) –

    Mean reconstruction log-likelihood across the batch.

  • kl_divergence( Float[Tensor, ''] ) –

    Mean KL divergence across the batch.

Source code in manify/embedders/vae.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def elbo(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[Float[torch.Tensor, ""], Float[torch.Tensor, ""], Float[torch.Tensor, ""]]:
    r"""Computes the Evidence Lower Bound (ELBO) for the VAE objective.

    The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term
    (log-likelihood) and a regularization term (KL divergence):

    $$\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z)),$$

    where:

    - $\theta$ are the decoder parameters
    - $\phi$ are the encoder parameters
    - $\beta$ is a weight for the KL term (setting $\beta < 1$ creates a $\beta$-VAE)

    Args:
        x: Input data tensor.

    Returns:
        elbo: Mean ELBO value across the batch (higher is better).
        log_likelihood: Mean reconstruction log-likelihood across the batch.
        kl_divergence: Mean KL divergence across the batch.
    """
    x_reconstructed, z_means, sigma_factorized = self(x)
    kld = self.kl_divergence(z_means, sigma_factorized)
    ll = -self.reconstruction_loss(x_reconstructed.view(x.shape[0], -1), x.view(x.shape[0], -1)).sum(dim=1)
    return (ll - self.beta * kld).mean(), ll.mean(), kld.mean()

fit(X, D=None, lr=0.001, burn_in_lr=0.0001, curvature_lr=0.0, burn_in_iterations=1, training_iterations=9, loss_window_size=100, logging_interval=10, batch_size=32, clip_grad=True)

Trains the VAE model on the provided data.

The training process consists of two phases:

  1. Burn-in phase: Initial training with a lower learning rate for stability
  2. Main training phase: Training with the full learning rate and optional curvature optimization

Training uses Adam optimizer with gradient clipping to prevent exploding gradients. During training, the model maximizes the Evidence Lower Bound (ELBO).

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Training data tensor.

  • D (None, default: None ) –

    Ignored.

  • lr (float, default: 0.001 ) –

    Learning rate for the main training phase.

  • burn_in_lr (float, default: 0.0001 ) –

    Learning rate for the burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for optimizing manifold scale factors. Off (no learning) by default.

  • burn_in_iterations (int, default: 1 ) –

    Number of iterations for the burn-in phase.

  • training_iterations (int, default: 9 ) –

    Number of iterations for the main training phase.

  • loss_window_size (int, default: 100 ) –

    Window size for computing moving average loss.

  • logging_interval (int, default: 10 ) –

    Interval for logging training progress.

  • batch_size (int, default: 32 ) –

    Batch size for training.

  • clip_grad (bool, default: True ) –

    Whether to apply gradient clipping.

Returns:
  • losses( 'ProductSpaceVAE' ) –

    List of loss values recorded during training.

Source code in manify/embedders/vae.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def fit(  # type: ignore[override]
    self,
    X: Float[torch.Tensor, "n_points n_features"],
    D: None = None,
    lr: float = 1e-3,
    burn_in_lr: float = 1e-4,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 1,
    training_iterations: int = 9,
    loss_window_size: int = 100,
    logging_interval: int = 10,
    batch_size: int = 32,
    clip_grad: bool = True,
) -> "ProductSpaceVAE":
    """Trains the VAE model on the provided data.

    The training process consists of two phases:

    1. Burn-in phase: Initial training with a lower learning rate for stability
    2. Main training phase: Training with the full learning rate and optional curvature optimization

    Training uses Adam optimizer with gradient clipping to prevent exploding gradients. During training, the model
    maximizes the Evidence Lower Bound (ELBO).

    Args:
        X: Training data tensor.
        D: Ignored.
        lr: Learning rate for the main training phase.
        burn_in_lr: Learning rate for the burn-in phase.
        curvature_lr: Learning rate for optimizing manifold scale factors. Off (no learning) by default.
        burn_in_iterations: Number of iterations for the burn-in phase.
        training_iterations: Number of iterations for the main training phase.
        loss_window_size: Window size for computing moving average loss.
        logging_interval: Interval for logging training progress.
        batch_size: Batch size for training.
        clip_grad: Whether to apply gradient clipping.

    Returns:
        losses: List of loss values recorded during training.
    """
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    my_tqdm = tqdm(total=(burn_in_iterations + training_iterations) * len(X))
    opt = torch.optim.Adam(
        [
            {"params": [p for p in self.parameters() if p not in set(self.pm.parameters())], "lr": burn_in_lr},
            {"params": self.pm.parameters(), "lr": 0},
        ]
    )
    losses: Dict[str, List[float]] = {"elbo": [], "ll": [], "kl": []}
    for epoch in range(burn_in_iterations + training_iterations):
        if epoch == burn_in_iterations:
            opt.param_groups[0]["lr"] = lr
            opt.param_groups[1]["lr"] = curvature_lr

        for i in range(0, len(X), batch_size):
            opt.zero_grad()
            X_batch = X[i : i + batch_size]
            elbo, ll, kl = self.elbo(X_batch)
            L = -elbo
            L.backward()

            # Add to losses
            losses["elbo"].append(elbo.item())
            losses["ll"].append(ll.item())
            losses["kl"].append(kl.item())

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.pm.parameters(), max_norm=1.0)
            if torch.isnan(L) or torch.isinf(L):
                print(f"Invalid loss detected at epoch {epoch}, batch {i}")
                continue
            elif self._grads_ok():
                opt.step()

            # TQDM management
            my_tqdm.update(batch_size)
            my_tqdm.set_description(f"L: {L.item():.3e}, ll: {ll.item():.3e}, kl: {kl.item():.3e}")

            # Logging
            if i % logging_interval == 0:
                d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
                # d["D_avg"] = f"{d_avg(D_tt, D[train][:, train], pairwise=True):.4f}"
                d["L_avg"] = f"{np.mean(losses['elbo'][-loss_window_size:]):.3e}"
                d["ll_avg"] = f"{np.mean(losses['ll'][-loss_window_size:]):.3e}"
                d["kl_avg"] = f"{np.mean(losses['kl'][-loss_window_size:]):.3e}"
                my_tqdm.set_postfix(d)

    # Final maintenance: update attributes
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self

transform(X, D=None, batch_size=32, expmap=True)

Transform data using the trained VAE. Outputs means of the variational distribution.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Features to embed with VAE.

  • D (None, default: None ) –

    Ignored.

  • batch_size (int, default: 32 ) –

    Number of samples per batch.

  • expmap (bool, default: True ) –

    Whether to use exponential map for embedding.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Source code in manify/embedders/vae.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def transform(
    self, X: Float[torch.Tensor, "n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using the trained VAE. Outputs means of the variational distribution.

    Args:
        X: Features to embed with VAE.
        D: Ignored.
        batch_size: Number of samples per batch.
        expmap: Whether to use exponential map for embedding.

    Returns:
        embeddings: Learned embeddings.
    """
    # Set random state
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Save the test embeddings
    embeddings_list = []
    for i in range(0, len(X), batch_size):
        x_batch = X[i : i + batch_size]
        z_mean_tangent, _ = self.encode(x_batch)
        if expmap:
            z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
            z_mean = self.pm.expmap(u=z_mean_ambient, base=None)
        else:
            z_mean = z_mean_tangent
        embeddings_list.append(z_mean.detach().cpu())

    embeddings = torch.cat(embeddings_list, dim=0)

    return embeddings

coordinate_learning

Implementation for direct coordinate optimization in Riemannian manifolds.

This module provides functions for learning optimal embeddings in product manifolds by directly optimizing the coordinates using Riemannian optimization. This approach is particularly useful for embedding graphs using metric learning to maintain pairwise distances in the target space. The optimization is performed using Riemannian gradient descent with support for non-transductive training, in which gradients from the test set to the training set are masked out.

CoordinateLearning(pm, random_state=None, device=None)

Bases: BaseEmbedder

Coordinate learning method class.

This embedder implements the approach described in Gu et al., "Learning Mixed-Curvature Representations in Product Spaces". It directly optimizes point coordinates to preserve a given distance matrix, using Riemannian optimization techniques.

Trains point coordinates in a product manifold to match target distances.

This class optimizes the coordinates of points in a product manifold to match a given distance matrix. The optimization is performed in two phases:

  1. Burn-in phase: Initial optimization with a smaller learning rate to find a good starting configuration.
  2. Training phase: Fine-tuning of the coordinates with a larger learning rate, and optionally optimizing the scale factors (curvatures) of the manifold components.

The optimization uses Riemannian Adam optimizer to respect the manifold structure during gradient updates. The loss is computed based on the distortion between the pairwise distances in the embedding and the target distances.

For non-transductive settings, the class supports split between training and testing points, optimizing different combinations of distances (train-train, test-test, train-test).

Attributes:
  • pm

    Product manifold defining the target embedding space.

  • embeddings_

    Optimized point coordinates after fitting.

  • loss_history_ (dict[str, list[float]]) –

    Training loss history.

  • is_fitted_ (bool) –

    Boolean flag indicating if the embedder has been fitted.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object defining the target embedding space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str | None, default: None ) –

    Optional device for tensor computations.

Source code in manify/embedders/coordinate_learning.py
69
70
def __init__(self, pm: ProductManifold, random_state: int | None = None, device: str | None = None) -> None:
    super().__init__(pm=pm, random_state=random_state, device=device)
fit(X, D, test_indices=None, lr=0.01, burn_in_lr=0.001, curvature_lr=0.0, burn_in_iterations=2000, training_iterations=18000, loss_window_size=100, logging_interval=10)

Fit the Coordinate Learning Embedder. Sets attributes embeddings_, loss_history_, and is_fitted_.

Parameters:
  • X (None) –

    Ignored.

  • D (Float[Tensor, 'n_points n_points']) –

    Tensor representing the target pairwise distance matrix between points.

  • test_indices (Int[Tensor, 'n_test'] | None, default: None ) –

    Tensor containing indices of test points for transductive learning. Defaults to an empty tensor (all points are used for training).

  • lr (float, default: 0.01 ) –

    Learning rate for the main training phase.

  • burn_in_lr (float, default: 0.001 ) –

    Learning rate for the burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for optimizing manifold scale factors. Off (no learning) by default.

  • burn_in_iterations (int, default: 2000 ) –

    Number of iterations for the burn-in phase.

  • training_iterations (int, default: 18000 ) –

    Number of iterations for the main training phase.

  • loss_window_size (int, default: 100 ) –

    Window size for computing moving average loss.

  • logging_interval (int, default: 10 ) –

    Interval for logging training progress.

Returns:
  • self( 'CoordinateLearning' ) –

    Fitted embedder instance.

Raises:
  • ValueError

    If the distance matrix D is None or if X is provided.

  • Warning

    If X is provided, it will be ignored during fitting.

Source code in manify/embedders/coordinate_learning.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def fit(  # type: ignore[override]
    self,
    X: None,
    D: Float[torch.Tensor, "n_points n_points"],
    test_indices: Int[torch.Tensor, "n_test"] | None = None,
    lr: float = 1e-2,
    burn_in_lr: float = 1e-3,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 2_000,
    training_iterations: int = 18_000,
    loss_window_size: int = 100,
    logging_interval: int = 10,
) -> "CoordinateLearning":
    """Fit the Coordinate Learning Embedder. Sets attributes `embeddings_`, `loss_history_`, and `is_fitted_`.

    Args:
        X: Ignored.
        D: Tensor representing the target pairwise distance matrix between points.
        test_indices: Tensor containing indices of test points for transductive learning.
            Defaults to an empty tensor (all points are used for training).
        lr: Learning rate for the main training phase.
        burn_in_lr: Learning rate for the burn-in phase.
        curvature_lr: Learning rate for optimizing manifold scale factors. Off (no learning) by default.
        burn_in_iterations: Number of iterations for the burn-in phase.
        training_iterations: Number of iterations for the main training phase.
        loss_window_size: Window size for computing moving average loss.
        logging_interval: Interval for logging training progress.

    Returns:
        self: Fitted embedder instance.

    Raises:
        ValueError: If the distance matrix D is None or if X is provided.
        Warning: If X is provided, it will be ignored during fitting.
    """
    # Input validation
    if D is None:
        raise ValueError("Distance matrix D is needed for coordinate learning")
    if X is not None:
        warnings.warn(
            "Input X has been given. This will be ignored during fitting. If you have provided a distance matrix,please run embedder.fit(None, D) instead.",
            stacklevel=2,
        )

    # Set random seed if provided
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Move everything to the device; initialize random embeddings
    n = D.shape[0]
    covs = [torch.stack([torch.eye(M.dim) / self.pm.dim] * n).to(self.device) for M in self.pm.P]
    means = torch.vstack([self.pm.mu0] * n).to(self.device)
    X_embed = self.pm.sample(z_mean=means, sigma_factorized=covs)
    D = D.to(self.device)

    # Get train and test indices set up
    test_indices = test_indices if test_indices is not None else torch.tensor([])
    use_test = len(test_indices) > 0
    test = torch.tensor([i in test_indices for i in range(len(D))]).to(self.device)
    train = ~test

    # Initialize optimizer
    X_embed = geoopt.ManifoldParameter(X_embed, manifold=self.pm.manifold)
    ropt = geoopt.optim.RiemannianAdam(
        [{"params": [X_embed], "lr": burn_in_lr}, {"params": self.pm.parameters(), "lr": 0}]
    )

    # Init TQDM
    my_tqdm = tqdm(total=burn_in_iterations + training_iterations, leave=False)

    # Outer training loop - mostly setting optimizer learning rates up here
    losses: dict[str, list[float]] = {"train_train": [], "test_test": [], "train_test": [], "total": []}

    # Actual training loop
    for i in range(burn_in_iterations + training_iterations):
        if i == burn_in_iterations:
            # Optimize curvature by changing lr
            ropt.param_groups[0]["lr"] = lr
            ropt.param_groups[1]["lr"] = curvature_lr

        # Zero grad
        ropt.zero_grad()

        # 1. Train-train loss
        X_t = X_embed[train]
        D_tt = self.pm.pdist(X_t)
        L_tt = distortion_loss(D_tt, D[train][:, train], pairwise=True)
        L_tt.backward(retain_graph=True)
        losses["train_train"].append(L_tt.item())

        if use_test:
            # 2. Test-test loss
            X_q = X_embed[test]
            D_qq = self.pm.pdist(X_q)
            L_qq = distortion_loss(D_qq, D[test][:, test], pairwise=True)
            L_qq.backward(retain_graph=True)
            losses["test_test"].append(L_qq.item())

            # 3. Train-test loss
            X_t_detached = X_embed[train].detach()
            D_tq = self.pm.dist(X_t_detached, X_q)  # Note 'dist' not 'pdist', as we're comparing different sets
            L_tq = distortion_loss(D_tq, D[train][:, test], pairwise=False)
            L_tq.backward()
            losses["train_test"].append(L_tq.item())
        else:
            L_qq = 0
            L_tq = 0

        # Step
        ropt.step()
        L = L_tt + L_qq + L_tq
        losses["total"].append(L.item())

        # TQDM management
        my_tqdm.update(1)
        my_tqdm.set_description(f"Loss: {L.item():.3e}")

        # Logging
        if i % logging_interval == 0:
            d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
            d["D_avg"] = f"{d_avg(D_tt, D[train][:, train], pairwise=True):.4f}"
            d["L_avg"] = f"{np.mean(losses['total'][-loss_window_size:]):.3e}"
            my_tqdm.set_postfix(d)

        # Early stopping for errors
        if torch.isnan(L):
            raise ValueError("Loss is NaN")

    # Final maintenance: update attributes
    self.embeddings_ = X_embed.data.detach()
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self
transform(X=None)

Transform data using learned embedding. This is not meaningful for new data during coordinate learning.

Parameters:
  • X (None, default: None ) –

    Ignored.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Raises:
  • ValueError

    If the embedder has not been fitted yet.

  • Warning

    If X is provided, as it will be ignored.

Source code in manify/embedders/coordinate_learning.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def transform(self, X: None = None) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using learned embedding. This is not meaningful for new data during coordinate learning.

    Args:
        X: Ignored.

    Returns:
        embeddings: Learned embeddings.

    Raises:
        ValueError: If the embedder has not been fitted yet.
        Warning: If X is provided, as it will be ignored.
    """
    if not self.is_fitted_:
        raise ValueError("The embedder has not been fitted yet.")

    if X is not None:
        warnings.warn("Coordinate learning can only return trained embeddings. X will be ignored.", stacklevel=2)

    return self.embeddings_
fit_transform(X, D, **fit_kwargs)

Transform data using learned embedding based on the provided distance matrix D.

This method overrides the base class method BaseEmbedder.fit_transform() to not use the input data X.

Parameters:
  • X (None) –

    Ignored.

  • D (Float[Tensor, 'n_points n_points']) –

    Distance matrix for the points.

  • fit_kwargs (Any, default: {} ) –

    Additional keyword arguments passed to the model.fit() method.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Source code in manify/embedders/coordinate_learning.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def fit_transform(  # type: ignore[override]
    self, X: None, D: Float[torch.Tensor, "n_points n_points"], **fit_kwargs: Any
) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using learned embedding based on the provided distance matrix D.

    This method overrides the base class method `BaseEmbedder.fit_transform()` to not use the input data X.

    Args:
        X: Ignored.
        D: Distance matrix for the points.
        fit_kwargs: Additional keyword arguments passed to the `model.fit()` method.

    Returns:
        embeddings: Learned embeddings.
    """
    return self.fit(X=None, D=D, **fit_kwargs).transform(X=None)

siamese

Siamese network implementation for manifold embedding.

This module provides a Siamese network architecture that can be used for embedding data into product manifolds. Siamese networks are particularly useful for metric learning tasks, where the goal is to learn a distance-preserving embedding, while also encoding a set of features.

The SiameseNetwork class supports both encoding (embedding) data into a manifold space and optionally decoding (reconstructing) from the embedding space back to the original data space.

SiameseNetwork(pm, encoder, decoder=None, reconstruction_loss='mse', beta=1.0, random_state=None, device='cpu')

Bases: BaseEmbedder, Module

Siamese network for embedding data into a product manifold space.

A Siamese network consists of an encoder network that maps input data to a latent representation in a product manifold, and optionally a decoder network that maps the latent representation back to the original feature space.

Attributes:
  • pm

    Product manifold defining the structure of the latent space.

  • random_state

    Random state for reproducibility.

  • encoder

    Neural network that maps inputs to latent embeddings.

  • decoder

    Neural network that reconstructs inputs from latent embeddings.

  • beta

    Weight for the distortion term in the loss function.

  • device

    Device for tensor computations.

  • reconstruction_loss

    Type of reconstruction loss to use.

Parameters:
  • pm (ProductManifold) –

    Product manifold defining the structure of the latent space.

  • encoder (Module) –

    Neural network module that maps inputs to the manifold's intrinsic dimension. The output dimension should match the intrinsic dimension of the product manifold.

  • decoder (Module | None, default: None ) –

    Neural network module that maps latent representations back to the input space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str, default: 'cpu' ) –

    Optional device for tensor computations.

  • beta (float, default: 1.0 ) –

    Weight of the distortion term in the loss function.

  • reconstruction_loss (str, default: 'mse' ) –

    Type of reconstruction loss to use.

Source code in manify/embedders/siamese.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    pm: ProductManifold,
    encoder: torch.nn.Module,
    decoder: torch.nn.Module | None = None,
    reconstruction_loss: str = "mse",
    beta: float = 1.0,
    random_state: int | None = None,
    device: str = "cpu",
):
    # Init both base classes
    torch.nn.Module.__init__(self)
    BaseEmbedder.__init__(self, pm=pm, random_state=random_state, device=device)

    # Now we assign
    self.pm = pm
    self.encoder = encoder
    self.beta = beta

    if decoder is not None:
        self.decoder = decoder
    else:
        self.decoder = torch.nn.Identity()
        self.decoder.requires_grad_(False)
        self.decoder.to(pm.device)

    if reconstruction_loss == "mse":
        self.reconstruction_loss = torch.nn.MSELoss(reduction="none")
    else:
        raise ValueError(f"Unknown reconstruction loss: {reconstruction_loss}")
encode(x)

Encodes input data into the manifold embedding space.

Takes a batch of input data and passes it through the encoder network to obtain embeddings in the manifold.

Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor..

Returns:
  • embeddings( Float[Tensor, 'batch_size n_latent'] ) –

    Tensor containing the embeddings in the manifold space.

Source code in manify/embedders/siamese.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def encode(self, x: Float[torch.Tensor, "batch_size n_features"]) -> Float[torch.Tensor, "batch_size n_latent"]:
    """Encodes input data into the manifold embedding space.

    Takes a batch of input data and passes it through the encoder network to obtain embeddings in the manifold.

    Args:
        x: Input data tensor..

    Returns:
        embeddings: Tensor containing the embeddings in the manifold space.
    """
    return self.encoder(x)
decode(z)

Decodes manifold embeddings back to the original input space.

Takes a batch of embeddings from the manifold space and passes them through the decoder network to reconstruct the original input data.

Parameters:
  • z (Float[Tensor, 'batch_size n_latent']) –

    Embedding tensor from the manifold space.

Returns:
  • reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Tensor containing the reconstructed input data.

Source code in manify/embedders/siamese.py
104
105
106
107
108
109
110
111
112
113
114
115
116
def decode(self, z: Float[torch.Tensor, "batch_size n_latent"]) -> Float[torch.Tensor, "batch_size n_features"]:
    """Decodes manifold embeddings back to the original input space.

    Takes a batch of embeddings from the manifold space and passes them through
    the decoder network to reconstruct the original input data.

    Args:
        z: Embedding tensor from the manifold space.

    Returns:
        reconstructed: Tensor containing the reconstructed input data.
    """
    return self.decoder(z)
forward(x1, x2)

Given two points, return their encodings, reconstructions, and embedding distance.

Parameters:
  • x1 (Float[Tensor, 'batch_size n_features']) –

    First input tensor.

  • x2 (Float[Tensor, 'batch_size n_features']) –

    Second input tensor.

Returns:
  • z1( Float[Tensor, 'batch_size n_latent'] ) –

    Encoded representation of the first input.

  • z2( Float[Tensor, 'batch_size n_latent'] ) –

    Encoded representation of the second input.

  • D_hat( Float[Tensor, 'batch_size'] ) –

    Estimated distance between the two embeddings.

  • reconstructed1( Float[Tensor, 'batch_size n_features'] ) –

    Reconstructed input from the first embedding.

  • reconstructed2( Float[Tensor, 'batch_size n_features'] ) –

    Reconstructed input from the second embedding.

Source code in manify/embedders/siamese.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def forward(
    self, x1: Float[torch.Tensor, "batch_size n_features"], x2: Float[torch.Tensor, "batch_size n_features"]
) -> Tuple[
    Float[torch.Tensor, "batch_size n_latent"],
    Float[torch.Tensor, "batch_size n_latent"],
    Float[torch.Tensor, "batch_size"],
    Float[torch.Tensor, "batch_size n_features"],
    Float[torch.Tensor, "batch_size n_features"],
]:
    """Given two points, return their encodings, reconstructions, and embedding distance.

    Args:
        x1: First input tensor.
        x2: Second input tensor.

    Returns:
        z1: Encoded representation of the first input.
        z2: Encoded representation of the second input.
        D_hat: Estimated distance between the two embeddings.
        reconstructed1: Reconstructed input from the first embedding.
        reconstructed2: Reconstructed input from the second embedding.
    """
    z1 = self.pm.expmap(self.encode(x1) @ self.pm.projection_matrix)
    z2 = self.pm.expmap(self.encode(x2) @ self.pm.projection_matrix)
    D_hat = self.pm.manifold.dist(z1, z2)  # use manifold dist to get (batch_size, ) vector of dists
    reconstructed1 = self.decode(z1)
    reconstructed2 = self.decode(z2)
    return z1, z2, D_hat, reconstructed1, reconstructed2
fit(X, D, lr=0.001, burn_in_lr=0.0001, curvature_lr=0.0, burn_in_iterations=1, training_iterations=9, loss_window_size=100, logging_interval=10, batch_size=32, clip_grad=True)

Fit the SiameseNetwork embedder.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Input data features to encode.

  • D (Float[Tensor, 'n_points n_points']) –

    Pairwise distances to emulate.

  • lr (float, default: 0.001 ) –

    Learning rate for the optimizer.

  • burn_in_lr (float, default: 0.0001 ) –

    Learning rate during burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for curvature updates.

  • burn_in_iterations (int, default: 1 ) –

    Number of iterations for burn-in phase.

  • training_iterations (int, default: 9 ) –

    Number of iterations for training phase.

  • loss_window_size (int, default: 100 ) –

    Size of the window for loss averaging.

  • logging_interval (int, default: 10 ) –

    Interval for logging progress.

  • batch_size (int, default: 32 ) –

    Number of samples per batch.

  • clip_grad (bool, default: True ) –

    Whether to clip gradients.

Returns:
  • self( 'SiameseNetwork' ) –

    Fitted SiameseNetwork instance.

Source code in manify/embedders/siamese.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def fit(  # type: ignore[override]
    self,
    X: Float[torch.Tensor, "n_points n_features"],
    D: Float[torch.Tensor, "n_points n_points"],
    lr: float = 1e-3,
    burn_in_lr: float = 1e-4,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 1,
    training_iterations: int = 9,
    loss_window_size: int = 100,
    logging_interval: int = 10,
    batch_size: int = 32,
    clip_grad: bool = True,
) -> "SiameseNetwork":
    """Fit the SiameseNetwork embedder.

    Args:
        X: Input data features to encode.
        D: Pairwise distances to emulate.
        lr: Learning rate for the optimizer.
        burn_in_lr: Learning rate during burn-in phase.
        curvature_lr: Learning rate for curvature updates.
        burn_in_iterations: Number of iterations for burn-in phase.
        training_iterations: Number of iterations for training phase.
        loss_window_size: Size of the window for loss averaging.
        logging_interval: Interval for logging progress.
        batch_size: Number of samples per batch.
        clip_grad: Whether to clip gradients.

    Returns:
        self: Fitted SiameseNetwork instance.
    """
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    n_samples = len(X)

    # Generate all upper triangular pairs using torch
    indices = torch.triu_indices(n_samples, n_samples, offset=1)
    pairs = torch.hstack([indices]).T  # (n_pairs, 2)

    # Number of pairs and batches
    n_pairs = len(pairs)
    n_batches_per_epoch = (n_pairs + batch_size - 1) // batch_size  # Ceiling division
    total_iterations = (burn_in_iterations + training_iterations) * n_batches_per_epoch

    my_tqdm = tqdm(total=total_iterations)

    opt = torch.optim.Adam(
        [
            {"params": [p for p in self.parameters() if p not in set(self.pm.parameters())], "lr": burn_in_lr},
            {"params": self.pm.parameters(), "lr": 0},
        ]
    )
    losses: Dict[str, List[float]] = {"total": [], "reconstruction": [], "distortion": []}

    for epoch in range(burn_in_iterations + training_iterations):
        if epoch == burn_in_iterations:
            opt.param_groups[0]["lr"] = lr
            opt.param_groups[1]["lr"] = curvature_lr

        # Shuffle all pairs
        shuffle_idx = torch.randperm(n_pairs)
        shuffled_pairs = pairs[shuffle_idx]

        for batch_start in range(0, n_pairs, batch_size):
            batch_end = min(batch_start + batch_size, n_pairs)
            batch_pairs = shuffled_pairs[batch_start:batch_end]

            # Extract indices for this batch
            batch_indices1 = batch_pairs[:, 0]
            batch_indices2 = batch_pairs[:, 1]

            # Get data for these indices
            X1 = X[batch_indices1]
            X2 = X[batch_indices2]

            # Extract the corresponding distances from D using advanced indexing
            D_batch = D[batch_indices1, batch_indices2]

            # Forward pass
            opt.zero_grad()
            _, _, D_hat, Y1, Y2 = self(X1, X2)
            mse1 = torch.nn.functional.mse_loss(Y1, X1)
            mse2 = torch.nn.functional.mse_loss(Y2, X2)

            # D_hat and D_batch are now 1D tensors of pairwise distances
            distortion = distortion_loss(D_hat, D_batch, pairwise=False)
            L = mse1 + mse2 + self.beta * distortion
            L.backward()

            # Add to losses
            losses["total"].append(L.item())
            losses["reconstruction"].append(mse1.item() + mse2.item())
            losses["distortion"].append(distortion.item())

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.pm.parameters(), max_norm=1.0)

            opt.step()

            # TQDM management
            my_tqdm.update(1)
            my_tqdm.set_description(
                f"L: {L.item():.3e}, recon: {mse1.item() + mse2.item():.3e}, dist: {distortion.item():.3e}"
            )

            # Logging
            if my_tqdm.n % logging_interval == 0:
                d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
                d["L_avg"] = f"{np.mean(losses['total'][-loss_window_size:]):.3e}"
                d["recon_avg"] = f"{np.mean(losses['reconstruction'][-loss_window_size:]):.3e}"
                d["dist_avg"] = f"{np.mean(losses['distortion'][-loss_window_size:]):.3e}"
                my_tqdm.set_postfix(d)

    # Final maintenance: update attributes
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self
transform(X, D=None, batch_size=32, expmap=True)

Transforms input data into manifold embeddings.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Features to embed with SiameseNetwork.

  • D (None, default: None ) –

    Ignored.

  • batch_size (int, default: 32 ) –

    Number of samples per batch.

  • expmap (bool, default: True ) –

    Whether to use exponential map for embedding.

Returns:
  • embeddings( Float[Tensor, 'n_points n_latent'] ) –

    Embeddings produced by forward pass of trained SiameseNetwork model.

Source code in manify/embedders/siamese.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def transform(
    self, X: Float[torch.Tensor, "n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
) -> Float[torch.Tensor, "n_points n_latent"]:
    """Transforms input data into manifold embeddings.

    Args:
        X: Features to embed with SiameseNetwork.
        D: Ignored.
        batch_size: Number of samples per batch.
        expmap: Whether to use exponential map for embedding.

    Returns:
        embeddings: Embeddings produced by forward pass of trained SiameseNetwork model.
    """
    # Set random state
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Save the  embeddings
    embeddings_list = []
    for i in range(0, len(X), batch_size):
        batch = X[i : i + batch_size]
        embeddings = self.encode(batch)
        if expmap:
            embeddings = self.pm.expmap(embeddings @ self.pm.projection_matrix)
        embeddings_list.append(embeddings)
    embeddings = torch.cat(embeddings_list, dim=0)

    return embeddings

vae

Variational autoencoder implementation for product manifold spaces.

This module provides a variational autoencoder (VAE) implementation specifically designed for learning representations in mixed-curvature product spaces. The implementation handles the complexities of sampling, KL divergence calculation, and reparameterization in curved spaces, supporting combinations of hyperbolic, Euclidean, and spherical geometries within a single latent space.

For more information, see Skopek et al (2020): Mixed Curvature Variational Autoencoders

ProductSpaceVAE(pm, encoder, decoder, random_state=None, device='cpu', beta=1.0, reconstruction_loss=None, n_samples=16)

Bases: BaseEmbedder, Module

Product Space Variational Autoencoder.

The probabilistic model is defined as:

  • Prior: \(p(z) = \mathcal{WN}(z; \mu_0, I)\) (wrapped normal distribution centered at manifold origin)
  • Likelihood: \(p_\theta(x|z) = \mathcal{N}(x; f_\theta(z), \sigma^2 I)\) or other reconstruction distribution
  • Posterior approximation: \(q_\phi(z|x) = \mathcal{WN}(z; \mu_\phi(x), \Sigma_\phi(x))\)

where \(\mathcal{WN}\) is a wrapped normal distribution on the manifold.

The model is trained by maximizing the evidence lower bound (ELBO):

\(\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z))\)

Attributes:
  • pm

    Product manifold defining the structure of the latent space.

  • random_state

    Random state for reproducibility.

  • encoder

    Neural network that outputs mean and log-variance parameters.

  • decoder

    Neural network that reconstructs inputs from latent embeddings.

  • beta

    Weight for the KL divergence term in the ELBO.

  • device

    Device for tensor computations.

  • n_samples

    Number of samples for Monte Carlo estimation of KL divergence.

  • reconstruction_loss

    Type of reconstruction loss to use.

  • loss_history_

    Dictionary to store the history of loss values during training.

  • is_fitted_

    Boolean flag indicating whether the model has been fitted.

Parameters:
  • pm (ProductManifold) –

    Product manifold defining the structure of the latent space.

  • encoder (Module) –

    Neural network module that produces mean (first half of output) and log-variance (second half of output) of the posterior distribution. The output dimension should match twice the intrinsic dimension of the product manifold.

  • decoder (Module) –

    Neural network module that maps latent representations back to the input space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str, default: 'cpu' ) –

    Optional device for tensor computations.

  • beta (float, default: 1.0 ) –

    Weight of the KL divergence term in the ELBO loss. Values < 1 give a \(\beta\)-VAE with a looser constraint on the latent space.

  • reconstruction_loss (_Loss | None, default: None ) –

    Type of reconstruction loss to use. Currently only "mse" (mean squared error) is supported.

  • n_samples (int, default: 16 ) –

    Number of Monte Carlo samples to use when estimating the KL divergence.

Source code in manify/embedders/vae.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    pm: ProductManifold,
    encoder: torch.nn.Module,
    decoder: torch.nn.Module,
    random_state: int | None = None,
    device: str = "cpu",
    beta: float = 1.0,
    reconstruction_loss: torch.nn.modules.loss._Loss | None = None,
    n_samples: int = 16,
):
    # Init both base classes
    torch.nn.Module.__init__(self)
    BaseEmbedder.__init__(self, pm=pm, random_state=random_state, device=device)

    # Now we assign
    self.encoder = encoder.to(device)
    self.decoder = decoder.to(device)
    self.beta = beta
    self.n_samples = n_samples
    self.reconstruction_loss = (
        reconstruction_loss if reconstruction_loss is not None else torch.nn.MSELoss(reduction="none")
    )
    self.model_ = None
    self.loss_history_ = {}
    self.is_fitted_ = False

    # Ensure encoder last dimension is 2 * pm.intrinsic_dim:
    assert encoder[-1].out_features == 2 * pm.dim, "Encoder output must match 2 * intrinsic dimension of manifold."

    # Ensure decoder input dimension is pm.intrinsic_dim
    assert decoder[0].in_features == pm.ambient_dim, "Decoder input must match ambient dimension of manifold."
encode(x)

Encodes input data to obtain latent means and log-variances in the manifold.

This method processes input data through the encoder network to obtain parameters of the approximate posterior distribution \(q(z|x)\) in the product manifold space. For non-Euclidean components, the method:

  1. Gets tangent space vectors and log-variances from the encoder,
  2. Projects tangent vectors to the ambient space by adding zeros in the right places, and
  3. Maps the ambient space vectors to the manifold using the exponential map
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • z_mean_tangent( Float[Tensor, 'batch_size n_latent'] ) –

    Mean of the posterior distribution in the tangent plane at the origin.

  • z_logvar( Float[Tensor, 'batch_size n_latent'] ) –

    Log-variance of the posterior distribution, used for constructing the covariance matrices.

Source code in manify/embedders/vae.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def encode(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[Float[torch.Tensor, "batch_size n_latent"], Float[torch.Tensor, "batch_size n_latent"]]:
    r"""Encodes input data to obtain latent means and log-variances in the manifold.

    This method processes input data through the encoder network to obtain parameters of the approximate posterior
    distribution $q(z|x)$ in the product manifold space. For non-Euclidean components, the method:

    1. Gets tangent space vectors and log-variances from the encoder,
    2. Projects tangent vectors to the ambient space by adding zeros in the right places, and
    3. Maps the ambient space vectors to the manifold using the exponential map

    Args:
        x: Input data tensor.

    Returns:
        z_mean_tangent: Mean of the posterior distribution in the tangent plane at the origin.
        z_logvar: Log-variance of the posterior distribution, used for constructing the covariance matrices.
    """
    z = self.encoder(x)
    z_mean_tangent, z_logvar = z[..., : self.pm.dim], z[..., self.pm.dim :]
    # z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
    # z_mean = self.pm.expmap(u=z_mean_ambient, base=None)
    return z_mean_tangent, z_logvar
decode(z)

Decodes latent points from the manifold space back to the input space.

Takes points from the product manifold latent space and passes them through the decoder network to reconstruct the original input data.

Parameters:
  • z (Float[Tensor, 'batch_size n_ambient']) –

    Latent points in the product manifold

Returns:
  • reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Tensor containing the reconstructed input data, with shape (batch_size, n_features).

Source code in manify/embedders/vae.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def decode(self, z: Float[torch.Tensor, "batch_size n_ambient"]) -> Float[torch.Tensor, "batch_size n_features"]:
    """Decodes latent points from the manifold space back to the input space.

    Takes points from the product manifold latent space and passes them through
    the decoder network to reconstruct the original input data.

    Args:
        z: Latent points in the product manifold

    Returns:
        reconstructed: Tensor containing the reconstructed input data,
            with shape (batch_size, n_features).
    """
    return self.decoder(z)
forward(x)

Performs the forward pass of the VAE in product manifold space.

This method implements the complete VAE forward pass, with manifold projection:

  1. Encode the input to get posterior parameters (z_means, z_logvars)
  2. Project means onto the manifold using exponential map
  3. Factorize the log-variances for each manifold component and convert to covariance matrices
  4. Sample points from the posterior distributions in the product manifold
  5. Decode the sampled points to get reconstructions
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • x_reconstructed( Float[Tensor, 'batch_size n_features'] ) –

    Reconstructed data tensor with the same shape as the input.

  • z_means( Float[Tensor, 'batch_size n_ambient'] ) –

    Means of the posterior distributions in the manifold space.

  • sigmas( list[Float[Tensor, 'batch_size n_latent n_latent']] ) –

    List of covariance matrices for each manifold component.

Source code in manify/embedders/vae.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def forward(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[
    Float[torch.Tensor, "batch_size n_features"],
    Float[torch.Tensor, "batch_size n_ambient"],
    list[Float[torch.Tensor, "batch_size n_latent n_latent"]],
]:
    r"""Performs the forward pass of the VAE in product manifold space.

    This method implements the complete VAE forward pass, with manifold projection:

    1. Encode the input to get posterior parameters (`z_means`, `z_logvars`)
    2. Project means onto the manifold using exponential map
    3. Factorize the log-variances for each manifold component and convert to covariance matrices
    4. Sample points from the posterior distributions in the product manifold
    5. Decode the sampled points to get reconstructions

    Args:
        x: Input data tensor.

    Returns:
        x_reconstructed: Reconstructed data tensor with the same shape as the input.
        z_means: Means of the posterior distributions in the manifold space.
        sigmas: List of covariance matrices for each manifold component.
    """
    z_mean_tangent, z_logvars = self.encode(x)

    # Need to convert from implicit parameterization to extrinsic coordinates
    z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
    z_means = self.pm.expmap(u=z_mean_ambient, base=None)

    # Factorize log-variances; convert to covariances
    sigma_factorized = self.pm.factorize(z_logvars, intrinsic=True)
    sigmas = [torch.diag_embed(torch.exp(z_logvar) + 1e-8) for z_logvar in sigma_factorized]

    # Sample and decode
    z = self.pm.sample(z_mean=z_means, sigma_factorized=sigmas)
    x_reconstructed = self.decode(z)
    return x_reconstructed, z_means, sigmas
kl_divergence(z_mean, sigma_factorized)

Computes the KL divergence between posterior and prior distributions in the manifold.

For distributions in Riemannian manifolds, computing the KL divergence analytically is often intractable. This method uses Monte Carlo sampling to approximate the KL divergence:

\[D_{KL}(q(z|x) || p(z)) \approx \frac{1}{N} \sum_{i=1}^{N} [\log q(z_i|x) - \log p(z_i)]\]

where \(z_i\) are samples from \(q(z|x)\).

This implementation follows the approach described in: http://joschu.net/blog/kl-approx.html

Parameters:
  • z_mean (Float[Tensor, 'batch_size n_latent']) –

    Means of the posterior distributions in the manifold.

  • sigma_factorized (list[Float[Tensor, 'batch_size manifold_dim manifold_dim']]) –

    List of covariance matrices for each manifold component.

Returns:
  • kl_divergence( Float[Tensor, 'batch_size'] ) –

    KL divergence values for each data point in the batch.

Source code in manify/embedders/vae.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def kl_divergence(
    self,
    z_mean: Float[torch.Tensor, "batch_size n_latent"],
    sigma_factorized: list[Float[torch.Tensor, "batch_size manifold_dim manifold_dim"]],
) -> Float[torch.Tensor, "batch_size"]:
    r"""Computes the KL divergence between posterior and prior distributions in the manifold.

    For distributions in Riemannian manifolds, computing the KL divergence analytically
    is often intractable. This method uses Monte Carlo sampling to approximate the KL divergence:

    $$D_{KL}(q(z|x) || p(z)) \approx \frac{1}{N} \sum_{i=1}^{N} [\log q(z_i|x) - \log p(z_i)]$$

    where $z_i$ are samples from $q(z|x)$.

    This implementation follows the approach described in:
    http://joschu.net/blog/kl-approx.html

    Args:
        z_mean: Means of the posterior distributions in the manifold.
        sigma_factorized: List of covariance matrices for each manifold component.

    Returns:
        kl_divergence: KL divergence values for each data point in the batch.
    """
    # Get KL divergence as the average of log q(z|x) - log p(z)
    means = torch.repeat_interleave(z_mean, self.n_samples, dim=0)
    sigmas_factorized_interleaved = [
        torch.repeat_interleave(sigma, self.n_samples, dim=0) for sigma in sigma_factorized
    ]
    # We want to use n_samples = 1 here, since we'll need to pass the interleaved means/sigmas to the log-likelihood
    z_samples = self.pm.sample(z_mean=means, sigma_factorized=sigmas_factorized_interleaved)
    log_qz = self.pm.log_likelihood(z_samples, means, sigmas_factorized_interleaved)
    log_pz = self.pm.log_likelihood(z_samples)
    return (log_qz - log_pz).view(-1, self.n_samples).mean(dim=1)
elbo(x)

Computes the Evidence Lower Bound (ELBO) for the VAE objective.

The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term (log-likelihood) and a regularization term (KL divergence):

\[\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z)),\]

where:

  • \(\theta\) are the decoder parameters
  • \(\phi\) are the encoder parameters
  • \(\beta\) is a weight for the KL term (setting \(\beta < 1\) creates a \(\beta\)-VAE)
Parameters:
  • x (Float[Tensor, 'batch_size n_features']) –

    Input data tensor.

Returns:
  • elbo( Float[Tensor, ''] ) –

    Mean ELBO value across the batch (higher is better).

  • log_likelihood( Float[Tensor, ''] ) –

    Mean reconstruction log-likelihood across the batch.

  • kl_divergence( Float[Tensor, ''] ) –

    Mean KL divergence across the batch.

Source code in manify/embedders/vae.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def elbo(
    self, x: Float[torch.Tensor, "batch_size n_features"]
) -> tuple[Float[torch.Tensor, ""], Float[torch.Tensor, ""], Float[torch.Tensor, ""]]:
    r"""Computes the Evidence Lower Bound (ELBO) for the VAE objective.

    The ELBO is the standard objective function for variational autoencoders, consisting of a reconstruction term
    (log-likelihood) and a regularization term (KL divergence):

    $$\mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \beta \cdot D_{KL}(q_\phi(z|x) || p(z)),$$

    where:

    - $\theta$ are the decoder parameters
    - $\phi$ are the encoder parameters
    - $\beta$ is a weight for the KL term (setting $\beta < 1$ creates a $\beta$-VAE)

    Args:
        x: Input data tensor.

    Returns:
        elbo: Mean ELBO value across the batch (higher is better).
        log_likelihood: Mean reconstruction log-likelihood across the batch.
        kl_divergence: Mean KL divergence across the batch.
    """
    x_reconstructed, z_means, sigma_factorized = self(x)
    kld = self.kl_divergence(z_means, sigma_factorized)
    ll = -self.reconstruction_loss(x_reconstructed.view(x.shape[0], -1), x.view(x.shape[0], -1)).sum(dim=1)
    return (ll - self.beta * kld).mean(), ll.mean(), kld.mean()
fit(X, D=None, lr=0.001, burn_in_lr=0.0001, curvature_lr=0.0, burn_in_iterations=1, training_iterations=9, loss_window_size=100, logging_interval=10, batch_size=32, clip_grad=True)

Trains the VAE model on the provided data.

The training process consists of two phases:

  1. Burn-in phase: Initial training with a lower learning rate for stability
  2. Main training phase: Training with the full learning rate and optional curvature optimization

Training uses Adam optimizer with gradient clipping to prevent exploding gradients. During training, the model maximizes the Evidence Lower Bound (ELBO).

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Training data tensor.

  • D (None, default: None ) –

    Ignored.

  • lr (float, default: 0.001 ) –

    Learning rate for the main training phase.

  • burn_in_lr (float, default: 0.0001 ) –

    Learning rate for the burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for optimizing manifold scale factors. Off (no learning) by default.

  • burn_in_iterations (int, default: 1 ) –

    Number of iterations for the burn-in phase.

  • training_iterations (int, default: 9 ) –

    Number of iterations for the main training phase.

  • loss_window_size (int, default: 100 ) –

    Window size for computing moving average loss.

  • logging_interval (int, default: 10 ) –

    Interval for logging training progress.

  • batch_size (int, default: 32 ) –

    Batch size for training.

  • clip_grad (bool, default: True ) –

    Whether to apply gradient clipping.

Returns:
  • losses( 'ProductSpaceVAE' ) –

    List of loss values recorded during training.

Source code in manify/embedders/vae.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def fit(  # type: ignore[override]
    self,
    X: Float[torch.Tensor, "n_points n_features"],
    D: None = None,
    lr: float = 1e-3,
    burn_in_lr: float = 1e-4,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 1,
    training_iterations: int = 9,
    loss_window_size: int = 100,
    logging_interval: int = 10,
    batch_size: int = 32,
    clip_grad: bool = True,
) -> "ProductSpaceVAE":
    """Trains the VAE model on the provided data.

    The training process consists of two phases:

    1. Burn-in phase: Initial training with a lower learning rate for stability
    2. Main training phase: Training with the full learning rate and optional curvature optimization

    Training uses Adam optimizer with gradient clipping to prevent exploding gradients. During training, the model
    maximizes the Evidence Lower Bound (ELBO).

    Args:
        X: Training data tensor.
        D: Ignored.
        lr: Learning rate for the main training phase.
        burn_in_lr: Learning rate for the burn-in phase.
        curvature_lr: Learning rate for optimizing manifold scale factors. Off (no learning) by default.
        burn_in_iterations: Number of iterations for the burn-in phase.
        training_iterations: Number of iterations for the main training phase.
        loss_window_size: Window size for computing moving average loss.
        logging_interval: Interval for logging training progress.
        batch_size: Batch size for training.
        clip_grad: Whether to apply gradient clipping.

    Returns:
        losses: List of loss values recorded during training.
    """
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    my_tqdm = tqdm(total=(burn_in_iterations + training_iterations) * len(X))
    opt = torch.optim.Adam(
        [
            {"params": [p for p in self.parameters() if p not in set(self.pm.parameters())], "lr": burn_in_lr},
            {"params": self.pm.parameters(), "lr": 0},
        ]
    )
    losses: Dict[str, List[float]] = {"elbo": [], "ll": [], "kl": []}
    for epoch in range(burn_in_iterations + training_iterations):
        if epoch == burn_in_iterations:
            opt.param_groups[0]["lr"] = lr
            opt.param_groups[1]["lr"] = curvature_lr

        for i in range(0, len(X), batch_size):
            opt.zero_grad()
            X_batch = X[i : i + batch_size]
            elbo, ll, kl = self.elbo(X_batch)
            L = -elbo
            L.backward()

            # Add to losses
            losses["elbo"].append(elbo.item())
            losses["ll"].append(ll.item())
            losses["kl"].append(kl.item())

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(self.pm.parameters(), max_norm=1.0)
            if torch.isnan(L) or torch.isinf(L):
                print(f"Invalid loss detected at epoch {epoch}, batch {i}")
                continue
            elif self._grads_ok():
                opt.step()

            # TQDM management
            my_tqdm.update(batch_size)
            my_tqdm.set_description(f"L: {L.item():.3e}, ll: {ll.item():.3e}, kl: {kl.item():.3e}")

            # Logging
            if i % logging_interval == 0:
                d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
                # d["D_avg"] = f"{d_avg(D_tt, D[train][:, train], pairwise=True):.4f}"
                d["L_avg"] = f"{np.mean(losses['elbo'][-loss_window_size:]):.3e}"
                d["ll_avg"] = f"{np.mean(losses['ll'][-loss_window_size:]):.3e}"
                d["kl_avg"] = f"{np.mean(losses['kl'][-loss_window_size:]):.3e}"
                my_tqdm.set_postfix(d)

    # Final maintenance: update attributes
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self
transform(X, D=None, batch_size=32, expmap=True)

Transform data using the trained VAE. Outputs means of the variational distribution.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Features to embed with VAE.

  • D (None, default: None ) –

    Ignored.

  • batch_size (int, default: 32 ) –

    Number of samples per batch.

  • expmap (bool, default: True ) –

    Whether to use exponential map for embedding.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Source code in manify/embedders/vae.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def transform(
    self, X: Float[torch.Tensor, "n_points n_features"], D: None = None, batch_size: int = 32, expmap: bool = True
) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using the trained VAE. Outputs means of the variational distribution.

    Args:
        X: Features to embed with VAE.
        D: Ignored.
        batch_size: Number of samples per batch.
        expmap: Whether to use exponential map for embedding.

    Returns:
        embeddings: Learned embeddings.
    """
    # Set random state
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Save the test embeddings
    embeddings_list = []
    for i in range(0, len(X), batch_size):
        x_batch = X[i : i + batch_size]
        z_mean_tangent, _ = self.encode(x_batch)
        if expmap:
            z_mean_ambient = z_mean_tangent @ self.pm.projection_matrix  # Adds zeros in the right places
            z_mean = self.pm.expmap(u=z_mean_ambient, base=None)
        else:
            z_mean = z_mean_tangent
        embeddings_list.append(z_mean.detach().cpu())

    embeddings = torch.cat(embeddings_list, dim=0)

    return embeddings