Nn

manify.predictors.nn

Neural network layers for KappaGCN and related models.

FermiDiracDecoder(manifold, learnable_params=True)

Bases: Module

Fermi-Dirac decoder for link prediction tasks.

Computes pairwise distances and applies Fermi-Dirac transformation to predict edge probabilities.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • learnable_params (bool, default: True ) –

    If True, temperature and bias are learnable parameters. If False, they are fixed to 1.0 and 0.0, respectively.

Source code in manify/predictors/nn/layers.py
306
307
308
309
310
311
312
313
314
315
316
def __init__(self, manifold: Manifold | ProductManifold, learnable_params: bool = True):
    super().__init__()

    self.manifold = manifold

    if learnable_params:
        self.temperature = nn.Parameter(torch.tensor(1.0))
        self.bias = nn.Parameter(torch.tensor(0.0))
    else:
        self.register_buffer("temperature", torch.tensor(1.0))
        self.register_buffer("bias", torch.tensor(0.0))

forward(X, A_hat=None)

Forward pass through Fermi-Dirac decoder.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Node embeddings

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Ignored (for compatibility)

  • return_pairwise

    If True, return full pairwise matrix. If False, return flattened upper triangle.

Returns:
  • Float[Tensor, 'n_nodes n_nodes']

    Edge probabilities (logits, apply sigmoid if needed)

Source code in manify/predictors/nn/layers.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes n_nodes"]:
    """Forward pass through Fermi-Dirac decoder.

    Args:
        X: Node embeddings
        A_hat: Ignored (for compatibility)
        return_pairwise: If True, return full pairwise matrix. If False, return flattened upper triangle.

    Returns:
        Edge probabilities (logits, apply sigmoid if needed)
    """
    # Compute pairwise distances
    pairwise_dist = self.manifold.pdist2(X)

    # Apply Fermi-Dirac transformation
    logits = -(pairwise_dist - self.bias) / self.temperature

    return logits

KappaGCNLayer(in_features, out_features, manifold, nonlinearity=torch.relu)

Bases: Module

Implementation for the Kappa GCN layer.

Parameters:
  • in_features (int) –

    Number of input features

  • out_features (int) –

    Number of output features

  • manifold (Manifold) –

    Manifold object for the Kappa GCN

  • nonlinearity (Callable | None, default: relu ) –

    Function for nonlinear activation.

Attributes:
  • W

    Weight matrix parameter.

  • sigma

    Nonlinear activation function applied via the manifold.

  • manifold

    The manifold object for geometric operations.

Source code in manify/predictors/nn/layers.py
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self, in_features: int, out_features: int, manifold: Manifold, nonlinearity: Callable | None = torch.relu
):
    super().__init__()

    # Parameters are Euclidean, straightforwardly
    self.W = torch.nn.Parameter(torch.randn(in_features, out_features) * 0.01)

    # Nonlinearity must be applied via the manifold
    self.sigma = manifold.apply(nonlinearity) if nonlinearity else lambda x: x

    # Also store manifold
    self.manifold = manifold

forward(X, A_hat=None)

Forward pass for the Kappa GCN layer.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Embedding matrix

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Normalized adjacency matrix

Returns:
  • AXW( Float[Tensor, 'n_nodes dim'] ) –

    Transformed node features after message passing and nonlinear activation.

