Manifolds

manify.manifolds

Tools for generating Riemannian manifolds and product manifolds.

The module consists of two classes: Manifold and ProductManifold. The Manifold class represents hyperbolic, Euclidean, or spherical manifolds of constant Gaussian curvature. The ProductManifold class supports Cartesian products of multiple manifolds, combining their geometric properties to create mixed-curvature. Both classes include methods for different key geometric operations, and are built on top of their corresponding geoopt classes (Lorentz, Euclidean, Sphere, Scaled and ProductManifold)

Manifold(curvature, dim, device='cpu', stereographic=False)

Constant-curvature Riemannian manifold class.

This class provides tools for creating and manipulating Riemannian manifolds with constant curvature (hyperbolic, Euclidean, or spherical).

Parameters:
  • curvature (float) –

    The curvature of the manifold.

  • dim (int) –

    The dimension of the manifold.

  • device (str, default: 'cpu' ) –

    The device on which the manifold is stored.

  • stereographic (bool, default: False ) –

    Whether to use stereographic coordinates.

Attributes:
  • curvature

    The curvature of the manifold. Negative for hyperbolic, zero for Euclidean, and positive for spherical manifolds.

  • dim

    The dimension of the manifold.

  • device

    The device on which the manifold is stored.

  • is_stereographic

    Whether stereographic coordinates are used for the manifold.

  • scale

    The scale factor derived from the curvature.

  • type

    A string identifier for the manifold type ('H' for hyperbolic, 'E' for Euclidean, 'S' for spherical, 'P' for poincaré ball, 'D' for stereographic sphere).

  • ambient_dim

    The dimension of the ambient space.

  • manifold

    The underlying geoopt manifold object.

  • mu0

    The origin point on the manifold.

  • name

    A string identifier for the manifold.

Source code in manify/manifolds.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def __init__(self, curvature: float, dim: int, device: str = "cpu", stereographic: bool = False):
    # Device management
    self.device = device

    # Basic properties
    self.curvature = curvature
    self.dim = dim
    self.scale = abs(curvature) ** -0.5 if curvature != 0 else 1
    self.is_stereographic = stereographic

    # A couple of manifold-specific quirks we need to deal with here
    if stereographic:
        self.manifold = geoopt.Stereographic(k=curvature, learnable=True).to(self.device)
        if curvature < 0:
            self.type = "P"
        elif curvature == 0:
            self.type = "E"
        else:  # curvature > 0
            self.type = "D"
        self.ambient_dim = dim
        self.mu0 = torch.zeros(self.dim).to(self.device).reshape(1, -1)
    else:
        if curvature < 0:
            self.type = "H"
            man = geoopt.Lorentz(k=1.0)
            # Use 'k=1.0' because the scale will take care of the curvature
            # For more information, see the bottom of page 5 of Gu et al. (2019):
            # https://openreview.net/pdf?id=HJxeWnCcF7
        elif curvature == 0:
            self.type = "E"
            man = geoopt.Euclidean(ndim=1)
            # Use 'ndim=1' because dim means the *shape* of the Euclidean space, not the dimensionality...
        else:
            self.type = "S"
            man = geoopt.Sphere()
        self.manifold = geoopt.Scaled(man, self.scale, learnable=True).to(self.device)

        self.ambient_dim = dim if curvature == 0 else dim + 1
        if curvature == 0:
            self.mu0 = torch.zeros(self.dim).to(self.device).reshape(1, -1)
        else:
            self.mu0 = torch.Tensor([1.0] + [0.0] * dim).to(self.device).reshape(1, -1)

    self.name = f"{self.type}_{abs(self.curvature):.1f}^{dim}"

    # Couple of assertions to check
    assert self.manifold.check_point(self.mu0)

to(device)

Move the Manifold object to a specified device.

Parameters:
  • device (str) –

    The device to which the manifold will be moved.

Returns:
  • manifold( Manifold ) –

    The updated manifold object on the specified device.

Source code in manify/manifolds.py
100
101
102
103
104
105
106
107
108
109
110
111
112
def to(self, device: str) -> Manifold:
    """Move the Manifold object to a specified device.

    Args:
        device: The device to which the manifold will be moved.

    Returns:
        manifold: The updated manifold object on the specified device.
    """
    self.device = device
    self.manifold = self.manifold.to(device)
    self.mu0 = self.mu0.to(device)
    return self

inner(X, Y)

Compute the inner product between two points on the manifold.

This ensures the correct inner product is computed for all manifold types (flipping the sign of dim 0 for hyperbolic manifolds).

Parameters:
  • X (Float[Tensor, 'n_points1 n_dim']) –

    Tensor of points in the manifold.

  • Y (Float[Tensor, 'n_points2 n_dim']) –

    Tensor of points in the manifold.

Returns:
  • inner_products( Float[Tensor, 'n_points1 n_points2'] ) –

    Tensor of inner products between points.

Source code in manify/manifolds.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def inner(
    self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
) -> Float[torch.Tensor, "n_points1 n_points2"]:
    """Compute the inner product between two points on the manifold.

    This ensures the correct inner product is computed for all manifold types
    (flipping the sign of dim 0 for hyperbolic manifolds).

    Args:
        X: Tensor of points in the manifold.
        Y: Tensor of points in the manifold.

    Returns:
        inner_products: Tensor of inner products between points.
    """
    # "Not inherited because of weird broadcasting stuff, plus need for scale.
    # This ensures we compute the right inner product for all manifolds (flip sign of dim 0 for hyperbolic)
    X_fixed = torch.cat([-X[:, 0:1], X[:, 1:]], dim=1) if self.type == "H" else X

    # This prevents dividing by zero in the Euclidean case
    scaler = 1 / abs(self.curvature) if self.type != "E" else 1
    return X_fixed @ Y.T * scaler

dist(X, Y)

Compute the distance between two sets of points on the manifold.

Parameters:
  • X (Float[Tensor, 'n_points1 n_dim']) –

    Tensor of points in the manifold.

  • Y (Float[Tensor, 'n_points2 n_dim']) –

    Tensor of points in the manifold.

Returns:
  • distances( Float[Tensor, 'n_points1 n_points2'] ) –

    Tensor of distances between points.

