Layers

manify.predictors.nn.layers

Neural network layers for product manifolds.

KappaGCNLayer(in_features, out_features, manifold, nonlinearity=torch.relu)

Bases: Module

Implementation for the Kappa GCN layer.

Parameters:
  • in_features (int) –

    Number of input features

  • out_features (int) –

    Number of output features

  • manifold (Manifold) –

    Manifold object for the Kappa GCN

  • nonlinearity (Callable | None, default: relu ) –

    Function for nonlinear activation.

Attributes:
  • W

    Weight matrix parameter.

  • sigma

    Nonlinear activation function applied via the manifold.

  • manifold

    The manifold object for geometric operations.

Source code in manify/predictors/nn/layers.py
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self, in_features: int, out_features: int, manifold: Manifold, nonlinearity: Callable | None = torch.relu
):
    super().__init__()

    # Parameters are Euclidean, straightforwardly
    self.W = torch.nn.Parameter(torch.randn(in_features, out_features) * 0.01)

    # Nonlinearity must be applied via the manifold
    self.sigma = manifold.apply(nonlinearity) if nonlinearity else lambda x: x

    # Also store manifold
    self.manifold = manifold

forward(X, A_hat=None)

Forward pass for the Kappa GCN layer.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Embedding matrix

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Normalized adjacency matrix

Returns:
  • AXW( Float[Tensor, 'n_nodes dim'] ) –

    Transformed node features after message passing and nonlinear activation.

