Coordinate Learning

manify.embedders.coordinate_learning

Implementation for direct coordinate optimization in Riemannian manifolds.

This module provides functions for learning optimal embeddings in product manifolds by directly optimizing the coordinates using Riemannian optimization. This approach is particularly useful for embedding graphs using metric learning to maintain pairwise distances in the target space. The optimization is performed using Riemannian gradient descent with support for non-transductive training, in which gradients from the test set to the training set are masked out.

CoordinateLearning(pm, random_state=None, device=None)

Bases: BaseEmbedder

Coordinate learning method class.

This embedder implements the approach described in Gu et al., "Learning Mixed-Curvature Representations in Product Spaces". It directly optimizes point coordinates to preserve a given distance matrix, using Riemannian optimization techniques.

Trains point coordinates in a product manifold to match target distances.

This class optimizes the coordinates of points in a product manifold to match a given distance matrix. The optimization is performed in two phases:

  1. Burn-in phase: Initial optimization with a smaller learning rate to find a good starting configuration.
  2. Training phase: Fine-tuning of the coordinates with a larger learning rate, and optionally optimizing the scale factors (curvatures) of the manifold components.

The optimization uses Riemannian Adam optimizer to respect the manifold structure during gradient updates. The loss is computed based on the distortion between the pairwise distances in the embedding and the target distances.

For non-transductive settings, the class supports split between training and testing points, optimizing different combinations of distances (train-train, test-test, train-test).

Attributes:
  • pm

    Product manifold defining the target embedding space.

  • embeddings_

    Optimized point coordinates after fitting.

  • loss_history_ (dict[str, list[float]]) –

    Training loss history.

  • is_fitted_ (bool) –

    Boolean flag indicating if the embedder has been fitted.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object defining the target embedding space.

  • random_state (int | None, default: None ) –

    Optional random state for reproducibility.

  • device (str | None, default: None ) –

    Optional device for tensor computations.

Source code in manify/embedders/coordinate_learning.py
69
70
def __init__(self, pm: ProductManifold, random_state: int | None = None, device: str | None = None) -> None:
    super().__init__(pm=pm, random_state=random_state, device=device)

fit(X, D, test_indices=None, lr=0.01, burn_in_lr=0.001, curvature_lr=0.0, burn_in_iterations=2000, training_iterations=18000, loss_window_size=100, logging_interval=10)

Fit the Coordinate Learning Embedder. Sets attributes embeddings_, loss_history_, and is_fitted_.

Parameters:
  • X (None) –

    Ignored.

  • D (Float[Tensor, 'n_points n_points']) –

    Tensor representing the target pairwise distance matrix between points.

  • test_indices (Int[Tensor, 'n_test'] | None, default: None ) –

    Tensor containing indices of test points for transductive learning. Defaults to an empty tensor (all points are used for training).

  • lr (float, default: 0.01 ) –

    Learning rate for the main training phase.

  • burn_in_lr (float, default: 0.001 ) –

    Learning rate for the burn-in phase.

  • curvature_lr (float, default: 0.0 ) –

    Learning rate for optimizing manifold scale factors. Off (no learning) by default.

  • burn_in_iterations (int, default: 2000 ) –

    Number of iterations for the burn-in phase.

  • training_iterations (int, default: 18000 ) –

    Number of iterations for the main training phase.

  • loss_window_size (int, default: 100 ) –

    Window size for computing moving average loss.

  • logging_interval (int, default: 10 ) –

    Interval for logging training progress.

Returns:
  • self( 'CoordinateLearning' ) –

    Fitted embedder instance.

Raises:
  • ValueError

    If the distance matrix D is None or if X is provided.

  • Warning

    If X is provided, it will be ignored during fitting.

