Bases: BaseEmbedder
Coordinate learning method class.
This embedder implements the approach described in Gu et al., "Learning Mixed-Curvature Representations in Product
Spaces". It directly optimizes point coordinates to preserve a given distance matrix, using Riemannian optimization
techniques.
Trains point coordinates in a product manifold to match target distances.
This class optimizes the coordinates of points in a product manifold to match a given distance matrix. The
optimization is performed in two phases:
- Burn-in phase: Initial optimization with a smaller learning rate to find a good starting configuration.
- Training phase: Fine-tuning of the coordinates with a larger learning rate, and optionally optimizing the scale
factors (curvatures) of the manifold components.
The optimization uses Riemannian Adam optimizer to respect the manifold structure during gradient updates. The loss
is computed based on the distortion between the pairwise distances in the embedding and the target distances.
For non-transductive settings, the class supports split between training and testing points, optimizing different
combinations of distances (train-train, test-test, train-test).
| Attributes: |
-
pm
–
Product manifold defining the target embedding space.
-
embeddings_
–
Optimized point coordinates after fitting.
-
loss_history_
(dict[str, list[float]])
–
-
is_fitted_
(bool)
–
Boolean flag indicating if the embedder has been fitted.
|
| Parameters: |
-
pm
(ProductManifold)
–
ProductManifold object defining the target embedding space.
-
random_state
(int | None, default:
None
)
–
Optional random state for reproducibility.
-
device
(str | None, default:
None
)
–
Optional device for tensor computations.
|
Source code in manify/embedders/coordinate_learning.py
| def __init__(self, pm: ProductManifold, random_state: int | None = None, device: str | None = None) -> None:
super().__init__(pm=pm, random_state=random_state, device=device)
|
fit(X, D, test_indices=None, lr=0.01, burn_in_lr=0.001, curvature_lr=0.0, burn_in_iterations=2000, training_iterations=18000, loss_window_size=100, logging_interval=10)
Fit the Coordinate Learning Embedder. Sets attributes embeddings_, loss_history_, and is_fitted_.
| Parameters: |
-
X
(None)
–
-
D
(Float[Tensor, 'n_points n_points'])
–
Tensor representing the target pairwise distance matrix between points.
-
test_indices
(Int[Tensor, 'n_test'] | None, default:
None
)
–
Tensor containing indices of test points for transductive learning.
Defaults to an empty tensor (all points are used for training).
-
lr
(float, default:
0.01
)
–
Learning rate for the main training phase.
-
burn_in_lr
(float, default:
0.001
)
–
Learning rate for the burn-in phase.
-
curvature_lr
(float, default:
0.0
)
–
Learning rate for optimizing manifold scale factors. Off (no learning) by default.
-
burn_in_iterations
(int, default:
2000
)
–
Number of iterations for the burn-in phase.
-
training_iterations
(int, default:
18000
)
–
Number of iterations for the main training phase.
-
loss_window_size
(int, default:
100
)
–
Window size for computing moving average loss.
-
logging_interval
(int, default:
10
)
–
Interval for logging training progress.
|
| Returns: |
-
self( 'CoordinateLearning'
) –
Fitted embedder instance.
|
| Raises: |
-
ValueError
–
If the distance matrix D is None or if X is provided.
-
Warning
–
If X is provided, it will be ignored during fitting.
|
Source code in manify/embedders/coordinate_learning.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205 | def fit( # type: ignore[override]
self,
X: None,
D: Float[torch.Tensor, "n_points n_points"],
test_indices: Int[torch.Tensor, "n_test"] | None = None,
lr: float = 1e-2,
burn_in_lr: float = 1e-3,
curvature_lr: float = 0.0, # Off by default
burn_in_iterations: int = 2_000,
training_iterations: int = 18_000,
loss_window_size: int = 100,
logging_interval: int = 10,
) -> "CoordinateLearning":
"""Fit the Coordinate Learning Embedder. Sets attributes `embeddings_`, `loss_history_`, and `is_fitted_`.
Args:
X: Ignored.
D: Tensor representing the target pairwise distance matrix between points.
test_indices: Tensor containing indices of test points for transductive learning.
Defaults to an empty tensor (all points are used for training).
lr: Learning rate for the main training phase.
burn_in_lr: Learning rate for the burn-in phase.
curvature_lr: Learning rate for optimizing manifold scale factors. Off (no learning) by default.
burn_in_iterations: Number of iterations for the burn-in phase.
training_iterations: Number of iterations for the main training phase.
loss_window_size: Window size for computing moving average loss.
logging_interval: Interval for logging training progress.
Returns:
self: Fitted embedder instance.
Raises:
ValueError: If the distance matrix D is None or if X is provided.
Warning: If X is provided, it will be ignored during fitting.
"""
# Input validation
if D is None:
raise ValueError("Distance matrix D is needed for coordinate learning")
if X is not None:
warnings.warn(
"Input X has been given. This will be ignored during fitting. If you have provided a distance matrix,please run embedder.fit(None, D) instead.",
stacklevel=2,
)
# Set random seed if provided
if self.random_state is not None:
torch.manual_seed(self.random_state)
# Move everything to the device; initialize random embeddings
n = D.shape[0]
covs = [torch.stack([torch.eye(M.dim) / self.pm.dim] * n).to(self.device) for M in self.pm.P]
means = torch.vstack([self.pm.mu0] * n).to(self.device)
X_embed = self.pm.sample(z_mean=means, sigma_factorized=covs)
D = D.to(self.device)
# Get train and test indices set up
test_indices = test_indices if test_indices is not None else torch.tensor([])
use_test = len(test_indices) > 0
test = torch.tensor([i in test_indices for i in range(len(D))]).to(self.device)
train = ~test
# Initialize optimizer
X_embed = geoopt.ManifoldParameter(X_embed, manifold=self.pm.manifold)
ropt = geoopt.optim.RiemannianAdam(
[{"params": [X_embed], "lr": burn_in_lr}, {"params": self.pm.parameters(), "lr": 0}]
)
# Init TQDM
my_tqdm = tqdm(total=burn_in_iterations + training_iterations, leave=False)
# Outer training loop - mostly setting optimizer learning rates up here
losses: dict[str, list[float]] = {"train_train": [], "test_test": [], "train_test": [], "total": []}
# Actual training loop
for i in range(burn_in_iterations + training_iterations):
if i == burn_in_iterations:
# Optimize curvature by changing lr
ropt.param_groups[0]["lr"] = lr
ropt.param_groups[1]["lr"] = curvature_lr
# Zero grad
ropt.zero_grad()
# 1. Train-train loss
X_t = X_embed[train]
D_tt = self.pm.pdist(X_t)
L_tt = distortion_loss(D_tt, D[train][:, train], pairwise=True)
L_tt.backward(retain_graph=True)
losses["train_train"].append(L_tt.item())
if use_test:
# 2. Test-test loss
X_q = X_embed[test]
D_qq = self.pm.pdist(X_q)
L_qq = distortion_loss(D_qq, D[test][:, test], pairwise=True)
L_qq.backward(retain_graph=True)
losses["test_test"].append(L_qq.item())
# 3. Train-test loss
X_t_detached = X_embed[train].detach()
D_tq = self.pm.dist(X_t_detached, X_q) # Note 'dist' not 'pdist', as we're comparing different sets
L_tq = distortion_loss(D_tq, D[train][:, test], pairwise=False)
L_tq.backward()
losses["train_test"].append(L_tq.item())
else:
L_qq = 0
L_tq = 0
# Step
ropt.step()
L = L_tt + L_qq + L_tq
losses["total"].append(L.item())
# TQDM management
my_tqdm.update(1)
my_tqdm.set_description(f"Loss: {L.item():.3e}")
# Logging
if i % logging_interval == 0:
d = {f"r{i}": f"{logscale.item():.3f}" for i, logscale in enumerate(self.pm.parameters())}
d["D_avg"] = f"{d_avg(D_tt, D[train][:, train], pairwise=True):.4f}"
d["L_avg"] = f"{np.mean(losses['total'][-loss_window_size:]):.3e}"
my_tqdm.set_postfix(d)
# Early stopping for errors
if torch.isnan(L):
raise ValueError("Loss is NaN")
# Final maintenance: update attributes
self.embeddings_ = X_embed.data.detach()
self.loss_history_ = losses
self.is_fitted_ = True
return self
|
Transform data using learned embedding. This is not meaningful for new data during coordinate learning.
| Parameters: |
-
X
(None, default:
None
)
–
|
| Returns: |
-
embeddings( Float[Tensor, 'n_points embedding_dim']
) –
|
| Raises: |
-
ValueError
–
If the embedder has not been fitted yet.
-
Warning
–
If X is provided, as it will be ignored.
|
Source code in manify/embedders/coordinate_learning.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226 | def transform(self, X: None = None) -> Float[torch.Tensor, "n_points embedding_dim"]:
"""Transform data using learned embedding. This is not meaningful for new data during coordinate learning.
Args:
X: Ignored.
Returns:
embeddings: Learned embeddings.
Raises:
ValueError: If the embedder has not been fitted yet.
Warning: If X is provided, as it will be ignored.
"""
if not self.is_fitted_:
raise ValueError("The embedder has not been fitted yet.")
if X is not None:
warnings.warn("Coordinate learning can only return trained embeddings. X will be ignored.", stacklevel=2)
return self.embeddings_
|
Transform data using learned embedding based on the provided distance matrix D.
This method overrides the base class method BaseEmbedder.fit_transform() to not use the input data X.
| Parameters: |
-
X
(None)
–
-
D
(Float[Tensor, 'n_points n_points'])
–
Distance matrix for the points.
-
fit_kwargs
(Any, default:
{}
)
–
Additional keyword arguments passed to the model.fit() method.
|
| Returns: |
-
embeddings( Float[Tensor, 'n_points embedding_dim']
) –
|
Source code in manify/embedders/coordinate_learning.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243 | def fit_transform( # type: ignore[override]
self, X: None, D: Float[torch.Tensor, "n_points n_points"], **fit_kwargs: Any
) -> Float[torch.Tensor, "n_points embedding_dim"]:
"""Transform data using learned embedding based on the provided distance matrix D.
This method overrides the base class method `BaseEmbedder.fit_transform()` to not use the input data X.
Args:
X: Ignored.
D: Distance matrix for the points.
fit_kwargs: Additional keyword arguments passed to the `model.fit()` method.
Returns:
embeddings: Learned embeddings.
"""
return self.fit(X=None, D=D, **fit_kwargs).transform(X=None)
|