Source code in manify/predictors/nn/layers.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the Kappa GCN layer.

    Args:
        X: Embedding matrix
        A_hat: Normalized adjacency matrix

    Returns:
        AXW: Transformed node features after message passing and nonlinear activation.
    """
    # 1. right-multiply X by W - mobius_matvec broadcasts correctly (verified)
    XW = self.manifold.manifold.mobius_matvec(m=self.W, x=X)

    # 2. left-multiply (X @ W) by A_hat - we need our own implementation for this
    if A_hat is None:
        AXW = XW
    elif isinstance(self.manifold, ProductManifold):
        XWs = self.manifold.factorize(XW)
        AXW = torch.hstack([self._left_multiply(A_hat, XW, M) for XW, M in zip(XWs, self.manifold.P, strict=False)])
    else:
        AXW = self._left_multiply(A_hat, XW, self.manifold)

    # 3. Apply nonlinearity - note that sigma is wrapped with our manifold.apply decorator
    AXW = self.sigma(AXW)

    return AXW

KappaSequential(*layers)

Bases: Module

Sequential container for κ-layers that properly handles adjacency matrices.

Similar to nn.Sequential but passes the adjacency matrix through each layer. All layers should accept (X, A_hat) and return X.

Parameters:
  • *layers (Module, default: () ) –

    Variable number of layers to be added to the sequence.

Source code in manify/predictors/nn/layers.py
114
115
116
def __init__(self, *layers: nn.Module):
    super().__init__()
    self.layers = nn.ModuleList(layers)

forward(X, A_hat=None)

Forward pass through all layers.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Adjacency matrix passed to each layer

Returns:
  • Float[Tensor, 'n_nodes out_dim']

    Output after passing through all layers

Source code in manify/predictors/nn/layers.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes out_dim"]:
    """Forward pass through all layers.

    Args:
        X: Input features
        A_hat: Adjacency matrix passed to each layer

    Returns:
        Output after passing through all layers
    """
    for layer in self.layers:
        X = layer(X, A_hat)
    return X

append(layer)

Add a layer to the end of the sequence.

Source code in manify/predictors/nn/layers.py
134
135
136
def append(self, layer: nn.Module) -> None:
    """Add a layer to the end of the sequence."""
    self.layers.append(layer)

StereographicLogits(out_features, manifold, apply_softmax=False)

Bases: Module

Stereographic logits layer for classification and regression on product manifolds.

Computes signed distances from hyperplanes in the product manifold space. Can optionally apply softmax for classification tasks.

Parameters:
  • out_features (int) –

    Number of output classes (dimensionality of output space)

  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • apply_softmax (bool, default: False ) –

    Whether to apply softmax to the output logits (default: False)

Source code in manify/predictors/nn/layers.py
170
171
172
173
174
175
176
177
178
179
180
181
def __init__(self, out_features: int, manifold: Manifold | ProductManifold, apply_softmax: bool = False):
    super().__init__()

    self.out_features = out_features
    self.manifold = manifold
    self.apply_softmax = apply_softmax

    # Weight matrix (Euclidean parameters)
    self.W = nn.Parameter(torch.randn(manifold.dim, out_features) * 0.01)

    # Bias points on the manifold
    self.p_ks = geoopt.ManifoldParameter(torch.zeros(out_features, manifold.dim), manifold=manifold.manifold)

forward(X, A_hat=None, aggregate_logits=False)

Forward pass through stereographic logits.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Optional adjacency matrix for logit aggregation

  • aggregate_logits (bool, default: False ) –

    Whether to aggregate logits using adjacency matrix

Returns:
  • Float[Tensor, 'n_nodes n_classes']

    Logits (or probabilities if apply_softmax=True)

Source code in manify/predictors/nn/layers.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    aggregate_logits: bool = False,
) -> Float[torch.Tensor, "n_nodes n_classes"]:
    """Forward pass through stereographic logits.

    Args:
        X: Input features
        A_hat: Optional adjacency matrix for logit aggregation
        aggregate_logits: Whether to aggregate logits using adjacency matrix

    Returns:
        Logits (or probabilities if apply_softmax=True)
    """
    # Compute logits based on manifold type
    if isinstance(self.manifold, ProductManifold):
        logits = self._get_logits_product_manifold(X, self.W, self.p_ks, self.manifold)
    else:
        logits = self._get_logits_single_manifold(X, self.W, self.p_ks, self.manifold, return_inner_products=False)

    # Optional aggregation via adjacency matrix
    if A_hat is not None and aggregate_logits:
        logits = A_hat @ logits

    # Optional softmax for classification
    if self.apply_softmax:
        logits = torch.softmax(logits, dim=-1)

    return logits

FermiDiracDecoder(manifold, learnable_params=True)

Bases: Module

Fermi-Dirac decoder for link prediction tasks.

Computes pairwise distances and applies Fermi-Dirac transformation to predict edge probabilities.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • learnable_params (bool, default: True ) –

    If True, temperature and bias are learnable parameters. If False, they are fixed to 1.0 and 0.0, respectively.

Source code in manify/predictors/nn/layers.py
308
309
310
311
312
313
314
315
316
317
318
def __init__(self, manifold: Manifold | ProductManifold, learnable_params: bool = True):
    super().__init__()

    self.manifold = manifold

    if learnable_params:
        self.temperature = nn.Parameter(torch.tensor(1.0))
        self.bias = nn.Parameter(torch.tensor(0.0))
    else:
        self.register_buffer("temperature", torch.tensor(1.0))
        self.register_buffer("bias", torch.tensor(0.0))

forward(X, A_hat=None)

Forward pass through Fermi-Dirac decoder.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Node embeddings

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Ignored (for compatibility)

Returns:
  • Float[Tensor, 'n_nodes n_nodes']

    Edge probabilities (logits, apply sigmoid if needed)

Source code in manify/predictors/nn/layers.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes n_nodes"]:
    """Forward pass through Fermi-Dirac decoder.

    Args:
        X: Node embeddings
        A_hat: Ignored (for compatibility)

    Returns:
        Edge probabilities (logits, apply sigmoid if needed)
    """
    # Compute pairwise distances
    pairwise_dist = self.manifold.pdist2(X)

    # Apply Fermi-Dirac transformation
    logits = -(pairwise_dist - self.bias) / self.temperature

    return logits

StereographicLayerNorm(manifold, embedding_dim)

Bases: Module

Stereographic Layer Normalization.

Layer normalization is undefined directly on a curved manifold, so we apply an ordinary Euclidean nn.LayerNorm in the tangent space at the origin (logmap0 -> LayerNorm -> expmap0). For a stereographic ProductManifold the tangent space at the origin is Euclidean of dimension manifold.dim and logmap0/expmap0 handle the per-component curvatures, so no explicit curvature broadcasting is required. The output is re-projected onto the manifold for numerical safety. In the curvature-zero limit this reduces to a plain LayerNorm.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry. Must be stereographic.

  • embedding_dim (int) –

    Embedding dimension of the input points (manifold.dim).

Attributes:
  • manifold

    The manifold object for geometric operations.

  • norm

    Tangent-space layer-norm wrapper.

Source code in manify/predictors/nn/layers.py
388
389
390
391
392
def __init__(self, manifold: Manifold | ProductManifold, embedding_dim: int):
    super().__init__()

    self.manifold = manifold
    self.norm = _tangent_module(manifold, nn.LayerNorm(embedding_dim))

forward(X)

Apply layer normalization on the stereographic manifold.

Source code in manify/predictors/nn/layers.py
394
395
396
def forward(self, X: Float[torch.Tensor, "n_nodes dim"]) -> Float[torch.Tensor, "n_nodes dim"]:
    """Apply layer normalization on the stereographic manifold."""
    return self.manifold.manifold.projx(self.norm(X))

GeometricLinearizedAttention(num_heads, head_dim)

Bases: Module

Faithful gyrovector linear attention (FPS-T, arXiv:2309.04082, Eqs 6, 7, 11).

This is the kernelized mixed-curvature attention from "Curve Your Attention: Mixed-Curvature Transformers for Graph Representation Learning". It operates per head on its own :math:\kappa_h-stereographic space. Following the rest of manify it works on a single graph (no batch dimension): inputs are [num_heads, n_nodes, head_dim] and the mask is the [n_nodes, n_nodes] adjacency matrix (None means full attention).

Inputs
  • V -- value points on the per-head :math:\kappa_h-stereographic manifold.
  • Q, K -- query/key tangent vectors at the corresponding V_i (Eq 5).

Scores (Eq 6): parallel-transport Q_i and K_j to the origin (parallel_transport0back), apply the feature map :math:\phi(x)=\mathrm{ELU}(x)+1, and take a Euclidean inner product there: :math:\alpha_{ij}\approx\phi(\tilde Q_i)^\top\phi(\tilde K_j).

Aggregation (Eq 7) is the Einstein midpoint, kernelized (Eq 11):

\[ \mathrm{Aggregate}_\kappa(V,\alpha)_i = \tfrac{1}{2}\otimes_\kappa \sum_j \frac{\alpha_{ij}\,\lambda^\kappa_{V_j}} {\sum_k \alpha_{ik}(\lambda^\kappa_{V_k}-1)}\, V_j, \]

where :math:\lambda^\kappa is the conformal factor and :math:\tfrac12\otimes_\kappa is Mobius scalar multiplication. Writing :math:\tilde V_i=\frac{\lambda^\kappa_{V_i}}{\lambda^\kappa_{V_i}-1}V_i and :math:\phi'(\tilde K)_i=\phi(\tilde K_i)(\lambda^\kappa_{V_i}-1), the per-query output is :math:\tfrac12\otimes_\kappa\big[\phi(\tilde Q)\,(\phi'(\tilde K)^\top \tilde V)\big]_i, which is :math:O(N+M) in the full-attention case. As :math:\kappa\to0 the Einstein midpoint reduces to the ordinary (Euclidean) weighted mean.

Parameters:
  • num_heads (int) –

    Number of attention heads.

  • head_dim (int) –

    Dimension of each attention head.

Attributes:
  • num_heads

    Number of attention heads.

  • head_dim

    Dimension of each attention head.

Source code in manify/predictors/nn/layers.py
438
439
440
441
442
443
444
def __init__(self, num_heads: int, head_dim: int):
    super().__init__()

    self.num_heads = num_heads
    self.head_dim = head_dim
    self._epsilon = 1e-6
    self._clamp_epsilon = 1e-10

forward(Q, K, V, curvatures, mask=None)

Forward pass for faithful gyrovector linear attention.

Parameters:
  • Q (Float[Tensor, 'num_heads n_nodes head_dim']) –

    Query tangent vectors at V, shape [num_heads, n_nodes, head_dim].

  • K (Float[Tensor, 'num_heads n_nodes head_dim']) –

    Key tangent vectors at V, shape [num_heads, n_nodes, head_dim].

  • V (Float[Tensor, 'num_heads n_nodes head_dim']) –

    Value points on the per-head manifold, shape [num_heads, n_nodes, head_dim].

  • curvatures (Float[Tensor, 'num_heads 1 1']) –

    Per-head curvatures, shape [num_heads, 1, 1].

  • mask (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Optional adjacency/attention mask, shape [n_nodes, n_nodes]. Entry (i, j) weights how much query i attends to key/value j. None means full attention.

Returns:
  • Float[Tensor, 'num_heads n_nodes head_dim']

    Aggregated value points on the per-head manifold, shape [num_heads, n_nodes, head_dim].

Source code in manify/predictors/nn/layers.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def forward(
    self,
    Q: Float[torch.Tensor, "num_heads n_nodes head_dim"],
    K: Float[torch.Tensor, "num_heads n_nodes head_dim"],
    V: Float[torch.Tensor, "num_heads n_nodes head_dim"],
    curvatures: Float[torch.Tensor, "num_heads 1 1"],
    mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "num_heads n_nodes head_dim"]:
    """Forward pass for faithful gyrovector linear attention.

    Args:
        Q: Query tangent vectors at ``V``, shape ``[num_heads, n_nodes, head_dim]``.
        K: Key tangent vectors at ``V``, shape ``[num_heads, n_nodes, head_dim]``.
        V: Value points on the per-head manifold, shape ``[num_heads, n_nodes, head_dim]``.
        curvatures: Per-head curvatures, shape ``[num_heads, 1, 1]``.
        mask: Optional adjacency/attention mask, shape ``[n_nodes, n_nodes]``. Entry ``(i, j)``
            weights how much query ``i`` attends to key/value ``j``. ``None`` means full attention.

    Returns:
        Aggregated value points on the per-head manifold, shape ``[num_heads, n_nodes, head_dim]``.
    """
    k = curvatures
    math = geoopt.manifolds.stereographic.math

    # Eq 6: parallel-transport Q, K (tangent at V) to the origin, then feature map phi = elu + 1.
    q_tilde = math.parallel_transport0back(V, Q, k=k)  # [H, N, d]
    k_tilde = math.parallel_transport0back(V, K, k=k)  # [H, N, d]
    phi_q = nn.functional.elu(q_tilde) + 1.0  # phi(Q~)
    phi_k = nn.functional.elu(k_tilde) + 1.0  # phi(K~)

    # Conformal factor lambda^kappa at each value point (Eq 7).
    lam = math.lambda_x(x=V, k=k, keepdim=True, dim=-1)  # [H, N, 1]
    denom = geoopt.utils.clamp_abs(lam - 1, self._clamp_epsilon)  # (lambda - 1), [H, N, 1]

    v_tilde = (lam / denom) * V  # V~_i = lambda_i / (lambda_i - 1) * V_i
    phi_k_prime = denom * phi_k  # phi'(K~)_i = phi(K~)_i * (lambda_i - 1)

    if mask is None:
        # Kernelized O(N + M) form (Eq 11). The numerator effectively weights V_j by
        # alpha_ij * lambda_j = (phi(Q~)_i . phi'(K~)_j) * V~_j, and the normalizer by
        # alpha_ij * (lambda_j - 1) = phi(Q~)_i . phi'(K~)_j.
        context = torch.einsum("hnd,hne->hde", phi_k_prime, v_tilde)  # phi'(K)^T V~, [H, d, d]
        numerator = torch.einsum("hnd,hde->hne", phi_q, context)  # [H, N, d]
        normalizer = torch.einsum("hnd,hd->hn", phi_q, phi_k_prime.sum(dim=-2))  # [H, N]
    else:
        # Explicit (masked) attention scores -- supports a sparse adjacency, O(N^2 d).
        # Identical math to the kernelized path: numerator weights V_j by alpha_ij * lambda_j,
        # normalizer by alpha_ij * (lambda_j - 1).
        alpha = torch.einsum("hnd,hmd->hnm", phi_q, phi_k) * mask[None]  # [H, N, N]
        numerator = torch.einsum("hnm,hme->hne", alpha, lam * V)  # [H, N, d]
        normalizer = torch.einsum("hnm,hm->hn", alpha, denom.squeeze(-1))  # [H, N]

    norm_inv = 1.0 / normalizer.masked_fill(normalizer == 0, self._epsilon)  # [H, N]
    midpoint = numerator * norm_inv.unsqueeze(-1)  # weighted gyro-sum, [H, N, d]

    # Einstein midpoint finishes with a Mobius half-scaling; project for numerical safety.
    out = math.project(midpoint, k=k)
    out = math.mobius_scalar_mul(torch.tensor(0.5, dtype=out.dtype, device=out.device), out, k=k, dim=-1)
    out = math.project(out, k=k)
    return out

StereographicAttention(manifold, num_heads, dim, head_dim, init_curvatures=None)

Bases: Module

Mixed-curvature multi-head attention for a single graph of [n_nodes, dim] tokens (FPS-T).

Faithful implementation of the multi-head attention from "Curve Your Attention" (arXiv:2309.04082). Each head h operates on its OWN :math:\kappa_h-stereographic space with an independent learnable curvature, and the multi-head output is the product over heads (:math:\bigotimes_h \mathrm{st}_{\kappa_h}) -- heads are product-manifold components, so per-head outputs are concatenated and never reshaped across heads. The per-head curvatures are decoupled from the input manifold.

Per head (Eq 5): values are points on :math:\mathrm{st}_{\kappa_h} and queries/keys live in the tangent space at the corresponding value point. We obtain these by mapping the input to the tangent space at the origin (logmap0), applying an ordinary Euclidean nn.Linear to the head dimension, and exponentiating into :math:\mathrm{st}_{\kappa_h} for the values; queries/keys are the (Euclidean) projections re-based to the tangent space at each value point via parallel transport from the origin. Aggregation is the gyrovector Einstein midpoint (:class:GeometricLinearizedAttention). The masked output product manifold is mapped back to the input manifold by a tangent-space linear projection (logmap0 -> Linear -> expmap0).

Parameters:
  • manifold (Manifold | ProductManifold) –

    Stereographic Manifold or ProductManifold defining the input/output geometry.

  • num_heads (int) –

    Number of attention heads (product-manifold components).

  • dim (int) –

    Embedding dimension of the input/output points (manifold.dim).

  • head_dim (int) –

    Dimension of each attention head.

  • init_curvatures (list[float] | None, default: None ) –

    Optional list of initial per-head curvatures (length num_heads). Defaults to -1 for the first half of heads and +1 for the rest (mixed curvature).

Attributes:
  • manifold

    The (input/output) manifold object.

  • num_heads

    Number of attention heads.

  • head_dim

    Dimensionality of each attention head.

  • curvatures

    Learnable per-head curvatures, shape [num_heads].

  • W_q

    Euclidean (tangent-space) linear projection to query vectors.

  • W_k

    Euclidean (tangent-space) linear projection to key vectors.

  • W_v

    Euclidean (tangent-space) linear projection to value vectors.

  • attn

    Gyrovector Einstein-midpoint attention module.

  • W_o

    Euclidean (tangent-space) linear output projection.

Source code in manify/predictors/nn/layers.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
def __init__(
    self,
    manifold: Manifold | ProductManifold,
    num_heads: int,
    dim: int,
    head_dim: int,
    init_curvatures: list[float] | None = None,
):
    super().__init__()

    self.manifold = manifold
    self.num_heads = num_heads
    self.head_dim = head_dim
    inner = num_heads * head_dim

    # Learnable per-head curvatures (decoupled from the input manifold). Default: mixed curvature.
    if init_curvatures is None:
        init_curvatures = [-1.0] * (num_heads // 2) + [1.0] * (num_heads - num_heads // 2)
    self.curvatures = nn.Parameter(torch.tensor(init_curvatures, dtype=torch.float))

    # Tangent-space linear projections (paper-faithful; allow free head_dim).
    self.W_q = nn.Linear(dim, inner)
    self.W_k = nn.Linear(dim, inner)
    self.W_v = nn.Linear(dim, inner)
    self.W_o = nn.Linear(inner, dim)

    self.attn = GeometricLinearizedAttention(num_heads=num_heads, head_dim=head_dim)

forward(X, mask=None)

Forward pass for the mixed-curvature attention layer.

Source code in manify/predictors/nn/layers.py
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the mixed-curvature attention layer."""
    math = geoopt.manifolds.stereographic.math
    k = self._per_head_curvatures()  # [H, 1, 1]

    # Tangent space at the origin of the input manifold, then per-head Euclidean projections.
    H = self.manifold.manifold.logmap0(X)  # [N, dim]
    q_tan = self._split_heads(self.W_q(H))  # [H, N, d] (Euclidean, at origin)
    k_tan = self._split_heads(self.W_k(H))
    v_tan = self._split_heads(self.W_v(H))

    # Values are points on the per-head st_{kappa_h}; Q/K are tangent vectors AT each value point.
    V = math.expmap0(v_tan, k=k)  # [H, N, d] points on st_kappa_h
    Q = math.parallel_transport0(V, q_tan, k=k)  # tangent at V_i (Eq 5)
    K = math.parallel_transport0(V, k_tan, k=k)

    attn_out = self.attn(Q, K, V, curvatures=k, mask=mask)  # [H, N, d] points on st_kappa_h

    # Multi-head output is the product over heads: concat per-head tangent coords, project back.
    attn_tan = self._combine_heads(math.logmap0(attn_out, k=k))  # [N, inner] (tangent)
    out_tan = self.W_o(attn_tan)  # [N, dim]
    return self.manifold.manifold.expmap0(out_tan)