Source code in manify/predictors/nn/layers.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the Kappa GCN layer.

    Args:
        X: Embedding matrix
        A_hat: Normalized adjacency matrix

    Returns:
        AXW: Transformed node features after message passing and nonlinear activation.
    """
    # 1. right-multiply X by W - mobius_matvec broadcasts correctly (verified)
    XW = self.manifold.manifold.mobius_matvec(m=self.W, x=X)

    # 2. left-multiply (X @ W) by A_hat - we need our own implementation for this
    if A_hat is None:
        AXW = XW
    elif isinstance(self.manifold, ProductManifold):
        XWs = self.manifold.factorize(XW)
        AXW = torch.hstack([self._left_multiply(A_hat, XW, M) for XW, M in zip(XWs, self.manifold.P, strict=False)])
    else:
        AXW = self._left_multiply(A_hat, XW, self.manifold)

    # 3. Apply nonlinearity - note that sigma is wrapped with our manifold.apply decorator
    AXW = self.sigma(AXW)

    return AXW

KappaSequential(*layers)

Bases: Module

Sequential container for κ-layers that properly handles adjacency matrices.

Similar to nn.Sequential but passes the adjacency matrix through each layer. All layers should accept (X, A_hat) and return X.

Parameters:
  • *layers (Module, default: () ) –

    Variable number of layers to be added to the sequence.

Source code in manify/predictors/nn/layers.py
114
115
116
def __init__(self, *layers: nn.Module):
    super().__init__()
    self.layers = nn.ModuleList(layers)

forward(X, A_hat=None)

Forward pass through all layers.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Adjacency matrix passed to each layer

Returns:
  • Float[Tensor, 'n_nodes out_dim']

    Output after passing through all layers

Source code in manify/predictors/nn/layers.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes out_dim"]:
    """Forward pass through all layers.

    Args:
        X: Input features
        A_hat: Adjacency matrix passed to each layer

    Returns:
        Output after passing through all layers
    """
    for layer in self.layers:
        X = layer(X, A_hat)
    return X

append(layer)

Add a layer to the end of the sequence.

Source code in manify/predictors/nn/layers.py
134
135
136
def append(self, layer: nn.Module) -> None:
    """Add a layer to the end of the sequence."""
    self.layers.append(layer)

StereographicLogits(out_features, manifold, apply_softmax=False)

Bases: Module

Stereographic logits layer for classification and regression on product manifolds.

Computes signed distances from hyperplanes in the product manifold space. Can optionally apply softmax for classification tasks.

Parameters:
  • out_features (int) –

    Number of output classes (dimensionality of output space)

  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • apply_softmax (bool, default: False ) –

    Whether to apply softmax to the output logits (default: False)

Source code in manify/predictors/nn/layers.py
170
171
172
173
174
175
176
177
178
179
180
181
def __init__(self, out_features: int, manifold: Manifold | ProductManifold, apply_softmax: bool = False):
    super().__init__()

    self.out_features = out_features
    self.manifold = manifold
    self.apply_softmax = apply_softmax

    # Weight matrix (Euclidean parameters)
    self.W = nn.Parameter(torch.randn(manifold.dim, out_features) * 0.01)

    # Bias points on the manifold
    self.p_ks = geoopt.ManifoldParameter(torch.zeros(out_features, manifold.dim), manifold=manifold.manifold)

forward(X, A_hat=None, aggregate_logits=False)

Forward pass through stereographic logits.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Optional adjacency matrix for logit aggregation

  • aggregate_logits (bool, default: False ) –

    Whether to aggregate logits using adjacency matrix

Returns:
  • Float[Tensor, 'n_nodes n_classes']

    Logits (or probabilities if apply_softmax=True)

Source code in manify/predictors/nn/layers.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    aggregate_logits: bool = False,
) -> Float[torch.Tensor, "n_nodes n_classes"]:
    """Forward pass through stereographic logits.

    Args:
        X: Input features
        A_hat: Optional adjacency matrix for logit aggregation
        aggregate_logits: Whether to aggregate logits using adjacency matrix

    Returns:
        Logits (or probabilities if apply_softmax=True)
    """
    # Compute logits based on manifold type
    if isinstance(self.manifold, ProductManifold):
        logits = self._get_logits_product_manifold(X, self.W, self.p_ks, self.manifold)
    else:
        logits = self._get_logits_single_manifold(X, self.W, self.p_ks, self.manifold, return_inner_products=False)

    # Optional aggregation via adjacency matrix
    if A_hat is not None and aggregate_logits:
        logits = A_hat @ logits

    # Optional softmax for classification
    if self.apply_softmax:
        logits = torch.softmax(logits, dim=-1)

    return logits

layers

Neural network layers for product manifolds.

KappaGCNLayer(in_features, out_features, manifold, nonlinearity=torch.relu)

Bases: Module

Implementation for the Kappa GCN layer.

Parameters:
  • in_features (int) –

    Number of input features

  • out_features (int) –

    Number of output features

  • manifold (Manifold) –

    Manifold object for the Kappa GCN

  • nonlinearity (Callable | None, default: relu ) –

    Function for nonlinear activation.

Attributes:
  • W

    Weight matrix parameter.

  • sigma

    Nonlinear activation function applied via the manifold.

  • manifold

    The manifold object for geometric operations.

Source code in manify/predictors/nn/layers.py
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self, in_features: int, out_features: int, manifold: Manifold, nonlinearity: Callable | None = torch.relu
):
    super().__init__()

    # Parameters are Euclidean, straightforwardly
    self.W = torch.nn.Parameter(torch.randn(in_features, out_features) * 0.01)

    # Nonlinearity must be applied via the manifold
    self.sigma = manifold.apply(nonlinearity) if nonlinearity else lambda x: x

    # Also store manifold
    self.manifold = manifold
forward(X, A_hat=None)

Forward pass for the Kappa GCN layer.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Embedding matrix

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Normalized adjacency matrix

Returns:
  • AXW( Float[Tensor, 'n_nodes dim'] ) –

    Transformed node features after message passing and nonlinear activation.

Source code in manify/predictors/nn/layers.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the Kappa GCN layer.

    Args:
        X: Embedding matrix
        A_hat: Normalized adjacency matrix

    Returns:
        AXW: Transformed node features after message passing and nonlinear activation.
    """
    # 1. right-multiply X by W - mobius_matvec broadcasts correctly (verified)
    XW = self.manifold.manifold.mobius_matvec(m=self.W, x=X)

    # 2. left-multiply (X @ W) by A_hat - we need our own implementation for this
    if A_hat is None:
        AXW = XW
    elif isinstance(self.manifold, ProductManifold):
        XWs = self.manifold.factorize(XW)
        AXW = torch.hstack([self._left_multiply(A_hat, XW, M) for XW, M in zip(XWs, self.manifold.P, strict=False)])
    else:
        AXW = self._left_multiply(A_hat, XW, self.manifold)

    # 3. Apply nonlinearity - note that sigma is wrapped with our manifold.apply decorator
    AXW = self.sigma(AXW)

    return AXW