Source code in manify/manifolds.py
137
138
139
140
141
142
143
144
145
146
147
148
149
def dist(
    self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
) -> Float[torch.Tensor, "n_points1 n_points2"]:
    """Compute the distance between two sets of points on the manifold.

    Args:
        X: Tensor of points in the manifold.
        Y: Tensor of points in the manifold.

    Returns:
        distances: Tensor of distances between points.
    """
    return self.manifold.dist(X[:, None], Y[None, :])

dist2(X, Y)

Compute the squared distance between two sets of points on the manifold.

Parameters:
  • X (Float[Tensor, 'n_points1 n_dim']) –

    Tensor of points in the manifold.

  • Y (Float[Tensor, 'n_points2 n_dim']) –

    Tensor of points in the manifold.

Returns:
  • squared_distances( Float[Tensor, 'n_points1 n_points2'] ) –

    Tensor of squared distances between points.

Source code in manify/manifolds.py
151
152
153
154
155
156
157
158
159
160
161
162
163
def dist2(
    self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
) -> Float[torch.Tensor, "n_points1 n_points2"]:
    """Compute the squared distance between two sets of points on the manifold.

    Args:
        X: Tensor of points in the manifold.
        Y: Tensor of points in the manifold.

    Returns:
        squared_distances: Tensor of squared distances between points.
    """
    return self.manifold.dist2(X[:, None], Y[None, :])

pdist(X)

Compute pairwise distances between points on the manifold.

Parameters:
  • X (Float[Tensor, 'n_points n_dim']) –

    Tensor of points in the manifold.

Returns:
  • pairwise_distances( Float[Tensor, 'n_points n_points'] ) –

    Tensor of pairwise distances.

Source code in manify/manifolds.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def pdist(self, X: Float[torch.Tensor, "n_points n_dim"]) -> Float[torch.Tensor, "n_points n_points"]:
    """Compute pairwise distances between points on the manifold.

    Args:
        X: Tensor of points in the manifold.

    Returns:
        pairwise_distances: Tensor of pairwise distances.
    """
    dists = self.dist(X, X)

    # Fill diagonal with zeros
    dists.fill_diagonal_(0.0)

    return dists

pdist2(X)

Compute pairwise squared distances between points on the manifold.

Parameters:
  • X (Float[Tensor, 'n_points n_dim']) –

    Tensor of points in the manifold.

Returns:
  • pairwise_squared_distances( Float[Tensor, 'n_points n_points'] ) –

    Tensor of pairwise squared distances.

Source code in manify/manifolds.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def pdist2(self, X: Float[torch.Tensor, "n_points n_dim"]) -> Float[torch.Tensor, "n_points n_points"]:
    """Compute pairwise squared distances between points on the manifold.

    Args:
        X: Tensor of points in the manifold.

    Returns:
        pairwise_squared_distances: Tensor of pairwise squared distances.
    """
    dists2 = self.dist2(X, X)

    dists2.fill_diagonal_(0.0)

    return dists2

sample(n_samples=1, z_mean=None, sigma=None, return_tangent=False)

Sample points from the variational distribution on the manifold.

Parameters:
  • n_samples (int, default: 1 ) –

    Number of points to sample.

  • z_mean (Float[Tensor, 'n_points n_ambient_dim'] | Float[Tensor, 'n_ambient_dim'] | None, default: None ) –

    Tensor representing the mean of the sample distribution.

  • sigma (Float[Tensor, 'n_points n_dim n_dim'] | None, default: None ) –

    Optional tensor representing the covariance matrix. If None, defaults to an identity matrix.

  • return_tangent (bool, default: False ) –

    Whether to return the tangent vectors along with the sampled points.

Returns:
  • x( tuple[Float[Tensor, 'n_points n_ambient_dim'], Float[Tensor, 'n_points n_dim']] | Float[Tensor, 'n_points n_ambient_dim'] ) –

    Tensor of sampled points on the manifold

  • v( tuple[Float[Tensor, 'n_points n_ambient_dim'], Float[Tensor, 'n_points n_dim']] | Float[Tensor, 'n_points n_ambient_dim'] ) –

    Tensor of tangent vectors (if return_tangent is True).

Source code in manify/manifolds.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def sample(
    self,
    n_samples: int = 1,
    z_mean: Float[torch.Tensor, "n_points n_ambient_dim"] | Float[torch.Tensor, "n_ambient_dim"] | None = None,
    sigma: Float[torch.Tensor, "n_points n_dim n_dim"] | None = None,
    return_tangent: bool = False,
) -> (
    tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points n_dim"]]
    | Float[torch.Tensor, "n_points n_ambient_dim"]
):
    """Sample points from the variational distribution on the manifold.

    Args:
        n_samples: Number of points to sample.
        z_mean: Tensor representing the mean of the sample distribution.
        sigma: Optional tensor representing the covariance matrix. If None, defaults to an identity matrix.
        return_tangent: Whether to return the tangent vectors along with the sampled points.

    Returns:
        x: Tensor of sampled points on the manifold
        v: Tensor of tangent vectors (if `return_tangent` is True).
    """
    z_mean = self.mu0 if z_mean is None else z_mean
    z_mean = torch.Tensor(z_mean).reshape(-1, self.ambient_dim).to(self.device)
    n = z_mean.shape[0]

    sigma = torch.stack([torch.eye(self.dim)] * n).to(self.device) if sigma is None else sigma
    sigma = torch.Tensor(sigma).reshape(-1, self.dim, self.dim).to(self.device)
    assert sigma.shape == (
        n,
        self.dim,
        self.dim,
    ), f"Expected sigma shape {(n, self.dim, self.dim)}, got {sigma.shape}"
    assert torch.allclose(sigma, sigma.transpose(-1, -2)), "Covariance matrix must be symmetric"
    assert z_mean.shape[-1] == self.ambient_dim, f"Expected z_mean shape {self.ambient_dim}, got {z_mean.shape[-1]}"

    # Adjust for n_points:
    z_mean = torch.repeat_interleave(z_mean, n_samples, dim=0)
    sigma = torch.repeat_interleave(sigma, n_samples, dim=0)

    # Sample initial vector from N(0, sigma)
    N = torch.distributions.MultivariateNormal(
        loc=torch.zeros((n * n_samples, self.dim), device=self.device), covariance_matrix=sigma
    )
    v = N.sample()

    # Don't need to adjust normal vectors for the Scaled manifold class in geoopt - very cool!

    # Enter tangent plane
    v_tangent = self._to_tangent_plane_mu0(v)

    # Move to z_mean via parallel transport
    z = self.manifold.transp(x=self.mu0, y=z_mean, v=v_tangent)

    # If we're sampling at the origin, z and v should be the same
    mask = torch.all(z == self.mu0, dim=1)
    assert torch.allclose(v_tangent[mask], z[mask]), (
        "Tangent vectors at the origin should be equal to the sampled points at the origin"
    )

    # Exp map onto the manifold
    x = self.manifold.expmap(x=z_mean, u=z)

    return (x, v) if return_tangent else x