StereographicTransformer(manifold, num_heads, dim, head_dim, use_layer_norm=True, init_curvatures=None)

Bases: Module

Mixed-curvature transformer block on a single graph of [n_nodes, dim] tokens (FPS-T).

A pre-norm transformer block adapted to a stereographic (product) manifold per "Curve Your Attention" (arXiv:2309.04082): a mixed-curvature multi-head attention sublayer (:class:StereographicAttention, faithful gyrovector Einstein-midpoint aggregation with learnable per-head curvatures) followed by a manifold feedforward sublayer, each wrapped in a Mobius-addition residual connection. Tokens are graph nodes; the mask is the adjacency matrix A_hat (None for full attention). As the curvatures vanish the block reduces to a standard Euclidean linear-attention transformer block.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Stereographic Manifold or ProductManifold defining the geometry.

  • num_heads (int) –

    Number of attention heads.

  • dim (int) –

    Dimensionality of the input features (manifold.dim).

  • head_dim (int) –

    Dimensionality of each attention head.

  • use_layer_norm (bool, default: True ) –

    Whether to apply (tangent-space) layer normalization.

  • init_curvatures (list[float] | None, default: None ) –

    Optional initial per-head curvatures passed to the attention sublayer.

Attributes:
  • manifold

    The manifold object for geometric operations.

  • mha

    Mixed-curvature multi-head attention module.

  • norm1 (Module) –

    First normalization layer (Identity or StereographicLayerNorm).

  • norm2 (Module) –

    Second normalization layer (Identity or StereographicLayerNorm).

  • ff1

    First feedforward KappaGCNLayer (with ReLU nonlinearity).

  • ff2

    Second feedforward KappaGCNLayer (no nonlinearity).