KappaSequential(*layers)

Bases: Module

Sequential container for κ-layers that properly handles adjacency matrices.

Similar to nn.Sequential but passes the adjacency matrix through each layer. All layers should accept (X, A_hat) and return X.

Parameters:
  • *layers (Module, default: () ) –

    Variable number of layers to be added to the sequence.

Source code in manify/predictors/nn/layers.py
114
115
116
def __init__(self, *layers: nn.Module):
    super().__init__()
    self.layers = nn.ModuleList(layers)
forward(X, A_hat=None)

Forward pass through all layers.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Adjacency matrix passed to each layer

Returns:
  • Float[Tensor, 'n_nodes out_dim']

    Output after passing through all layers

Source code in manify/predictors/nn/layers.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes out_dim"]:
    """Forward pass through all layers.

    Args:
        X: Input features
        A_hat: Adjacency matrix passed to each layer

    Returns:
        Output after passing through all layers
    """
    for layer in self.layers:
        X = layer(X, A_hat)
    return X
append(layer)

Add a layer to the end of the sequence.

Source code in manify/predictors/nn/layers.py
134
135
136
def append(self, layer: nn.Module) -> None:
    """Add a layer to the end of the sequence."""
    self.layers.append(layer)

StereographicLogits(out_features, manifold, apply_softmax=False)

Bases: Module

Stereographic logits layer for classification and regression on product manifolds.

Computes signed distances from hyperplanes in the product manifold space. Can optionally apply softmax for classification tasks.

Parameters:
  • out_features (int) –

    Number of output classes (dimensionality of output space)

  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • apply_softmax (bool, default: False ) –

    Whether to apply softmax to the output logits (default: False)