log_likelihood(z, mu=None, sigma=None)

Compute probability density function for \(\mathcal{WN}(\mathbf{z}; \mu, \Sigma)\) on the manifold.

Parameters:
  • z (Float[Tensor, 'n_points n_ambient_dim']) –

    Tensor of points on the manifold for which to compute the likelihood.

  • mu (Float[Tensor, 'n_points n_ambient_dim'] | None, default: None ) –

    Tensor representing the mean of the distribution. If None, defaults to the origin self.mu0.

  • sigma (Float[Tensor, 'n_points n_dim n_dim'] | None, default: None ) –

    Tensor representing the covariance matrix. If None, defaults to an identity matrix.

Returns:
  • log_likelihoods( Float[Tensor, 'n_points'] ) –

    Tensor containing the log-likelihood of the points z under the distribution with mean mu and covariance sigma.

Source code in manify/manifolds.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def log_likelihood(
    self,
    z: Float[torch.Tensor, "n_points n_ambient_dim"],
    mu: Float[torch.Tensor, "n_points n_ambient_dim"] | None = None,
    sigma: Float[torch.Tensor, "n_points n_dim n_dim"] | None = None,
) -> Float[torch.Tensor, "n_points"]:
    r"""Compute probability density function for $\mathcal{WN}(\mathbf{z}; \mu, \Sigma)$ on the manifold.

    Args:
        z: Tensor of points on the manifold for which to compute the likelihood.
        mu: Tensor representing the mean of the distribution. If None, defaults to the origin `self.mu0`.
        sigma: Tensor representing the covariance matrix. If None, defaults to an identity matrix.

    Returns:
        log_likelihoods: Tensor containing the log-likelihood of the points `z` under the distribution with mean
            `mu` and covariance `sigma`.
    """
    # Default to mu=self.mu0 and sigma=I
    mu = self.mu0 if mu is None else mu
    mu = torch.Tensor(mu).reshape(-1, self.ambient_dim).to(self.device)
    n = mu.shape[0]
    sigma = torch.stack([torch.eye(self.dim)] * n).to(self.device) if sigma is None else sigma
    sigma = torch.Tensor(sigma).reshape(-1, self.dim, self.dim).to(self.device)

    # Euclidean case is regular old Gaussian log-likelihood
    if self.type == "E":
        return torch.distributions.MultivariateNormal(mu, sigma).log_prob(z)

    u = self.manifold.logmap(x=mu, y=z)  # Map z to tangent space at mu
    v = self.manifold.transp(x=mu, y=self.mu0, v=u)  # Parallel transport to origin
    # assert torch.allclose(v[:, 0], torch.Tensor([0.])) # For tangent vectors at origin this should be true
    # OK, so this assertion doesn't actually pass, but it's spiritually true
    if torch.isnan(v).any():
        print("NANs in parallel transport")
        v = torch.nan_to_num(v, nan=0.0)
    N = torch.distributions.MultivariateNormal(torch.zeros(self.dim, device=self.device), sigma)
    ll = N.log_prob(v[:, 1:])

    # For convenience
    R = self.scale
    n = self.dim

    # Final formula (epsilon to avoid log(0))
    if self.type == "S":
        sin_M = torch.sin
        u_norm = self.manifold.norm(x=mu, u=u)

    else:
        sin_M = torch.sinh
        u_norm = self.manifold.base.norm(u=u)  # Horrible workaround needed for geoopt bug # type: ignore

    return ll - (n - 1) * torch.log(R * torch.abs(sin_M(u_norm / R) / u_norm) + 1e-8)

logmap(x, base=None)

Compute the logarithmic map of points on the manifold at a base point.

Parameters:
  • x (Float[Tensor, 'n_points n_dim']) –

    Tensor representing points on the manifold.

  • base (Float[Tensor, 'n_points n_dim'] | Float[Tensor, '1 n_dim'] | None, default: None ) –

    Tensor representing the base point for the map. If None, defaults to the origin self.mu0.

Returns:
  • logmap_result( Float[Tensor, 'n_points n_dim'] ) –

    Tensor representing the result of the logarithmic map from base to x on the manifold.

Source code in manify/manifolds.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def logmap(
    self,
    x: Float[torch.Tensor, "n_points n_dim"],
    base: Float[torch.Tensor, "n_points n_dim"] | Float[torch.Tensor, "1 n_dim"] | None = None,
) -> Float[torch.Tensor, "n_points n_dim"]:
    """Compute the logarithmic map of points on the manifold at a base point.

    Args:
        x: Tensor representing points on the manifold.
        base: Tensor representing the base point for the map. If None, defaults to the origin `self.mu0`.

    Returns:
        logmap_result: Tensor representing the result of the logarithmic map from `base` to `x` on the manifold.
    """
    base = self.mu0 if base is None else base
    return self.manifold.logmap(x=base, y=x)

expmap(u, base=None)

Compute the exponential map of a tangent vector \(\mathbf{u}\) at base point.

Parameters:
  • u (Float[Tensor, 'n_points n_dim']) –

    Tensor representing the tangent vector at the base point to map.

  • base (Float[Tensor, 'n_points n_dim'] | Float[Tensor, '1 n_dim'] | None, default: None ) –

    Tensor representing the base point for the exponential map. If None, defaults to the origin self.mu0.

Returns:
  • expmap_result( Float[Tensor, 'n_points n_dim'] ) –

    Tensor representing the result of the exponential map applied to u at the base point.