Source code in manify/embedders/coordinate_learning.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def fit(  # type: ignore[override]
    self,
    X: None,
    D: Float[torch.Tensor, "n_points n_points"],
    test_indices: Int[torch.Tensor, "n_test"] | None = None,
    lr: float = 1e-2,
    burn_in_lr: float = 1e-3,
    curvature_lr: float = 0.0,  # Off by default
    burn_in_iterations: int = 2_000,
    training_iterations: int = 18_000,
    loss_window_size: int = 100,
    logging_interval: int = 10,
) -> "CoordinateLearning":
    """Fit the Coordinate Learning Embedder. Sets attributes `embeddings_`, `loss_history_`, and `is_fitted_`.

    Args:
        X: Ignored.
        D: Tensor representing the target pairwise distance matrix between points.
        test_indices: Tensor containing indices of test points for transductive learning.
            Defaults to an empty tensor (all points are used for training).
        lr: Learning rate for the main training phase.
        burn_in_lr: Learning rate for the burn-in phase.
        curvature_lr: Learning rate for optimizing manifold scale factors. Off (no learning) by default.
        burn_in_iterations: Number of iterations for the burn-in phase.
        training_iterations: Number of iterations for the main training phase.
        loss_window_size: Window size for computing moving average loss.
        logging_interval: Interval for logging training progress.

    Returns:
        self: Fitted embedder instance.

    Raises:
        ValueError: If the distance matrix D is None or if X is provided.
        Warning: If X is provided, it will be ignored during fitting.
    """
    # Input validation
    if D is None:
        raise ValueError("Distance matrix D is needed for coordinate learning")
    if X is not None:
        warnings.warn(
            "Input X has been given. This will be ignored during fitting. If you have provided a distance matrix,please run embedder.fit(None, D) instead.",
            stacklevel=2,
        )

    # Set random seed if provided
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Move everything to the device; initialize random embeddings
    n = D.shape[0]
    covs = [torch.stack([torch.eye(M.dim) / self.pm.dim] * n).to(self.device) for M in self.pm.P]
    means = torch.vstack([self.pm.mu0] * n).to(self.device)
    X_embed = self.pm.sample(z_mean=means, sigma_factorized=covs)
    D = D.to(self.device)

    # Get train and test indices set up
    test_indices = test_indices if test_indices is not None else torch.tensor([])
    use_test = len(test_indices) > 0
    test = torch.tensor([i in test_indices for i in range(len(D))]).to(self.device)
    train = ~test

    # Initialize optimizer
    X_embed = geoopt.ManifoldParameter(X_embed, manifold=self.pm.manifold)
    ropt = geoopt.optim.RiemannianAdam(
        [{"params": [X_embed], "lr": burn_in_lr}, {"params": self.pm.parameters(), "lr": 0}]
    )

    # Init TQDM
    my_tqdm = tqdm(total=burn_in_iterations + training_iterations, leave=False)

    # Outer training loop - mostly setting optimizer learning rates up here
    losses: dict[str, list[float]] = {"train_train": [], "test_test": [], "train_test": [], "total": []}

    # Actual training loop
    for i in range(burn_in_iterations + training_iterations):
        if i == burn_in_iterations:
            # Optimize curvature by changing lr
            ropt.param_groups[0]["lr"] = lr
            ropt.param_groups[1]["lr"] = curvature_lr

        # Zero grad
        ropt.zero_grad()

        # 1. Train-train loss
        X_t = X_embed[train]
        D_tt = self.pm.pdist(X_t)
        L_tt = distortion_loss(D_tt, D[train][:, train], pairwise=True)
        L_tt.backward(retain_graph=True)
        losses["train_train"].append(L_tt.item())

        if use_test:
            # 2. Test-test loss
            X_q = X_embed[test]
            D_qq = self.pm.pdist(X_q)
            L_qq = distortion_loss(D_qq, D[test][:, test], pairwise=True)
            L_qq.backward(retain_graph=True)
            losses["test_test"].append(L_qq.item())

            # 3. Train-test loss
            X_t_detached = X_embed[train].detach()
            D_tq = self.pm.dist(X_t_detached, X_q)  # Note 'dist' not 'pdist', as we're comparing different sets
            L_tq = distortion_loss(D_tq, D[train][:, test], pairwise=False)
            L_tq.backward()
            losses["train_test"].append(L_tq.item())
        else:
            L_qq = 0
            L_tq = 0

        # Step
        ropt.step()
        L = L_tt + L_qq + L_tq
        losses["total"].append(L.item())

        # TQDM management
        my_tqdm.update(1)
        my_tqdm.set_description(f"Loss: {L.item():.3e}")

        # Logging
        if i % logging_interval == 0:
            d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
            d["D_avg"] = f"{d_avg(D_tt, D[train][:, train], pairwise=True):.4f}"
            d["L_avg"] = f"{np.mean(losses['total'][-loss_window_size:]):.3e}"
            my_tqdm.set_postfix(d)

        # Early stopping for errors
        if torch.isnan(L):
            raise ValueError("Loss is NaN")

    # Final maintenance: update attributes
    self.embeddings_ = X_embed.data.detach()
    self.loss_history_ = losses
    self.is_fitted_ = True

    return self

transform(X=None)

Transform data using learned embedding. This is not meaningful for new data during coordinate learning.

Parameters:
  • X (None, default: None ) –

    Ignored.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Raises:
  • ValueError

    If the embedder has not been fitted yet.

  • Warning

    If X is provided, as it will be ignored.

Source code in manify/embedders/coordinate_learning.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def transform(self, X: None = None) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using learned embedding. This is not meaningful for new data during coordinate learning.

    Args:
        X: Ignored.

    Returns:
        embeddings: Learned embeddings.

    Raises:
        ValueError: If the embedder has not been fitted yet.
        Warning: If X is provided, as it will be ignored.
    """
    if not self.is_fitted_:
        raise ValueError("The embedder has not been fitted yet.")

    if X is not None:
        warnings.warn("Coordinate learning can only return trained embeddings. X will be ignored.", stacklevel=2)

    return self.embeddings_

fit_transform(X, D, **fit_kwargs)

Transform data using learned embedding based on the provided distance matrix D.

This method overrides the base class method BaseEmbedder.fit_transform() to not use the input data X.

Parameters:
  • X (None) –

    Ignored.

  • D (Float[Tensor, 'n_points n_points']) –

    Distance matrix for the points.

  • fit_kwargs (Any, default: {} ) –

    Additional keyword arguments passed to the model.fit() method.

Returns:
  • embeddings( Float[Tensor, 'n_points embedding_dim'] ) –

    Learned embeddings.

Source code in manify/embedders/coordinate_learning.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def fit_transform(  # type: ignore[override]
    self, X: None, D: Float[torch.Tensor, "n_points n_points"], **fit_kwargs: Any
) -> Float[torch.Tensor, "n_points embedding_dim"]:
    """Transform data using learned embedding based on the provided distance matrix D.

    This method overrides the base class method `BaseEmbedder.fit_transform()` to not use the input data X.

    Args:
        X: Ignored.
        D: Distance matrix for the points.
        fit_kwargs: Additional keyword arguments passed to the `model.fit()` method.

    Returns:
        embeddings: Learned embeddings.
    """
    return self.fit(X=None, D=D, **fit_kwargs).transform(X=None)