Source code in manify/predictors/nn/layers.py
170
171
172
173
174
175
176
177
178
179
180
181
def __init__(self, out_features: int, manifold: Manifold | ProductManifold, apply_softmax: bool = False):
    super().__init__()

    self.out_features = out_features
    self.manifold = manifold
    self.apply_softmax = apply_softmax

    # Weight matrix (Euclidean parameters)
    self.W = nn.Parameter(torch.randn(manifold.dim, out_features) * 0.01)

    # Bias points on the manifold
    self.p_ks = geoopt.ManifoldParameter(torch.zeros(out_features, manifold.dim), manifold=manifold.manifold)
forward(X, A_hat=None, aggregate_logits=False)

Forward pass through stereographic logits.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Optional adjacency matrix for logit aggregation

  • aggregate_logits (bool, default: False ) –

    Whether to aggregate logits using adjacency matrix

Returns:
  • Float[Tensor, 'n_nodes n_classes']

    Logits (or probabilities if apply_softmax=True)

Source code in manify/predictors/nn/layers.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    aggregate_logits: bool = False,
) -> Float[torch.Tensor, "n_nodes n_classes"]:
    """Forward pass through stereographic logits.

    Args:
        X: Input features
        A_hat: Optional adjacency matrix for logit aggregation
        aggregate_logits: Whether to aggregate logits using adjacency matrix

    Returns:
        Logits (or probabilities if apply_softmax=True)
    """
    # Compute logits based on manifold type
    if isinstance(self.manifold, ProductManifold):
        logits = self._get_logits_product_manifold(X, self.W, self.p_ks, self.manifold)
    else:
        logits = self._get_logits_single_manifold(X, self.W, self.p_ks, self.manifold, return_inner_products=False)

    # Optional aggregation via adjacency matrix
    if A_hat is not None and aggregate_logits:
        logits = A_hat @ logits

    # Optional softmax for classification
    if self.apply_softmax:
        logits = torch.softmax(logits, dim=-1)

    return logits

FermiDiracDecoder(manifold, learnable_params=True)

Bases: Module

Fermi-Dirac decoder for link prediction tasks.

Computes pairwise distances and applies Fermi-Dirac transformation to predict edge probabilities.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • learnable_params (bool, default: True ) –

    If True, temperature and bias are learnable parameters. If False, they are fixed to 1.0 and 0.0, respectively.

Source code in manify/predictors/nn/layers.py
306
307
308
309
310
311
312
313
314
315
316
def __init__(self, manifold: Manifold | ProductManifold, learnable_params: bool = True):
    super().__init__()

    self.manifold = manifold

    if learnable_params:
        self.temperature = nn.Parameter(torch.tensor(1.0))
        self.bias = nn.Parameter(torch.tensor(0.0))
    else:
        self.register_buffer("temperature", torch.tensor(1.0))
        self.register_buffer("bias", torch.tensor(0.0))
forward(X, A_hat=None)

Forward pass through Fermi-Dirac decoder.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Node embeddings

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Ignored (for compatibility)

  • return_pairwise

    If True, return full pairwise matrix. If False, return flattened upper triangle.

Returns:
  • Float[Tensor, 'n_nodes n_nodes']

    Edge probabilities (logits, apply sigmoid if needed)

Source code in manify/predictors/nn/layers.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes n_nodes"]:
    """Forward pass through Fermi-Dirac decoder.

    Args:
        X: Node embeddings
        A_hat: Ignored (for compatibility)
        return_pairwise: If True, return full pairwise matrix. If False, return flattened upper triangle.

    Returns:
        Edge probabilities (logits, apply sigmoid if needed)
    """
    # Compute pairwise distances
    pairwise_dist = self.manifold.pdist2(X)

    # Apply Fermi-Dirac transformation
    logits = -(pairwise_dist - self.bias) / self.temperature

    return logits

StereographicLayerNorm(manifold, embedding_dim, curvatures)

Bases: Module

Stereographic Layer Normalization.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry.

  • embedding_dim (int) –

    Embedding dimension of the input points.

  • curvatures (Tensor['num_heads 1 1']) –

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

Attributes:
  • manifold

    The manifold object for geometric operations.

  • stereographic_norm

    Stereographic layernorm used in the tangent space.

  • curvatures

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