Source code in manify/manifolds.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def expmap(
    self,
    u: Float[torch.Tensor, "n_points n_dim"],
    base: Float[torch.Tensor, "n_points n_dim"] | Float[torch.Tensor, "1 n_dim"] | None = None,
) -> Float[torch.Tensor, "n_points n_dim"]:
    r"""Compute the exponential map of a tangent vector $\mathbf{u}$ at base point.

    Args:
        u: Tensor representing the tangent vector at the base point to map.
        base: Tensor representing the base point for the exponential map. If None, defaults to the origin
            `self.mu0`.

    Returns:
        expmap_result: Tensor representing the result of the exponential map applied to `u` at the base point.
    """
    base = self.mu0 if base is None else base
    return self.manifold.expmap(x=base, u=u)

stereographic(*points)

Convert the manifold to its stereographic equivalent. If points are given, convert them as well.

Formula for stereographic projection (for \(i \geq 1\)): \begin{equation} \rho_K(x_i) = \frac{x_i}{1 + \sqrt{|K|} \cdot x_0} \end{equation}

For more information, see https://arxiv.org/pdf/1911.08411

Parameters:
  • *points (Float[Tensor, 'n_points n_dim'], default: () ) –

    Variable number of tensors representing points on the manifold to convert to stereographic coords.

Returns:
  • stereo_manifold( Manifold ) –

    The manifold in stereographic coordinates.

  • stereo_points( ... ) –

    The provided points converted to stereographic coordinates (if any).

Source code in manify/manifolds.py
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> tuple[Manifold, ...]:
    r"""Convert the manifold to its stereographic equivalent. If points are given, convert them as well.

    Formula for stereographic projection (for $i \geq 1$):
    \begin{equation}
        \rho_K(x_i) = \frac{x_i}{1 + \sqrt{|K|} \cdot x_0}
    \end{equation}

    For more information, see https://arxiv.org/pdf/1911.08411

    Args:
        *points: Variable number of tensors representing points on the manifold to convert to stereographic coords.

    Returns:
        stereo_manifold: The manifold in stereographic coordinates.
        stereo_points: The provided points converted to stereographic coordinates (if any).
    """
    if self.is_stereographic:
        print("Manifold is already in stereographic coordinates.")
        return self, *points

    # Convert manifold
    stereo_manifold = Manifold(self.curvature, self.dim, device=self.device, stereographic=True)

    # Euclidean edge case
    if self.type == "E":
        return stereo_manifold, *points

    # Convert points
    num = [X[:, 1:] for X in points]
    denom = [1 + abs(self.curvature) ** 0.5 * X[:, 0:1] for X in points]
    for X in denom:
        X[X.abs() < 1e-6] = 1e-6  # Avoid division by zero
    stereo_points = [n / d for n, d in zip(num, denom, strict=False)]
    assert all(stereo_manifold.manifold.check_point(X) for X in stereo_points), (
        "Generated points do not lie on the target manifold"
    )

    return stereo_manifold, *stereo_points

inverse_stereographic(*points)

Convert the manifold from its stereographic coordinates back to the original coordinates.

If points are given, convert them as well.

Formula for inverse stereographic projection: \begin{align} x_0 &= \frac{1 + sign(K) \cdot |y|^2}{1 - sign(K) \cdot |y|^2} \\ x_i &= \frac{2 \cdot y_i}{1 - sign(K) \cdot |y|^2} \end{align}

Parameters:
  • *points (Float[Tensor, 'n_points n_dim_stereo'], default: () ) –

    Variable number of tensors representing points in stereographic coords to convert back to original coords.

Returns:
  • inv_stereo_manifold( Manifold ) –

    The manifold in original coords.

  • inv_stereo_points( ... ) –

    The provided points converted from stereographic to original coords (if any).

Source code in manify/manifolds.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
def inverse_stereographic(self, *points: Float[torch.Tensor, "n_points n_dim_stereo"]) -> tuple[Manifold, ...]:
    r"""Convert the manifold from its stereographic coordinates back to the original coordinates.

    If points are given, convert them as well.

    Formula for inverse stereographic projection:
    \begin{align}
        x_0 &= \frac{1 + sign(K) \cdot \|y\|^2}{1 - sign(K) \cdot \|y\|^2} \\\\
        x_i &= \frac{2 \cdot y_i}{1 - sign(K) \cdot \|y\|^2}
    \end{align}

    Args:
        *points: Variable number of tensors representing points in stereographic coords to convert back to original
            coords.

    Returns:
        inv_stereo_manifold: The manifold in original coords.
        inv_stereo_points: The provided points converted from stereographic to original coords (if any).
    """
    if not self.is_stereographic:
        print("Manifold is already in original coordinates.")
        return self, *points

    # Convert manifold
    orig_manifold = Manifold(self.curvature, self.dim, device=self.device, stereographic=False)

    # Euclidean edge case
    if self.type == "E":
        return orig_manifold, *points

    # Inverse projection for points
    out = []
    for X in points:
        # Calculate squared norm
        # let σ = sign(K)  and  λ = sqrt(|K|)
        sign = torch.sign(torch.tensor(self.curvature, device=self.device))
        lam = abs(self.curvature) ** 0.5

        # compute the ‖·‖² in the *scaled* ball
        norm2 = torch.sum((lam * X) ** 2, dim=1)

        # inverse‐stereographic denom must be (1 + σ⋅‖y‖²), *not* (1 – σ⋅‖y‖²)
        denom = 1.0 + sign * norm2
        # clamp to avoid blow‐up at the boundary
        denom = torch.clamp_min(denom.abs(), 1e-6) * denom.sign()

        # then
        X0 = (1.0 - sign * norm2) / denom
        Xi = 2.0 * lam * X / denom.unsqueeze(1)

        # Combine into full coordinates
        inv_points = torch.cat([X0.unsqueeze(1), Xi], dim=1)

        # Let the manifold class validate the points
        if not orig_manifold.manifold.check_point(inv_points):
            raise ValueError("Generated points do not lie on the target manifold")

        out.append(inv_points)

    return orig_manifold, *out

apply(f)

Create a decorator for logmap -> function -> expmap. If a base point is not provided, use the origin.

Parameters:
  • f (Callable) –

    Function to apply in the tangent space.

Returns:
  • wrapper( Callable ) –

    Callable representing the composed map that:

    1. Maps points to the tangent space using logmap
    2. Applies the function f
    3. Maps the result back to the manifold using expmap