Source code in manify/predictors/nn/layers.py
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def __init__(
    self,
    manifold: Manifold | ProductManifold,
    num_heads: int,
    dim: int,
    head_dim: int,
    use_layer_norm: bool = True,
    init_curvatures: list[float] | None = None,
):
    super().__init__()

    if not manifold.is_stereographic:
        raise ValueError(
            "Manifold must be stereographic for StereographicTransformer to work. "
            "Please use manifold.stereographic() to convert."
        )

    self.manifold = manifold
    self.mha = StereographicAttention(
        manifold=manifold, num_heads=num_heads, dim=dim, head_dim=head_dim, init_curvatures=init_curvatures
    )

    if use_layer_norm:
        self.norm1: nn.Module = StereographicLayerNorm(manifold=manifold, embedding_dim=dim)
        self.norm2: nn.Module = StereographicLayerNorm(manifold=manifold, embedding_dim=dim)
    else:
        self.norm1 = nn.Identity()
        self.norm2 = nn.Identity()

    self.ff1 = KappaGCNLayer(in_features=dim, out_features=dim, manifold=manifold, nonlinearity=torch.relu)
    self.ff2 = KappaGCNLayer(in_features=dim, out_features=dim, manifold=manifold, nonlinearity=None)

forward(X, mask=None)

Forward pass through the mixed-curvature transformer block.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Node features as points on the manifold, shape [n_nodes, dim].

  • mask (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Optional adjacency matrix A_hat; None means full attention.

Returns:
  • Float[Tensor, 'n_nodes dim']

    Updated node features as points on the manifold, shape [n_nodes, dim].

Source code in manify/predictors/nn/layers.py
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass through the mixed-curvature transformer block.

    Args:
        X: Node features as points on the manifold, shape ``[n_nodes, dim]``.
        mask: Optional adjacency matrix ``A_hat``; ``None`` means full attention.

    Returns:
        Updated node features as points on the manifold, shape ``[n_nodes, dim]``.
    """
    man = self.manifold.manifold

    # Pre-norm attention sublayer with Mobius-addition residual.
    attn = self.mha(self.norm1(X), mask)
    X = man.projx(man.mobius_add(attn, X))

    # Pre-norm feedforward sublayer with Mobius-addition residual.
    ff = self._mlpblock(self.norm2(X))
    X = man.projx(man.mobius_add(ff, X))

    return X