Source code in manify/predictors/nn/layers.py
358
359
360
361
362
363
364
365
def __init__(
    self, manifold: Manifold | ProductManifold, embedding_dim: int, curvatures: torch.Tensor["num_heads 1 1"]
):
    super().__init__()

    self.manifold = manifold
    self.stereographic_norm = self.manifold.apply(nn.LayerNorm(embedding_dim))
    self.curvatures = curvatures
forward(X)

Apply layer normalization on the stereographic manifold.

Source code in manify/predictors/nn/layers.py
367
368
369
370
371
def forward(self, X: Float[torch.Tensor, "n_nodes dim"]) -> Float[torch.Tensor, "n_nodes dim"]:
    """Apply layer normalization on the stereographic manifold."""
    norm_X = self.stereographic_norm(X)
    output = geoopt.manifolds.stereographic.math.project(norm_X, self.curvatures)
    return output

GeometricLinearizedAttention(curvatures, num_heads, head_dim)

Bases: Module

Geometric Linearized Attention.

Parameters:
  • curvatures (float | list[float]) –

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

  • num_heads (int) –

    Number of attention heads.

  • head_dim (int) –

    Dimension of each attention head.

Attributes:
  • num_heads

    Number of attention heads.

  • head_dim

    Dimension of each attention head.

  • epsilon

    Small epsilon for masking inverse denominator (constant).

  • clamp_epsilon

    Minimum clamp value for numerical stability in gamma denominator (constant).

Source code in manify/predictors/nn/layers.py
390
391
392
393
394
395
396
397
398
def __init__(self, curvatures: float | list[float], num_heads: int, head_dim: int):
    super().__init__()

    self.num_heads = num_heads
    self.curvatures = curvatures

    self.head_dim = head_dim
    self._epsilon = 1e-5
    self._clamp_epsilon = 1e-10
forward(Q, K, V, mask)

Forward pass for the geometric linearized attention layer.

Parameters:
  • Q (Float[Tensor, 'batch_size num_heads n_nodes head_dim']) –

    Query tensor.

  • K (Float[Tensor, 'batch_size num_heads n_nodes head_dim']) –

    Key tensor.

  • V (Float[Tensor, 'batch_size num_heads n_nodes head_dim']) –

    Value tensor.

  • mask (Float[Tensor, '1 1 n_nodes n_nodes']) –

    Mask tensor for attention.

Returns:
  • Float[Tensor, 'batch_size n_nodes dim']

    Output tensor after applying attention.

Source code in manify/predictors/nn/layers.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def forward(
    self,
    Q: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
    K: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
    V: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
    mask: Float[torch.Tensor, "1 1 n_nodes n_nodes"],
) -> Float[torch.Tensor, "batch_size n_nodes dim"]:
    """Forward pass for the geometric linearized attention layer.

    Args:
        Q: Query tensor.
        K: Key tensor.
        V: Value tensor.
        mask: Mask tensor for attention.

    Returns:
        Output tensor after applying attention.
    """
    v1 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, Q, k=self.curvatures)
    v2 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, K, k=self.curvatures)

    gamma = geoopt.manifolds.stereographic.math.lambda_x(x=V, k=self.curvatures, keepdim=True, dim=-1)
    denominator = geoopt.utils.clamp_abs((gamma - 1), self._clamp_epsilon)

    x = ((gamma / denominator) * V) * mask[None, :, None]

    v1 = nn.functional.elu(v1) + 1
    v2 = (denominator * (nn.functional.elu(v2) + 1)) * mask[None, :, None]

    # Linearized approximation
    v2_cumsum = v2.sum(dim=-2)  # [B, H, D]
    D = torch.einsum("...nd,...d->...n", v1, v2_cumsum.type_as(v1))  # normalization terms
    D_inv = 1.0 / D.masked_fill_(D == 0, self._epsilon)
    context = torch.einsum("...nd,...ne->...de", v2, x)
    X = torch.einsum("...de,...nd,...n->...ne", context, v1, D_inv)

    X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
    X = geoopt.manifolds.stereographic.math.mobius_scalar_mul(
        torch.tensor(0.5, dtype=X.dtype, device=X.device), X, k=self.curvatures, dim=-1
    )
    X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)

    return X

StereographicAttention(manifold, num_heads, dim, head_dim)