Source code in manify/manifolds.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
def apply(self, f: Callable) -> Callable:
    """Create a decorator for logmap -> function -> expmap. If a base point is not provided, use the origin.

    Args:
        f: Function to apply in the tangent space.

    Returns:
        wrapper: Callable representing the composed map that:

            1. Maps points to the tangent space using logmap
            2. Applies the function f
            3. Maps the result back to the manifold using expmap
    """

    def wrapper(x: Float[torch.Tensor, "n_points n_dim"]) -> Float[torch.Tensor, "n_points n_dim"]:
        return self.expmap(
            f(self.logmap(x, base=self.mu0)),
            base=self.mu0,
        )

    return wrapper

ProductManifold(signature, device='cpu', stereographic=False)

Bases: Manifold

Tools for constructing product manifolds with multiple factors.

A product manifold combines multiple manifolds with different curvatures and dimensions into a single product space.

Parameters:
  • signature (list[tuple[float, int]]) –

    List of (curvature, dimension) tuples for each factor manifold.

  • device (str, default: 'cpu' ) –

    The device on which the manifold is stored.

  • stereographic (bool, default: False ) –

    Whether to use stereographic coordinates.

Attributes:
  • signature

    List of tuples defining the curvature and dimension of each factor manifold.

  • device

    The device on which the manifold is stored.

  • is_stereographic

    Whether stereographic coordinates are used.

  • type

    String identifier for product manifold (always 'P').

  • curvatures

    List of curvature values for each factor manifold.

  • dims

    List of dimensions for each factor manifold.

  • n_manifolds

    Number of factor manifolds.

  • P

    List of individual Manifold objects that make up this product manifold.

  • manifold

    The underlying geoopt ProductManifold object.

  • name

    String identifier for the product manifold.

  • mu0

    The origin point on the product manifold.

  • ambient_dim

    Total ambient dimension of the product manifold.

  • dim

    Total intrinsic dimension of the product manifold.

  • dim2man

    Dictionary mapping dimensions to manifold indices.

  • man2dim

    Dictionary mapping manifold indices to their dimensions.

  • man2intrinsic

    Dictionary mapping manifold indices to their intrinsic dimensions.

  • intrinsic2man

    Dictionary mapping intrinsic dimensions to manifold indices.

  • projection_matrix

    Matrix for projecting from intrinsic to ambient dimensions.

Source code in manify/manifolds.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def __init__(self, signature: list[tuple[float, int]], device: str = "cpu", stereographic: bool = False):
    # Device management
    self.device = device

    # Basic properties
    self.type = "P"
    self.signature = signature
    self.curvatures = [curvature for curvature, _ in signature]
    self.dims = [dim for _, dim in signature]
    self.n_manifolds = len(signature)
    self.is_stereographic = stereographic

    # Actually initialize the geoopt manifolds; other derived properties
    self.P = [Manifold(curvature, dim, device=device, stereographic=stereographic) for curvature, dim in signature]
    manifold_class = geoopt.StereographicProductManifold if stereographic else geoopt.ProductManifold
    self.manifold = manifold_class(*[(M.manifold, M.ambient_dim) for M in self.P]).to(device)
    self.name = " x ".join([M.name for M in self.P])

    # Origin
    self.mu0 = torch.cat([M.mu0 for M in self.P], axis=1).to(self.device)

    # Manifold <-> Dimension mapping
    self.ambient_dim, self.n_manifolds, self.dim = 0, 0, 0
    self.dim2man, self.man2dim, self.man2intrinsic, self.intrinsic2man = {}, {}, {}, {}

    for M in self.P:
        for d in range(self.ambient_dim, self.ambient_dim + M.ambient_dim):
            self.dim2man[d] = self.n_manifolds
        for d in range(self.dim, self.dim + M.dim):
            self.intrinsic2man[d] = self.n_manifolds
        self.man2dim[self.n_manifolds] = list(range(self.ambient_dim, self.ambient_dim + M.ambient_dim))
        self.man2intrinsic[self.n_manifolds] = list(range(self.dim, self.dim + M.dim))

        self.ambient_dim += M.ambient_dim
        self.n_manifolds += 1
        self.dim += M.dim

    # Lift matrix - useful for tensor stuff
    # The idea here is to right-multiply by this to lift a vector in R^dim to a vector in R^ambient_dim
    # such that there are zeros in all the right places, i.e. to make it a tangent vector at the origin of P
    self.projection_matrix = torch.zeros(self.dim, self.ambient_dim, device=self.device)
    for i in range(len(self.P)):
        intrinsic_dims = self.man2intrinsic[i]
        ambient_dims = self.man2dim[i]
        for j, k in zip(intrinsic_dims, ambient_dims[-len(intrinsic_dims) :], strict=False):
            self.projection_matrix[j, k] = 1.0

parameters()

Get scale parameters for all component manifolds.

Returns:
  • scales( list[Parameter] ) –

    List of scale parameters for each component manifold.

Source code in manify/manifolds.py
567
568
569
570
571
572
573
def parameters(self) -> list[torch.nn.parameter.Parameter]:
    """Get scale parameters for all component manifolds.

    Returns:
        scales: List of scale parameters for each component manifold.
    """
    return [x._log_scale for x in self.manifold.manifolds]

to(device)

Move all components to a new device.

Parameters:
  • device (str) –

    The device to which to move all components.

Returns:
  • manifold( ProductManifold ) –

    The updated ProductManifold object on the specified device.

Source code in manify/manifolds.py
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
def to(self, device: str) -> ProductManifold:
    """Move all components to a new device.

    Args:
        device: The device to which to move all components.

    Returns:
        manifold: The updated ProductManifold object on the specified device.
    """
    self.device = device
    self.P = [M.to(device) for M in self.P]
    self.manifold = self.manifold.to(device)
    self.mu0 = self.mu0.to(device)
    self.projection_matrix = self.projection_matrix.to(device)
    return self

inner(X, Y)

Compute the inner product between points on the product manifold.

The inner product is the sum of inner products in each component manifold.

Parameters:
  • X (Float[Tensor, 'n_points1 n_dim']) –

    Tensor of points in the product manifold.

  • Y (Float[Tensor, 'n_points2 n_dim']) –

    Tensor of points in the product manifold.

Returns:
  • inner_products( Float[Tensor, 'n_points1 n_points2'] ) –

    Tensor of inner products between points.

Source code in manify/manifolds.py
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
def inner(
    self, X: Float[torch.Tensor, "n_points1 n_dim"], Y: Float[torch.Tensor, "n_points2 n_dim"]
) -> Float[torch.Tensor, "n_points1 n_points2"]:
    """Compute the inner product between points on the product manifold.

    The inner product is the sum of inner products in each component manifold.

    Args:
        X: Tensor of points in the product manifold.
        Y: Tensor of points in the product manifold.

    Returns:
        inner_products: Tensor of inner products between points.
    """
    ips = [M.inner(x, y) for x, y, M in zip(self.factorize(X), self.factorize(Y), self.P, strict=False)]
    return torch.stack(ips, dim=0).sum(dim=0)

factorize(X, intrinsic=False)

Factorize the embeddings into the individual manifolds.

Parameters:
  • X (Float[Tensor, 'n_points n_dim']) –

    Tensor representing the embeddings to be factorized.

  • intrinsic (bool, default: False ) –

    bool for whether to use intrinsic dimensions of the manifolds.

Returns:
  • X_factorized( list[Float[Tensor, 'n_points n_dim_manifold']] ) –

    A list of tensors representing the factorized embeddings in each manifold.

Source code in manify/manifolds.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
def factorize(
    self, X: Float[torch.Tensor, "n_points n_dim"], intrinsic: bool = False
) -> list[Float[torch.Tensor, "n_points n_dim_manifold"]]:
    """Factorize the embeddings into the individual manifolds.

    Args:
        X: Tensor representing the embeddings to be factorized.
        intrinsic: bool for whether to use intrinsic dimensions of the manifolds.

    Returns:
        X_factorized: A list of tensors representing the factorized embeddings in each manifold.
    """
    dims_dict = self.man2intrinsic if intrinsic else self.man2dim
    return [X[..., dims_dict[i]] for i in range(len(self.P))]

sample(n_samples=1, z_mean=None, sigma_factorized=None, return_tangent=False)

Sample from the variational distribution.

Parameters:
  • n_samples (int, default: 1 ) –

    Number of points to sample.

  • z_mean (Float[Tensor, 'n_points n_ambient_dim'] | None, default: None ) –

    Tensor representing the mean of the sample distribution. If None, defaults to the origin self.mu0.

  • sigma_factorized (list[Float[Tensor, 'n_points ...']] | None, default: None ) –

    List of tensors representing factorized covariance matrices for each manifold. If None, defaults to a list of identity matrices for each manifold.

  • return_tangent (bool, default: False ) –

    Whether to return the tangent vectors along with the sampled points.

Returns:
  • x( tuple[Float[Tensor, 'n_points n_ambient_dim'], Float[Tensor, 'n_points total_intrinsic_dim']] | Float[Tensor, 'n_points n_ambient_dim'] ) –

    Tensor of sampled points on the manifold

  • v( tuple[Float[Tensor, 'n_points n_ambient_dim'], Float[Tensor, 'n_points total_intrinsic_dim']] | Float[Tensor, 'n_points n_ambient_dim'] ) –

    Tensor of tangent vectors (if return_tangent is True).

Source code in manify/manifolds.py
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def sample(
    self,
    n_samples: int = 1,
    z_mean: Float[torch.Tensor, "n_points n_ambient_dim"] | None = None,
    sigma_factorized: list[Float[torch.Tensor, "n_points ..."]] | None = None,
    return_tangent: bool = False,
) -> (
    tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Float[torch.Tensor, "n_points total_intrinsic_dim"]]
    | Float[torch.Tensor, "n_points n_ambient_dim"]
):
    """Sample from the variational distribution.

    Args:
        n_samples: Number of points to sample.
        z_mean: Tensor representing the mean of the sample distribution. If None, defaults to the origin `self.mu0`.
        sigma_factorized: List of tensors representing factorized covariance matrices for each manifold. If None,
            defaults to a list of identity matrices for each manifold.
        return_tangent: Whether to return the tangent vectors along with the sampled points.

    Returns:
        x: Tensor of sampled points on the manifold
        v: Tensor of tangent vectors (if `return_tangent` is True).
    """
    z_mean = self.mu0 if z_mean is None else z_mean
    z_mean = torch.Tensor(z_mean).reshape(-1, self.ambient_dim).to(self.device)
    n = z_mean.shape[0]

    sigma_factorized = (
        [torch.stack([torch.eye(M.dim)] * n) for M in self.P] if sigma_factorized is None else sigma_factorized
    )
    sigma_factorized = [
        torch.Tensor(sigma).reshape(-1, M.dim, M.dim).to(self.device)
        for M, sigma in zip(self.P, sigma_factorized, strict=False)
    ]

    # Adjust for n_points:
    z_mean = torch.repeat_interleave(z_mean, n_samples, dim=0)
    sigma_factorized = [torch.repeat_interleave(sigma, n_samples, dim=0) for sigma in sigma_factorized]

    assert all(
        sigma.shape == (n * n_samples, M.dim, M.dim) for M, sigma in zip(self.P, sigma_factorized, strict=False)
    ), "Sigma matrices must match the dimensions of the manifolds."
    assert z_mean.shape == (n * n_samples, self.ambient_dim), (
        "z_mean must have the same ambient dimension as the product manifold."
    )

    # Sample initial vector from N(0, sigma)
    samples = [
        M.sample(1, z_M, sigma_M, return_tangent=True)
        for M, z_M, sigma_M in zip(self.P, self.factorize(z_mean), sigma_factorized, strict=False)
    ]

    x = torch.cat([s[0] for s in samples], dim=1)
    v = torch.cat([s[1] for s in samples], dim=1)

    # Different samples and tangent vectors
    return (x, v) if return_tangent else x

log_likelihood(z, mu=None, sigma_factorized=None)

Compute probability density function for \(\mathcal{WN}(\mathbf{z} ; \mu, \Sigma)\) on the product manifold.

Parameters:
  • z (Float[Tensor, 'batch_size n_dim']) –

    Tensor representing the points for which the log-likelihood is computed.

  • mu (Float[Tensor, 'batch_size n_dim'] | None, default: None ) –

    Tensor representing the mean of the distribution. If None, defaults to the origin self.mu0.

  • sigma_factorized (list[Float[Tensor, 'batch_size ...']] | None, default: None ) –

    List of tensors representing factorized covariance matrices for each manifold. If None, defaults to a list of identity matrices for each manifold.