Bases: Module

Stereographic Attention Layer.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry.

  • num_heads (int) –

    Number of attention heads.

  • dim (int) –

    Embedding dimension of the input points.

  • head_dim (int) –

    Dimension of each attention head.

Attributes:
  • manifold

    The manifold object for geometric operations.

  • curvatures

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

  • num_heads

    Number of attention heads.

  • head_dim

    Dimensionality of each attention head.

  • W_q

    Linear layer projecting inputs to query vectors.

  • W_k

    Linear layer projecting inputs to key vectors.

  • W_v

    Manifold-aware linear layer projecting to value vectors.

  • attn

    Stereographic multi-head attention module.

  • ff

    Manifold-aware linear layer for the feedforward output.

Source code in manify/predictors/nn/layers.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int):
    super().__init__()

    self.manifold = manifold
    self.num_heads = num_heads
    self.head_dim = head_dim
    self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), self.num_heads)

    self.W_q = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
    self.W_k = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
    self.W_v = KappaGCNLayer(in_features=dim, out_features=self.num_heads * self.head_dim, manifold=self.manifold)

    self.attn = GeometricLinearizedAttention(
        curvatures=self.curvatures, num_heads=self.num_heads, head_dim=self.head_dim
    )
    self.ff = KappaGCNLayer(in_features=self.num_heads * self.head_dim, out_features=dim, manifold=self.manifold)
forward(X, mask=None)

Forward pass for the stereographic attention layer.

Source code in manify/predictors/nn/layers.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the stereographic attention layer."""
    Q = self._split_heads(self.W_q(X))  # [B, H, N, D]
    K = self._split_heads(self.W_k(X))
    V = self._split_heads(self.W_v(X=X))

    attn_out = self.attn(Q, K, V, mask.unsqueeze(0).unsqueeze(0))  # type: ignore
    attn_out = self._combine_heads(attn_out)

    out = self.ff(X=attn_out)

    return out

StereographicTransformer(manifold, num_heads, dim, head_dim, use_layer_norm=True)

Bases: Module

Stereographic Transformer Block.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry.

  • num_heads (int) –

    Number of attention heads.

  • dim (int) –

    Dimensionality of the input features.

  • head_dim (int) –

    Dimensionality of each attention head.

  • use_layer_norm (bool, default: True ) –

    Whether to apply layer normalization in tangent space.

Attributes:
  • manifold

    The manifold object for geometric operations.

  • curvatures

    Manifold curvatures reshaped to [num_heads, 1, 1] for broadcasting.

  • mha

    Multi-head stereographic attention module.

  • norm1

    First normalization layer (can be Identity or StereographicLayerNorm).

  • norm2

    Second normalization layer.

  • mlpblock

    Feedforward network in stereographic space.

  • stereographic_activation

    Activation wrapped to operate in tangent space.

Source code in manify/predictors/nn/layers.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
def __init__(
    self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int, use_layer_norm: bool = True
):
    super().__init__()

    # Check that manifold is stereographic
    if not manifold.is_stereographic:
        raise ValueError(
            "Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
        )

    self.manifold = manifold
    self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), num_heads)
    self.stereographic_activation = self.manifold.apply(nn.ReLU())
    self.mha = StereographicAttention(manifold=self.manifold, num_heads=num_heads, dim=dim, head_dim=head_dim)

    if use_layer_norm:
        self.norm1 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
        self.norm2 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
    else:
        self.norm1 = nn.Identity()
        self.norm2 = nn.Identity()

    self.mlpblock = nn.Sequential(
        KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
        self.stereographic_activation,
        KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
    )
forward(X, mask=None)

Forward pass through the stereographic transformer block.

Source code in manify/predictors/nn/layers.py
581
582
583
584
585
586
587
588
589
590
591
592
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass through the stereographic transformer block."""
    X = geoopt.manifolds.stereographic.math.mobius_add(self.mha(self.norm1(X), mask), X, self.curvatures)
    X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
    X = geoopt.manifolds.stereographic.math.mobius_add(self.mlpblock(self.norm2(X)), X, self.curvatures)
    X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)

    return X