Returns:
  • log_likelihoods( Float[Tensor, 'batch_size'] ) –

    Tensor containing the log-likelihood of the points z under the distribution with mean mu and covariance sigma.

Source code in manify/manifolds.py
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
def log_likelihood(
    self,
    z: Float[torch.Tensor, "batch_size n_dim"],
    mu: Float[torch.Tensor, "batch_size n_dim"] | None = None,
    sigma_factorized: list[Float[torch.Tensor, "batch_size ..."]] | None = None,
) -> Float[torch.Tensor, "batch_size"]:
    r"""Compute probability density function for $\mathcal{WN}(\mathbf{z} ; \mu, \Sigma)$ on the product manifold.

    Args:
        z: Tensor representing the points for which the log-likelihood is computed.
        mu: Tensor representing the mean of the distribution. If None, defaults to the origin `self.mu0`.
        sigma_factorized: List of tensors representing factorized covariance matrices for each manifold. If None,
            defaults to a list of identity matrices for each manifold.

    Returns:
        log_likelihoods: Tensor containing the log-likelihood of the points `z` under the distribution with mean
            `mu` and covariance `sigma`.
    """
    n = z.shape[0]
    mu = torch.vstack([self.mu0] * n).to(self.device) if mu is None else mu

    sigma_factorized = (
        [torch.stack([torch.eye(M.dim)] * n) for M in self.P] if sigma_factorized is None else sigma_factorized
    )
    # Note that this factorization assumes block-diagonal covariance matrices

    mu_factorized = self.factorize(mu)
    z_factorized = self.factorize(z)
    component_lls = [
        M.log_likelihood(z_M, mu_M, sigma_M).unsqueeze(dim=1)
        for M, z_M, mu_M, sigma_M in zip(self.P, z_factorized, mu_factorized, sigma_factorized, strict=False)
    ]
    return torch.cat(component_lls, axis=1).sum(axis=1)

stereographic(*points)

Convert the manifold to its stereographic equivalent. If points are given, convert them as well.

Formula for stereographic projection (for \(i \geq 1\)): \begin{equation} \rho_K(x_i) = \frac{x_i}{1 + \sqrt{|K|} \cdot x_0} \end{equation}

For more information, see https://arxiv.org/pdf/1911.08411

Parameters:
  • *points (Float[Tensor, 'n_points n_dim'], default: () ) –

    Variable number of tensors representing points on the manifold to convert to stereographic coords.

Returns:
  • stereo_manifold( ProductManifold ) –

    The manifold in stereographic coords.

  • stereo_points( ... ) –

    The provided points converted to stereographic coords (if any).

Source code in manify/manifolds.py
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
def stereographic(self, *points: Float[torch.Tensor, "n_points n_dim"]) -> tuple[ProductManifold, ...]:
    r"""Convert the manifold to its stereographic equivalent. If points are given, convert them as well.

    Formula for stereographic projection (for $i \geq 1$):
    \begin{equation}
        \rho_K(x_i) = \frac{x_i}{1 + \sqrt{|K|} \cdot x_0}
    \end{equation}

    For more information, see https://arxiv.org/pdf/1911.08411

    Args:
        *points: Variable number of tensors representing points on the manifold to convert to stereographic coords.

    Returns:
        stereo_manifold: The manifold in stereographic coords.
        stereo_points: The provided points converted to stereographic coords (if any).
    """
    if self.is_stereographic:
        print("Manifold is already in stereographic coords.")
        return self, *points

    # Convert manifold
    stereo_manifold = ProductManifold(self.signature, device=self.device, stereographic=True)

    # Convert points
    stereo_points = [
        torch.hstack([M.stereographic(x)[1] for x, M in zip(self.factorize(X), self.P, strict=False)])
        for X in points
    ]
    assert all(stereo_manifold.manifold.check_point(X) for X in stereo_points), (
        "Generated points do not lie on the target manifold"
    )

    return stereo_manifold, *stereo_points

inverse_stereographic(*points)

Convert the manifold from its stereographic coordinates back to the original coordinates.

If points are given, convert them as well.

Formula for inverse stereographic projection: \begin{align} x_0 &= \frac{1 + sign(K) \cdot |y|^2}{1 - sign(K) \cdot |y|^2} \\ x_i &= \frac{2 \cdot y_i}{1 - sign(K) \cdot |y|^2} \end{align}

Parameters:
  • *points (Float[Tensor, 'n_points n_dim_stereo'], default: () ) –

    Variable number of tensors representing points in stereographic coords to convert back to original coords.

Returns:
  • inv_stereo_manifold( ProductManifold ) –

    The manifold in original coords.

  • inv_stereo_points( ... ) –

    The provided points converted from stereographic to original coords (if any).

Source code in manify/manifolds.py
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
def inverse_stereographic(
    self, *points: Float[torch.Tensor, "n_points n_dim_stereo"]
) -> tuple[ProductManifold, ...]:
    r"""Convert the manifold from its stereographic coordinates back to the original coordinates.

    If points are given, convert them as well.

    Formula for inverse stereographic projection:
    \begin{align}
        x_0 &= \frac{1 + sign(K) \cdot \|y\|^2}{1 - sign(K) \cdot \|y\|^2} \\\\
        x_i &= \frac{2 \cdot y_i}{1 - sign(K) \cdot \|y\|^2}
    \end{align}

    Args:
        *points: Variable number of tensors representing points in stereographic coords to convert back to original
            coords.

    Returns:
        inv_stereo_manifold: The manifold in original coords.
        inv_stereo_points: The provided points converted from stereographic to original coords (if any).
    """
    if not self.is_stereographic:
        print("Manifold is already in original coordinates.")
        return self, *points

    # Convert manifold
    orig_manifold = ProductManifold(self.signature, device=self.device, stereographic=False)

    orig_points = [
        torch.hstack([M.inverse_stereographic(x)[1] for x, M in zip(self.factorize(X), self.P, strict=False)])
        for X in points
    ]
    assert all(orig_manifold.manifold.check_point(X) for X in orig_points), (
        "Generated points do not lie on the target manifold"
    )

    return orig_manifold, *orig_points

gaussian_mixture(num_points=1000, num_classes=2, num_clusters=None, seed=None, cov_scale_means=1.0, cov_scale_points=1.0, regression_noise_std=0.1, task='classification', adjust_for_dims=False)

Generate a set of labeled samples from a Gaussian mixture model.

Parameters:
  • num_points (int, default: 1000 ) –

    The number of points to generate.

  • num_classes (int, default: 2 ) –

    The number of classes to generate.

  • num_clusters (int | None, default: None ) –

    The number of clusters to generate. If None, defaults to num_classes.

  • seed (int | None, default: None ) –

    An optional seed for the random number generator. If None, no random seed is set.

  • cov_scale_means (float, default: 1.0 ) –

    The scale of the covariance matrix for the means.

  • cov_scale_points (float, default: 1.0 ) –

    The scale of the covariance matrix for the points.

  • regression_noise_std (float, default: 0.1 ) –

    The standard deviation of the noise for regression labels.

  • task (Literal['classification', 'regression'], default: 'classification' ) –

    The type of labels to generate. Either "classification" or "regression".

  • adjust_for_dims (bool, default: False ) –

    Whether to adjust the covariance matrices for the number of dimensions in each manifold.

Returns:
  • samples( Float[Tensor, 'n_points n_ambient_dim'] ) –

    A tensor of generated samples.

  • class_assignments( Real[Tensor, 'n_points'] ) –

    A tensor of class assignments for the samples.

Source code in manify/manifolds.py
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
@torch.no_grad()  # type: ignore
def gaussian_mixture(
    self,
    num_points: int = 1_000,
    num_classes: int = 2,
    num_clusters: int | None = None,
    seed: int | None = None,
    cov_scale_means: float = 1.0,
    cov_scale_points: float = 1.0,
    regression_noise_std: float = 0.1,
    task: Literal["classification", "regression"] = "classification",
    adjust_for_dims: bool = False,
) -> tuple[Float[torch.Tensor, "n_points n_ambient_dim"], Real[torch.Tensor, "n_points"]]:
    """Generate a set of labeled samples from a Gaussian mixture model.

    Args:
        num_points: The number of points to generate.
        num_classes: The number of classes to generate.
        num_clusters: The number of clusters to generate. If None, defaults to num_classes.
        seed: An optional seed for the random number generator. If None, no random seed is set.
        cov_scale_means: The scale of the covariance matrix for the means.
        cov_scale_points: The scale of the covariance matrix for the points.
        regression_noise_std: The standard deviation of the noise for regression labels.
        task: The type of labels to generate. Either "classification" or "regression".
        adjust_for_dims: Whether to adjust the covariance matrices for the number of dimensions in each manifold.

    Returns:
        samples: A tensor of generated samples.
        class_assignments: A tensor of class assignments for the samples.
    """
    # Set seed
    if seed is not None:
        torch.manual_seed(seed)

    # Deal with clusters
    num_clusters = num_clusters or num_classes
    assert num_clusters >= num_classes, "Number of clusters must be at least as large as number of classes."

    # Adjust covariance matrices for number of dimensions
    if adjust_for_dims:
        cov_scale_points /= self.dim
        cov_scale_means /= self.dim

    # Generate cluster means
    cluster_means = self.sample(num_clusters, sigma_factorized=[torch.eye(M.dim) * cov_scale_means for M in self.P])
    assert cluster_means.shape == (num_clusters, self.ambient_dim), "Cluster means shape mismatch."  # type: ignore

    # Generate class assignments
    cluster_probs = torch.rand(num_clusters)
    cluster_probs /= cluster_probs.sum()

    # Draw cluster assignments: ensure at least 2 points per cluster. This is to ensure splits can always happen.
    cluster_assignments = torch.multinomial(input=cluster_probs, num_samples=num_points, replacement=True)
    while (cluster_assignments.bincount() < 2).any():
        cluster_assignments = torch.multinomial(input=cluster_probs, num_samples=num_points, replacement=True)
    assert cluster_assignments.shape == (num_points,), "Cluster assignments shape mismatch."

    # Generate covariance matrices for each class - Wishart distribution
    cov_matrices = [
        torch.distributions.Wishart(df=M.dim + 1, covariance_matrix=torch.eye(M.dim) * cov_scale_points).sample(
            sample_shape=(num_clusters,)
        )
        + torch.eye(M.dim) * 1e-5  # jitter to avoid singularity
        for M in self.P
    ]

    # Generate random samples for each cluster
    sample_means = torch.stack([cluster_means[c] for c in cluster_assignments])
    assert sample_means.shape == (num_points, self.ambient_dim), "Sample means shape mismatch."
    sample_covs = [torch.stack([cov_matrix[c] for c in cluster_assignments]) for cov_matrix in cov_matrices]
    samples, tangent_vals = self.sample(z_mean=sample_means, sigma_factorized=sample_covs, return_tangent=True)
    assert samples.shape == (num_points, self.ambient_dim), "Sample shape mismatch."

    # Map clusters to classes
    cluster_to_class = torch.cat(
        [
            torch.arange(num_classes, device=self.device),
            torch.randint(0, num_classes, (num_clusters - num_classes,), device=self.device),
        ]
    )
    assert cluster_to_class.shape == (num_clusters,), "Cluster to class mapping shape mismatch."
    assert torch.unique(cluster_to_class).shape == (num_classes,), (
        "Cluster to class mapping must cover all classes."
    )

    # Generate outputs
    if task == "classification":
        labels = cluster_to_class[cluster_assignments]
    elif task == "regression":
        slopes = (0.5 - torch.randn(num_clusters, self.dim, device=self.device)) * 2
        intercepts = (0.5 - torch.randn(num_clusters, device=self.device)) * 20
        labels = (
            torch.einsum("ij,ij->i", slopes[cluster_assignments], tangent_vals) + intercepts[cluster_assignments]
        )

        # Noise component
        N = torch.distributions.Normal(0, regression_noise_std)
        v = N.sample((num_points,)).to(self.device)
        labels += v

        # Normalize regression labels to range [0, 1] so that RMSE can be more easily interpreted
        labels = (labels - labels.min()) / (labels.max() - labels.min())

    return samples, labels