Predictors

manify.predictors

Initialize predictors in the product space.

Implemented predictors include:

Decision Trees and Random Forests: Use geodesic splits in product manifolds via projection angles and manifold-specific geodesic midpoints.

For points in a product manifold \(\mathcal{P} = \mathcal{M}_1 \times \mathcal{M}_2 \times \cdots \times \mathcal{M}_k\), we define splits along two-dimensional subspaces for each component manifold.

Angular representation of splits

For each component manifold \(\mathcal{M}_i\), we project points onto two-dimensional subspaces and represent potential splits using angles. Given a point \(\mathbf{x}\) and a basis dimension \(d\), the projection angle is computed as:

\[ \theta(\mathbf{x}, d) = \tan^{-1}\left(\frac{x_0}{x_d}\right) \]

where \(x_0\) and \(x_d\) are the coordinates in the selected two-dimensional subspace.

The splitting criterion is then:

\[ S(\mathbf{x}, d, \theta) = \begin{cases} 1 & \text{if } \tan^{-1}\left( \frac{x_0}{x_d} \right) \in [\theta, \theta + \pi) \\\\ 0 & \text{otherwise} \end{cases} \]
Geodesic Midpoints for Decision Boundaries

To place decision boundaries optimally between clusters of points, we compute the geodesic midpoint between consecutive angles in the sorted list of projection angles. The midpoint calculation is specific to each manifold type:

\[ \begin{align} \theta_u &= \tan^{-1}\left(\frac{u_0}{u_d}\right) \\ \theta_v &= \tan^{-1}\left(\frac{v_0}{v_d}\right)\\ m_{\mathbb{E}}(\theta_u, \theta_v) &= \tan^{-1}\left( \frac{2}{u_0 + v_0} \right)\\ m_{\mathbb{S}}(\theta_u, \theta_v) &= \frac{\theta_u + \theta_v}{2}\\ m_{\mathbb{H}}(\theta_u, \theta_v) &= \begin{cases} \cot^{-1}\left(V - \sqrt{V^2-1}\right) &\text{if } \theta_u + \theta_v < \pi, \\ \cot^{-1}\left(V + \sqrt{V^2-1}\right) &\text{otherwise.} \end{cases}\\ V &= \frac{\sin(2\theta_u - 2\theta_v)}{2\sin(\theta_u + \theta_v)\sin(\theta_v - \theta_u)} \end{align} \]

\(\kappa\)-GCNs: Extend standard GCNs to operate on both positive and negative curvature using gyrovector operations. Variants include:

Graph Convolutional Networks Background

In a typical (Euclidean) graph convolutional network (GCN), each layer takes the form:

\[ \begin{align} \mathbf{H}^{(0)} &= \mathbf{X} \\\\ \mathbf{H}^{(l+1)} &= \sigma\left( \hat{\mathbf{A}} \mathbf{H}^{(l)} \mathbf{W}^{(l)} + \mathbf{b}^{(l)} \right) \end{align} \]

where \(\hat{\mathbf{A}} \in \mathbb{R}^{n \times n}\) is a normalized adjacency matrix with self-connections, \(\mathbf{X}^{(l)} \in \mathbb{R}^{n \times d}\) is a matrix of features, \(\mathbf{W}^{(l)} \in \mathbb{R}^{d \times e}\) is a weight matrix, \(\mathbf{b}^{(l)} \in \mathbb{R}^e\) is a bias term, and \(\sigma\) is some nonlinearity (e.g., ReLU).

Graph Convolution Layers

Bachmann et al. (2020) describe a way to adapt the typical GCN model for use with \(\mathbf{X} \in \mathbb{S}^d_\kappa\), using gyrovector operations:

\[ \mathbf{H}^{(l+1)} = \sigma^{\otimes_\kappa} \left( \hat{\mathbf{A}} \boxtimes_\kappa \left( \mathbf{H}^{(l)} \otimes_\kappa \mathbf{W}^{(l)} \right) \right) \]

$$ \sigma^{\otimes_\kappa}(\cdot) = \exp_{\mathbf{0}} \left( \sigma\left( \log_{\mathbf{0}}(\cdot) \right) \right) $$ Note that this paper does not include a bias term, although it is reasonable to extend the definition of a GCN layer to include one:

\[ \mathbf{H}^{(l+1)} = \sigma^{\otimes_\kappa} \left( \hat{\mathbf{A}} \boxtimes_\kappa \left( \mathbf{H}^{(l)} \otimes_\kappa \mathbf{W}^{(l)} \right) \oplus \mathbf{b} \right) \]

where \(\mathbf{b} \in \mathbb{S}^d_\kappa\) is a bias vector.

Also note that, in order for each \(\mathbf{H}^{(i)}\) to remain on the same manifold, \(\mathbf{W}^{(i)}\) must be a square matrix. However, this assumption can be relaxed to allow for different dimensionalities and curvatures for each layer.

Stereographic Logits

For classification, we define a \(\kappa\)-stereographic equivalent of a logit layer:

\[ \mathbf{H}^{(L)} = \text{softmax} \left( \hat{\mathbf{A}} \, \operatorname{logits}_{\mathbb{S}^d_\kappa} \left( \mathbf{H}^{(L-1)}, \mathbf{W}^{(L-1)} \right) \right) \]

To implement logits in \(\mathbb{S}^d_\kappa\), we begin by noting that Euclidean logits can be interpreted as signed distances from a hyperplane. This follows from the linear form \(\mathbf{w}_i^\top \mathbf{x} + b_i\) used in traditional classification, where \(\mathbf{w}_i\) is a column of the final weight matrix and \(b_i\) is its corresponding bias.

The magnitude reflects the point's distance from the decision boundary (the hyperplane \(\mathbf{w}_i^\top \mathbf{x} + b_i = 0\)), and the sign determines which side of the hyperplane the point lies on. This formulation encodes both the model’s decision and its confidence.

Bachmann et al. (2020) and Ganea et al. (2018) extend this intuition to non-Euclidean spaces by defining logits using appropriate distance functions.

In \(\kappa\)-GCN, this becomes:

\[ \mathbb{P}(y = k \mid \mathbf{x}) = \text{Softmax} \left( \operatorname{logits}_\mathcal{M}(\mathbf{x}, k) \right) \]
\[ \operatorname{logits}_\mathcal{M}(\mathbf{x}, k) = \frac{ \| \mathbf{a}_k \|_{\mathbf{p}_k} }{ \sqrt{K} } \, \sin_K^{-1} \left( \frac{ 2 \sqrt{|\kappa|} \langle \mathbf{z}_k, \mathbf{a}_k \rangle }{ (1 + \kappa \| \mathbf{z}_k \|^2) \| \mathbf{a}_k \| } \right) \]

Although it is not explicitly stated in Bachmann et al. (2020), we follow Cho et al. (2023) and later Chlenski et al. (2024) in aggregating logits across product manifolds using the \(\ell_2\)-norm of component manifold logits, scaled by the sign of the sum of component inner products:

\[ \operatorname{logits}_\mathcal{P}(\mathbf{x}, k) = \sqrt{ \sum_{\mathcal{M} \in \mathcal{P}} \left( \operatorname{logits}_\mathcal{M}(\mathbf{x}^\mathcal{M}, k) \right)^2 } \cdot \text{sign} \left( \sum_{\mathcal{M} \in \mathcal{P}} \langle \mathbf{x}^\mathcal{M}, \mathbf{a}_k^\mathcal{M} \rangle_\mathcal{M} \right) \]

Finally, for link prediction, we follow Chami et al. (2019) in adopting the standard approach of applying the Fermi–Dirac decoder [Krioukov et al., 2010; Nickel and Kiela, 2017] to predict edges:

\[ \mathbb{P}\big((i, j) \in \mathcal{E} \,\big|\, \mathbf{x}_i, \mathbf{x}_j\big) = \left( \exp \left( \frac{\delta_\mathcal{M}(\mathbf{x}_i, \mathbf{x}_j)^2 - r}{t} \right) + 1 \right)^{-1} \]
  • Perceptrons: Use manifold-specific linear combinations of inputs with sine and sinh activation terms.
Product Space Perceptron

A linear classifier on the product manifold \(\mathcal{P}\) is defined as:

\[ \operatorname{LC}(\mathbf{x}, \mathbf{w}) = \operatorname{sign} \left( \langle \mathbf{w}_\mathbb{E}, \mathbf{x}_\mathbb{E} \rangle + \alpha_\mathbb{S} \sin^{-1} \left( \langle \mathbf{w}_\mathbb{S}, \mathbf{x}_\mathbb{S} \rangle \right) + \alpha_\mathbb{H} \sinh^{-1} \left( \langle \mathbf{w}_\mathbb{H}, \mathbf{x}_\mathbb{H} \rangle_\mathbb{H} \right) + b \right) \]

where \(\mathbf{x}_\mathcal{M}\) (and similarly \(\mathbf{w}_\mathcal{M}\)) denotes the restriction of \(\mathbf{x} \in \mathcal{P}\) to one of its component manifolds. The coefficients \(\alpha_\mathbb{S}\) and $\alpha_\

  • SVMs: Extend kernel SVMs with constraints respecting the geometry of each component manifold, optimizing a margin-based objective under convex and relaxed constraints.
Product Space SVM

The Product Space SVM extends the kernel-based approach described in Section Perceptron by finding a maximum-margin classifier in the product space. The optimization problem is formulated as:

\[ \text{maximize } \varepsilon - \sum_{i=1}^{n} \xi_i \]

subject to

\[ y_i \sum_{j=1}^{n} \beta_j K(\mathbf{x}_i, \mathbf{x}_j) \geq \varepsilon - \xi_i \quad \text{for all } i \in \{1, \ldots, n\} \]

where \(\varepsilon > 0\) and \(\xi_i \geq 0\).

ProductSpaceDT(pm, max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, task='classification', use_special_dims=False, batch_size=None, n_features='d', ablate_midpoints=False, random_state=None, device=None)

Bases: BasePredictor

Decision tree in the product space to handle hyperbolic, euclidean, and hyperspherical data.

Source code in manify/predictors/decision_tree.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def __init__(
    self,
    pm: ProductManifold,
    max_depth: int | None = None,
    min_samples_leaf: int = 1,
    min_samples_split: int = 2,
    min_impurity_decrease: float = 0.0,
    task: Literal["classification", "regression", "link_prediction"] = "classification",
    use_special_dims: bool = False,
    batch_size: int | None = None,
    n_features: Literal["d", "d_choose_2"] = "d",
    ablate_midpoints: bool = False,
    random_state: int | None = None,
    device: str | None = None,
):
    # Initialize the base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)

    # Raise error if manifold is stereographic
    if pm.is_stereographic:
        raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
    if task == "link_prediction":
        raise ValueError(
            "Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
        )

    # Store hyperparameters
    self.pm = pm
    self.max_depth = max_depth or -1
    self.min_samples_leaf = min_samples_leaf
    self.min_samples_split = min_samples_split
    self.min_impurity_decrease = min_impurity_decrease
    self.use_special_dims = use_special_dims
    self.n_features = n_features
    self.ablate_midpoints = ablate_midpoints

    # I use "batched" to mean "all at once" and "batch_size" to mean "in chunks"
    self.batch_size = batch_size
    if batch_size is not None:
        self.batched = False
    else:
        self.batched = True

    # Task-specific stuff
    self.task = task
    self.criterion = "gini" if task == "classification" else "mse"

    # These will become important later, when fit is called
    self.nodes: list[_DecisionNode] = []  # For fitted nodes
    self.permutations: Int[torch.Tensor, "n_classes"] | None = None  # If used as part of a random forest
    self.angle2man: list[int] = []  # Maps preprocessed angles to manifold indices
    self.special_first: list[bool] = []  # Whether the first dimension is special in a projection
    self.angle_dims: list[tuple[int, int]] = []  # Maps preprocessed angles to dimension indices
    self.tree: _DecisionNode = _DecisionNode()  # The root of the tree
    self.classes_: Float[torch.Tensor, "n_classes"] = torch.empty(0)  # Initialize as an empty tensor
    self.labels_: Int[torch.Tensor, "batch n_classes"] = torch.tensor([])  # sklearn-style labels
    self.signature: list[tuple[float, int]] = pm.signature  # The signature of the manifold

fit(X, y)

Reworked fit function for new version of ProductDT.

Parameters:
  • X (Float[Tensor, 'batch ambient_dim']) –

    (batch, ambient_dim) tensor of trainind data (ambient coordinate representation)

  • y (Real[Tensor, 'batch']) –

    (batch,) tensor of labels (integer representation)

Returns:
Source code in manify/predictors/decision_tree.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
@torch.no_grad()  # type: ignore
def fit(self, X: Float[torch.Tensor, "batch ambient_dim"], y: Real[torch.Tensor, "batch"]) -> ProductSpaceDT:
    """Reworked fit function for new version of ProductDT.

    Args:
        X: (batch, ambient_dim) tensor of trainind data (ambient coordinate representation)
        y: (batch,) tensor of labels (integer representation)

    Returns:
        None (fits tree in place)
    """
    # Pre-preprocessing step: aggregate special dimensions into a new Euclidean component
    if self.use_special_dims:
        X, self.pm = self._aggregate_special_dims(X)

    # Preprocess data
    angles, labels, comparisons_reshaped = self._preprocess(X=X, y=y)

    # Fit node
    self.tree = self._fit_node(angles=angles, labels=labels, comparisons=comparisons_reshaped, depth=self.max_depth)

    self.is_fitted_ = True  # Mark the model as fitted
    return self

predict_proba(X)

Predict class probabilities for samples in X.

Source code in manify/predictors/decision_tree.py
607
608
609
610
611
612
613
614
615
@torch.no_grad()  # type: ignore
def predict_proba(self, X: Float[torch.Tensor, "batch intrinsic_dim"]) -> Float[torch.Tensor, "batch n_classes"]:
    """Predict class probabilities for samples in X."""
    if self.use_special_dims:
        X, _ = self._aggregate_special_dims(X)
    angles, _, _ = self._preprocess(X=X)
    if self.permutations is not None:
        angles = angles[:, self.permutations]
    return torch.vstack([self._traverse(angles_row, self.tree).probs for angles_row in angles])

ProductSpaceRF(pm, task='classification', use_special_dims=False, n_features='d', max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, ablate_midpoints=False, n_estimators=100, max_features='sqrt', max_samples=1.0, batch_size=None, random_state=None, n_jobs=-1, device=None)

Bases: BasePredictor

Random Forest in the product space.

Source code in manify/predictors/decision_tree.py
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
def __init__(
    self,
    pm: ProductManifold,
    task: Literal["classification", "regression"] = "classification",
    use_special_dims: bool = False,
    n_features: Literal["d", "d_choose_2"] = "d",
    max_depth: int | None = None,
    min_samples_leaf: int = 1,
    min_samples_split: int = 2,
    min_impurity_decrease: float = 0.0,
    ablate_midpoints: bool = False,
    n_estimators: int = 100,
    max_features: Literal["sqrt", "log2", "none"] = "sqrt",
    max_samples: float = 1.0,
    batch_size: int | None = None,
    random_state: int | None = None,
    n_jobs: int = -1,
    device: str | None = None,
):
    # Initialize the base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)

    # Raise error if manifold is stereographic
    if pm.is_stereographic:
        raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
    if task == "link_prediction":
        raise ValueError(
            "Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
        )

    # Tree hyperparameters
    tree_kwargs: Dict[str, Any] = {}
    self.pm = tree_kwargs["pm"] = pm
    self.task = tree_kwargs["task"] = task
    self.max_depth = tree_kwargs["max_depth"] = max_depth or -1
    self.min_samples_leaf = tree_kwargs["min_samples_leaf"] = min_samples_leaf
    self.min_samples_split = tree_kwargs["min_samples_split"] = min_samples_split
    self.min_impurity_decrease = tree_kwargs["min_impurity_decrease"] = min_impurity_decrease
    self.use_special_dims = tree_kwargs["use_special_dims"] = use_special_dims
    self.n_features = tree_kwargs["n_features"] = n_features
    self.batch_size = tree_kwargs["batch_size"] = batch_size
    self.ablate_midpoints = tree_kwargs["ablate_midpoints"] = ablate_midpoints

    # I use "batched" to mean "all at once" and "batch_size" to mean "in chunks"
    self.batch_size = batch_size
    if batch_size is not None:
        self.batched = False
    else:
        self.batched = True

    # Random forest hyperparameters
    self.n_estimators = n_estimators
    self.max_features = max_features
    self.max_samples = max_samples
    self.random_state = random_state
    self.n_jobs = n_jobs
    self.trees = [ProductSpaceDT(**tree_kwargs) for _ in range(n_estimators)]

    # These will become important later - just the sklearn-style stuff
    # For other special attributes, we just use ProductSpaceDT's attributes
    self.classes_: Float[torch.Tensor, "n_classes"] | None = None
    self.labels_: Int[torch.Tensor, "batch n_classes"] | None = None

fit(X, y)

Preprocess and fit an ensemble of trees on subsampled data.

Source code in manify/predictors/decision_tree.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
@torch.no_grad()  # type: ignore
def fit(self, X: Float[torch.Tensor, "batch ambient_dim"], y: Real[torch.Tensor, "batch"]) -> ProductSpaceRF:
    """Preprocess and fit an ensemble of trees on subsampled data."""
    # Pre-preprocessing step: aggregate special dimensions
    if self.use_special_dims:
        X, self.pm = self.trees[0]._aggregate_special_dims(X)
        for tree in self.trees:
            tree.pm = self.pm

    # Can use any tree to preprocess X and y
    angles, labels, comparisons = self.trees[0]._preprocess(X=X, y=y)

    # Also update angle2man and special_first
    for tree in self.trees:
        tree.angle2man = self.trees[0].angle2man
        tree.special_first = self.trees[0].special_first
        tree.classes_ = self.trees[0].classes_
    self.classes_ = self.trees[0].classes_

    # Use seed here
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Subsample - just the indices
    n, d = angles.shape
    idx_sample_all, idx_dim_all = self._generate_subsample(n_rows=n, n_cols=d, n_trees=self.n_estimators)

    # Fit trees
    for tree, idx_sample, idx_dim in zip(self.trees, idx_sample_all, idx_dim_all, strict=False):
        tree.permutations = idx_dim
        if self.batched:
            comparisons_subsample = comparisons[idx_sample][:, idx_dim][:, :, idx_sample]
        else:
            comparisons_subsample = comparisons
        tree.tree = tree._fit_node(
            angles=angles[idx_sample][:, idx_dim],
            labels=labels[idx_sample],
            comparisons=comparisons_subsample,
            depth=self.max_depth,
        )

    self.is_fitted_ = True
    return self

predict_proba(X)

Predict class probabilities for samples in X.

Source code in manify/predictors/decision_tree.py
749
750
751
752
@torch.no_grad()  # type: ignore
def predict_proba(self, X: Float[torch.Tensor, "batch intrinsic_dim"]) -> Float[torch.Tensor, "batch n_classes"]:
    """Predict class probabilities for samples in X."""
    return torch.stack([tree.predict_proba(X) for tree in self.trees]).mean(dim=0)

KappaGCN(pm, output_dim, num_hidden=2, nonlinearity=torch.relu, task='classification', random_state=None, device=None)

Bases: BasePredictor, Module

Implementation for the Kappa GCN.

Attributes:
  • pm

    ProductManifold object for the Kappa GCN.

  • output_dim

    Number of output features.

  • num_hidden

    Number of hidden layers.

  • nonlinearity

    Function for nonlinear activation.

  • task

    Task type, one of ["classification", "regression", "link_prediction"]

  • random_state

    Random seed for reproducibility.

  • device

    Device to run the model on (default: None, uses current device).

  • is_fitted_ (bool) –

    Whether the model has been fitted.

  • loss_history_ (dict[str, list[float]]) –

    History of loss values during training.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object for the Kappa GCN

  • output_dim (int) –

    Number of output features

  • num_hidden (int, default: 2 ) –

    Number of hidden layers.

  • nonlinearity (Callable, default: relu ) –

    Function for nonlinear activation.

  • task (Literal['classification', 'regression', 'link_prediction'], default: 'classification' ) –

    Task type, one of ["classification", "regression", "link_prediction"].

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device to run the model on (default: None, uses current device).

Raises:
  • ValueError

    If the ProductManifold is not stereographic.

Source code in manify/predictors/kappa_gcn.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    pm: ProductManifold,
    output_dim: int,
    num_hidden: int = 2,
    nonlinearity: Callable = torch.relu,
    task: Literal["classification", "regression", "link_prediction"] = "classification",
    random_state: int | None = None,
    device: str | None = None,
):
    BasePredictor.__init__(self, pm=pm, task=task, random_state=random_state, device=device)
    torch.nn.Module.__init__(self)

    self.pm = pm
    self.task = task
    self.output_dim = output_dim
    self.num_hidden = num_hidden
    self.nonlinearity = nonlinearity

    # Ensure pm is stereographic
    if not pm.is_stereographic:
        raise ValueError(
            "ProductManifold must be stereographic for KappaGCN to work.Please use pm.stereographic() to convert."
        )

    # Build layer dimensions
    dims = [pm.dim] + [pm.dim] * num_hidden

    # Build the main GCN layers using Sequential
    gcn_layers = []
    for i in range(len(dims) - 1):
        gcn_layers.append(KappaGCNLayer(dims[i], dims[i + 1], pm, nonlinearity))

    self.gcn_layers = KappaSequential(*gcn_layers)

    # Task-specific output layers - much cleaner now!
    if task == "link_prediction":
        self.output_layer = FermiDiracDecoder(pm, learnable_params=True)
    else:
        # This is the same for classification/regression since we apply softmax in the loss function, not here
        self.output_layer = StereographicLogits(output_dim, pm, apply_softmax=False)

forward(X, A_hat=None, aggregate_logits=True, softmax=False)

Forward pass through the GCN layers and output head.

Source code in manify/predictors/kappa_gcn.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    aggregate_logits: bool = True,
    softmax: bool = False,
) -> (
    Float[torch.Tensor, "n_nodes n_classes"]
    | Float[torch.Tensor, "n_nodes"]
    | Float[torch.Tensor, "n_nodes n_nodes"]
):
    """Forward pass through the GCN layers and output head."""
    # Pass through main GCN layers
    H = self.gcn_layers(X, A_hat)

    # Task-specific output using the specialized layers
    if self.task == "link_prediction":
        return self.output_layer(H)  # Flattened for link prediction
    else:
        # For classification/regression, use stereographic logits
        logits = self.output_layer(H, A_hat, aggregate_logits=aggregate_logits)

        if softmax:
            logits = torch.softmax(logits, dim=-1)

        return logits.squeeze()

fit(X, y, A=None, epochs=2000, lr=0.01, use_tqdm=True, tqdm_prefix=None)

Fit the Kappa GCN model.

Parameters:
  • X (Tensor) –

    Feature matrix.

  • y (Tensor) –

    Labels for training nodes.

  • A (Tensor, default: None ) –

    Adjacency or distance matrix.

  • epochs (int, default: 2000 ) –

    Number of training epochs (default=200).

  • lr (float, default: 0.01 ) –

    Learning rate (default=1e-2).

  • use_tqdm (bool, default: True ) –

    Whether to use tqdm for progress bar.

  • tqdm_prefix (str | None, default: None ) –

    Prefix for tqdm progress bar.

Source code in manify/predictors/kappa_gcn.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def fit(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    y: Real[torch.Tensor, "n_nodes"],
    A: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    epochs: int = 2_000,
    lr: float = 1e-2,
    use_tqdm: bool = True,
    tqdm_prefix: str | None = None,
) -> KappaGCN:
    """Fit the Kappa GCN model.

    Args:
        X (torch.Tensor): Feature matrix.
        y (torch.Tensor): Labels for training nodes.
        A (torch.Tensor): Adjacency or distance matrix.
        epochs: Number of training epochs (default=200).
        lr: Learning rate (default=1e-2).
        use_tqdm: Whether to use tqdm for progress bar.
        tqdm_prefix: Prefix for tqdm progress bar.
    """
    # Copy everything
    X = X.clone()
    y = y.clone()
    A = A.clone() if A is not None else None

    # Convert A to A_hat
    A_hat = get_A_hat(A, make_symmetric=True, add_self_loops=True) if A is not None else None

    # Collect all paramters
    euclidean_params = []
    riemannian_params = []
    for layer in self.gcn_layers.layers:
        euclidean_params.append(layer.W)
    if self.task == "link_prediction":
        euclidean_params += [self.output_layer.temperature, self.output_layer.bias]
    else:
        euclidean_params += [self.output_layer.W]
        riemannian_params += [self.output_layer.p_ks]

    # Optimizers
    opt = torch.optim.Adam(euclidean_params, lr=lr)
    ropt = geoopt.optim.RiemannianAdam(riemannian_params, lr=lr) if riemannian_params else None

    if self.task == "classification":
        loss_fn = torch.nn.CrossEntropyLoss()
        y = y.long()
    elif self.task == "regression":
        loss_fn = torch.nn.MSELoss()
        y = y.float()
    elif self.task == "link_prediction":
        loss_fn = torch.nn.BCEWithLogitsLoss()
        # y = y.flatten().float()
        y = y.float()
    else:
        raise ValueError("Invalid task!")

    self.train()
    if use_tqdm:
        my_tqdm = tqdm(total=epochs, desc=tqdm_prefix)

    losses = []
    for i in range(epochs):
        opt.zero_grad()
        if riemannian_params:
            ropt.zero_grad()  # type: ignore
        y_pred = self(X, A_hat)
        loss = loss_fn(y_pred, y)
        loss.backward()
        opt.step()
        if riemannian_params:
            ropt.step()  # type: ignore

        # Progress bar
        if use_tqdm:
            my_tqdm.update(1)
            my_tqdm.set_description(f"Epoch {i + 1}/{epochs}, Loss: {loss.item():.4f}")

        # Early termination for nan loss
        if torch.isnan(loss):
            print("Loss is NaN, stopping training.")
            break
        losses.append(loss.item())

    if use_tqdm:
        my_tqdm.close()

    self.is_fitted_ = True
    self.loss_history_["train"] = losses
    return self

predict_proba(X, A=None)

Predict class probabilities using the trained Kappa GCN.

Parameters:
  • X (Tensor) –

    Feature matrix (NxD).

  • A (Tensor, default: None ) –

    Adjacency or distance matrix (NxN).

Returns:
  • Real[Tensor, 'n_nodes n_classes'] | Real[Tensor, 'n_nodes']

    torch.Tensor: Predicted class probabilities / regression targets.

Source code in manify/predictors/kappa_gcn.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def predict_proba(
    self, X: Float[torch.Tensor, "n_nodes dim"], A: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Real[torch.Tensor, "n_nodes n_classes"] | Real[torch.Tensor, "n_nodes"]:
    """Predict class probabilities using the trained Kappa GCN.

    Args:
        X (torch.Tensor): Feature matrix (NxD).
        A (torch.Tensor): Adjacency or distance matrix (NxN).

    Returns:
        torch.Tensor: Predicted class probabilities / regression targets.
    """
    # Copy everything
    X = X.clone()
    A = A.clone() if A is not None else None
    A_hat = get_A_hat(A, make_symmetric=True, add_self_loops=True) if A is not None else None

    # Get edges for test set
    self.eval()
    y_pred = self(X, A_hat)
    return y_pred

ProductSpacePerceptron(pm, max_epochs=1000, patience=5, weights=None, task='classification', random_state=None, device=None)

Bases: BasePredictor

A product-space perceptron model for multiclass classification in the product manifold space.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object for the product space.

  • max_epochs (int, default: 1000 ) –

    Maximum number of training epochs.

  • patience (int, default: 5 ) –

    Number of consecutive epochs without improvement to consider convergence.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Per-manifold weights for kernel combination.

  • task (str, default: 'classification' ) –

    Task type (defaults to "classification").

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Attributes:
  • pm

    ProductManifold object associated with the predictor.

  • max_epochs

    Maximum number of training epochs.

  • patience

    Number of consecutive epochs without improvement to consider convergence.

  • weights

    Per-manifold weights for kernel combination.

  • alpha

    Dictionary storing perceptron coefficients for each class.

  • X_train_

    Training data points.

  • y_train_

    Training labels.

  • is_fitted_ (bool) –

    Boolean flag indicating if the predictor has been fitted.

Source code in manify/predictors/perceptron.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self,
    pm: ProductManifold,
    max_epochs: int = 1_000,
    patience: int = 5,
    weights: Float[torch.Tensor, "n_manifolds"] | None = None,
    task: str = "classification",
    random_state: int | None = None,
    device: str | None = None,
):
    # Initialize base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)
    self.pm = pm  # ProductManifold instance
    self.max_epochs = max_epochs
    self.patience = patience  # Number of consecutive epochs without improvement to consider convergence
    self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
    assert len(self.weights) == len(pm.P), "Number of weights must match the number of manifolds."

fit(X, y)

Trains the perceptron model using the provided data and labels.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Training data tensor.

  • y (Int[Tensor, 'n_samples']) –

    Class labels for the training data.

Returns:
Source code in manify/predictors/perceptron.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def fit(
    self, X: Float[torch.Tensor, "n_samples n_manifolds"], y: Int[torch.Tensor, "n_samples"]
) -> ProductSpacePerceptron:
    """Trains the perceptron model using the provided data and labels.

    Args:
        X: Training data tensor.
        y: Class labels for the training data.

    Returns:
        self: Fitted perceptron model.
    """
    # Identify unique classes for multiclass classification
    self._store_classes(y)
    n_samples = X.shape[0]

    # Precompute kernel matrix
    Ks, _ = product_kernel(self.pm, X, None)
    K = torch.ones((n_samples, n_samples), dtype=X.dtype, device=X.device)
    for K_m, w in zip(Ks, self.weights, strict=False):
        K += w * K_m

    # Store training data and labels for prediction
    self.X_train_ = X
    self.y_train_ = y

    # Initialize dictionary to store alpha coefficients for each class
    self.alpha = {}

    # For patience checking
    best_epoch, least_errors = 0, n_samples + 1

    for class_label in self.classes_:
        class_label_item = class_label.item()

        # One-vs-rest labels
        y_binary = torch.where(y == class_label_item, 1, -1)  # Shape: (n_samples,)

        # Initialize alpha coefficients for this class
        alpha = torch.zeros(n_samples, dtype=X.dtype, device=X.device)

        for epoch in range(self.max_epochs):
            # Compute decision function: f = K @ (alpha * y_binary)
            f = K @ (alpha * y_binary)  # Shape: (n_samples,)

            # Compute predictions
            predictions = torch.sign(f)

            # Find misclassified samples
            misclassified = predictions != y_binary

            # If no misclassifications, break early
            if not misclassified.any():
                break

            # Test patience
            n_errors = misclassified.sum().item()
            if n_errors < least_errors:
                best_epoch, least_errors = epoch, n_errors
            if epoch - best_epoch >= self.patience:
                break

            # Update alpha coefficients for misclassified samples
            alpha[misclassified] += 1

        # Store the alpha coefficients for the current class
        self.alpha[class_label_item] = alpha

    self.is_fitted_ = True
    return self

predict_proba(X)

Predicts the decision values for each class.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Test data tensor.

Returns:
  • decision_values( Float[Tensor, 'n_points n_classes'] ) –

    Decision values for each test sample and each class.

Source code in manify/predictors/perceptron.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def predict_proba(
    self,
    X: Float[torch.Tensor, "n_points n_features"],  # type: ignore[override]
) -> Float[torch.Tensor, "n_points n_classes"]:
    """Predicts the decision values for each class.

    Args:
        X: Test data tensor.

    Returns:
        decision_values: Decision values for each test sample and each class.
    """
    n_samples = X.shape[0]
    n_classes = len(self.classes_)
    decision_values = torch.zeros((n_samples, n_classes), dtype=X.dtype, device=X.device)

    # Compute kernel matrix between training data and test data
    Ks, _ = product_kernel(self.pm, self.X_train_, X)
    K_test = torch.ones((self.X_train_.shape[0], n_samples), dtype=X.dtype, device=X.device)
    for K_m, w in zip(Ks, self.weights, strict=False):
        K_test += w * K_m
    # K_test = self.X_train_ @ X.T

    for idx, class_label in enumerate(self.classes_):
        class_label_item = class_label.item()
        alpha = self.alpha[class_label_item]  # Shape: (n_samples_train,)
        y_binary = torch.where(self.y_train_ == class_label_item, 1, -1)  # Shape: (n_samples_train,)

        # Compute decision function for test samples
        f = (alpha * y_binary) @ K_test  # Shape: (n_samples_test,)
        decision_values[:, idx] = f

    return decision_values

ProductSpaceSVM(pm, weights=None, h_constraints=True, e_constraints=True, s_constraints=True, task='classification', epsilon=1e-05, random_state=None, device=None)

Bases: BasePredictor

Product Space SVM class in a product manifold setting.

Trains one-vs-rest SVMs with Euclidean, spherical, and hyperbolic constraints enforced via second-order-cone (SOC) formulations for convexity.

Parameters:
  • pm (ProductManifold) –

    A ProductManifold instance specifying component manifolds.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Optional per-manifold weights tensor.

  • h_constraints (bool, default: True ) –

    Whether to enforce hyperbolic constraints.

  • e_constraints (bool, default: True ) –

    Whether to enforce Euclidean constraints.

  • s_constraints (bool, default: True ) –

    Whether to enforce spherical constraints.

  • task (Literal['classification', 'regression'], default: 'classification' ) –

    Task type, either "classification" or "regression".

  • epsilon (float, default: 1e-05 ) –

    Slack parameter for SOC constraints.

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Attributes:
  • pm

    ProductManifold object associated with the predictor.

  • weights

    Per-manifold weights for kernel combination.

  • h_constraints

    Whether to enforce hyperbolic constraints.

  • e_constraints

    Whether to enforce Euclidean constraints.

  • s_constraints

    Whether to enforce spherical constraints.

  • eps

    Slack parameter for SOC constraints.

  • beta

    Dictionary storing SVM coefficients for each class.

  • zeta

    Dictionary storing slack variables for each class.

  • epsilon

    Dictionary storing epsilon values for each class.

  • b

    Dictionary storing bias terms for each class.

  • X_train_

    Training data points.

  • is_fitted_ (bool) –

    Boolean flag indicating if the predictor has been fitted.

Initialize the ProductSpaceSVM.

Parameters:
  • pm (ProductManifold) –

    A ProductManifold instance specifying component manifolds.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Optional per-manifold weights tensor.

  • h_constraints (bool, default: True ) –

    Whether to enforce hyperbolic constraints.

  • e_constraints (bool, default: True ) –

    Whether to enforce Euclidean constraints.

  • s_constraints (bool, default: True ) –

    Whether to enforce spherical constraints.

  • task (Literal['classification', 'regression'], default: 'classification' ) –

    Task type, either "classification" or "regression".

  • epsilon (float, default: 1e-05 ) –

    Slack parameter for SOC constraints.

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Source code in manify/predictors/svm.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    pm: ProductManifold,
    weights: Float[torch.Tensor, "n_manifolds"] | None = None,
    h_constraints: bool = True,
    e_constraints: bool = True,
    s_constraints: bool = True,
    task: Literal["classification", "regression"] = "classification",
    epsilon: float = 1e-5,
    random_state: int | None = None,
    device: str | None = None,
):
    """Initialize the ProductSpaceSVM.

    Args:
        pm: A ProductManifold instance specifying component manifolds.
        weights: Optional per-manifold weights tensor.
        h_constraints: Whether to enforce hyperbolic constraints.
        e_constraints: Whether to enforce Euclidean constraints.
        s_constraints: Whether to enforce spherical constraints.
        task: Task type, either "classification" or "regression".
        epsilon: Slack parameter for SOC constraints.
        random_state: Random seed for reproducibility.
        device: Device for tensor computations.
    """
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)
    self.pm = pm
    self.h_constraints = h_constraints
    self.s_constraints = s_constraints
    self.e_constraints = e_constraints
    self.eps = epsilon
    self.task = task
    self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
    assert len(self.weights) == len(pm.P), "Number of weights must match the number of manifolds."

fit(X, y)

Fit one-vs-rest SVMs on the product manifold data.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Training points tensor.

  • y (Int[Tensor, 'n_samples']) –

    Integer class labels tensor.

Returns:
Source code in manify/predictors/svm.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def fit(
    self,
    X: Float[torch.Tensor, "n_samples n_manifolds"],
    y: Int[torch.Tensor, "n_samples"],
) -> ProductSpaceSVM:
    """Fit one-vs-rest SVMs on the product manifold data.

    Args:
        X: Training points tensor.
        y: Integer class labels tensor.

    Returns:
        self: Fitted ProductSpaceSVM instance.
    """
    # unique classes
    self._store_classes(y)
    n = X.shape[0]

    # aggregated kernel
    Ks, _ = product_kernel(self.pm, X, None)
    K_sum = torch.ones((n, n), dtype=X.dtype, device=X.device)
    for K_m, w in zip(Ks, self.weights, strict=False):
        K_sum += w * K_m

    X_np = X.detach().cpu().numpy()
    K_np = K_sum.detach().cpu().numpy()

    def sqrtm_psd(P: np.ndarray) -> Any:
        w, V = np.linalg.eigh(P)
        w_s = np.sqrt(np.clip(w, 0, None))
        B = V @ np.diag(w_s) @ V.T
        return (B + B.T) * 0.5

    # containers
    self.beta = {}
    self.zeta = {}
    self.epsilon = {}
    self.b = {}

    for cls in self.classes_:
        cls_item = cls.item() if isinstance(cls, torch.Tensor) else cls
        # one-vs-rest labels: +1 for cls, -1 for others
        y_bin = torch.where(y == cls_item, 1, -1)
        Y = torch.diagflat(y_bin).detach().cpu().numpy()

        # variables
        beta_var = cp.Variable(n)
        zeta = cp.Variable(n, nonneg=True)
        eps_var = cp.Variable(1)
        b_var = cp.Variable(1)

        # base constraints
        constraints = [eps_var >= 0]
        constraints.append(Y @ (K_np @ beta_var + b_var) >= eps_var - zeta)

        # per-manifold SOC
        for M, K_comp in zip(self.pm.P, Ks, strict=False):
            P_np = K_comp.detach().cpu().numpy()
            if M.type == "E" and self.e_constraints:
                B = sqrtm_psd(P_np)
                constraints.append(cp.norm(B @ beta_var, 2) <= 1.0)
            elif M.type == "S" and self.s_constraints:
                B = sqrtm_psd(P_np)
                constraints.append(cp.norm(B @ beta_var, 2) <= np.sqrt(np.pi / 2))
            elif M.type == "H" and self.h_constraints:
                # PSD split
                eigvals, eigvecs = np.linalg.eigh(P_np)
                plus = np.clip(eigvals, 0, None)
                minus = np.clip(-eigvals, 0, None)
                Kp = (eigvecs @ np.diag(plus) @ eigvecs.T + (eigvecs @ np.diag(plus) @ eigvecs.T).T) * 0.5
                Km = (eigvecs @ np.diag(minus) @ eigvecs.T + (eigvecs @ np.diag(minus) @ eigvecs.T).T) * 0.5
                Bp = sqrtm_psd(Kp)
                Bm = sqrtm_psd(Km)

                C_H = abs(M.curvature)
                R = -M.scale
                r_h = abs(np.arcsinh(-(R**2) * C_H))
                r = self.eps

                constraints.append(cp.norm(Bm @ beta_var, 2) <= np.sqrt(max(r, 0.0)))
                constraints.append(cp.norm(Bp @ beta_var, 2) <= np.sqrt(max(r + r_h, 0.0)))

        # solve
        prob = cp.Problem(cp.Minimize(-eps_var + cp.sum(zeta)), constraints)
        prob.solve(solver="SCS")

        # save results
        self.beta[cls_item] = np.ravel(beta_var.value)
        self.zeta[cls_item] = zeta.value
        self.epsilon[cls_item] = float(eps_var.value)
        self.b[cls_item] = float(b_var.value)

    # store training data
    self.X_train_ = torch.tensor(X_np, dtype=torch.float32)
    self.is_fitted_ = True
    return self

predict_proba(X)

Predict class probabilities using the fitted SVMs.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Test points tensor.

Returns:
  • class_probabilities( Float[Tensor, 'n_samples n_classes'] ) –

    Class probabilities for each test sample.

Source code in manify/predictors/svm.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def predict_proba(
    self,
    X: Float[torch.Tensor, "n_samples n_manifolds"],
) -> Float[torch.Tensor, "n_samples n_classes"]:
    """Predict class probabilities using the fitted SVMs.

    Args:
        X: Test points tensor.

    Returns:
        class_probabilities: Class probabilities for each test sample.
    """
    X_tensor = torch.tensor(X, dtype=torch.float32) if not isinstance(X, torch.Tensor) else X
    X_tensor = X_tensor.to(self.X_train_.device)

    Ks_test, _ = product_kernel(self.pm, self.X_train_, X_tensor)
    Kt = torch.ones((self.X_train_.shape[0], X_tensor.shape[0]), device=X_tensor.device)
    for K_m, w in zip(Ks_test, self.weights, strict=False):
        Kt += w * K_m
    Kt_np = Kt.detach().cpu().numpy()

    n_test = X_tensor.shape[0]
    n_cls = len(self.classes_)
    dec = np.zeros((n_test, n_cls))
    for idx, cls in enumerate(self.classes_):
        cls_item = cls.item() if isinstance(cls, torch.Tensor) else cls
        beta_vec: np.ndarray = np.ravel(self.beta[cls_item])
        dec[:, idx] = Kt_np.T @ beta_vec + self.b[cls_item]

    exp_scores = np.exp(dec - dec.max(axis=1, keepdims=True))
    probs = exp_scores / exp_scores.sum(axis=1, keepdims=True)
    return torch.tensor(probs, dtype=torch.float32)

decision_tree

Decision tree and random forest predictors for product space manifolds.

For more information, see Chlenski et al. (2024): https://arxiv.org/abs/2410.13879

ProductSpaceDT(pm, max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, task='classification', use_special_dims=False, batch_size=None, n_features='d', ablate_midpoints=False, random_state=None, device=None)

Bases: BasePredictor

Decision tree in the product space to handle hyperbolic, euclidean, and hyperspherical data.

Source code in manify/predictors/decision_tree.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def __init__(
    self,
    pm: ProductManifold,
    max_depth: int | None = None,
    min_samples_leaf: int = 1,
    min_samples_split: int = 2,
    min_impurity_decrease: float = 0.0,
    task: Literal["classification", "regression", "link_prediction"] = "classification",
    use_special_dims: bool = False,
    batch_size: int | None = None,
    n_features: Literal["d", "d_choose_2"] = "d",
    ablate_midpoints: bool = False,
    random_state: int | None = None,
    device: str | None = None,
):
    # Initialize the base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)

    # Raise error if manifold is stereographic
    if pm.is_stereographic:
        raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
    if task == "link_prediction":
        raise ValueError(
            "Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
        )

    # Store hyperparameters
    self.pm = pm
    self.max_depth = max_depth or -1
    self.min_samples_leaf = min_samples_leaf
    self.min_samples_split = min_samples_split
    self.min_impurity_decrease = min_impurity_decrease
    self.use_special_dims = use_special_dims
    self.n_features = n_features
    self.ablate_midpoints = ablate_midpoints

    # I use "batched" to mean "all at once" and "batch_size" to mean "in chunks"
    self.batch_size = batch_size
    if batch_size is not None:
        self.batched = False
    else:
        self.batched = True

    # Task-specific stuff
    self.task = task
    self.criterion = "gini" if task == "classification" else "mse"

    # These will become important later, when fit is called
    self.nodes: list[_DecisionNode] = []  # For fitted nodes
    self.permutations: Int[torch.Tensor, "n_classes"] | None = None  # If used as part of a random forest
    self.angle2man: list[int] = []  # Maps preprocessed angles to manifold indices
    self.special_first: list[bool] = []  # Whether the first dimension is special in a projection
    self.angle_dims: list[tuple[int, int]] = []  # Maps preprocessed angles to dimension indices
    self.tree: _DecisionNode = _DecisionNode()  # The root of the tree
    self.classes_: Float[torch.Tensor, "n_classes"] = torch.empty(0)  # Initialize as an empty tensor
    self.labels_: Int[torch.Tensor, "batch n_classes"] = torch.tensor([])  # sklearn-style labels
    self.signature: list[tuple[float, int]] = pm.signature  # The signature of the manifold
fit(X, y)

Reworked fit function for new version of ProductDT.

Parameters:
  • X (Float[Tensor, 'batch ambient_dim']) –

    (batch, ambient_dim) tensor of trainind data (ambient coordinate representation)

  • y (Real[Tensor, 'batch']) –

    (batch,) tensor of labels (integer representation)

Returns:
Source code in manify/predictors/decision_tree.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
@torch.no_grad()  # type: ignore
def fit(self, X: Float[torch.Tensor, "batch ambient_dim"], y: Real[torch.Tensor, "batch"]) -> ProductSpaceDT:
    """Reworked fit function for new version of ProductDT.

    Args:
        X: (batch, ambient_dim) tensor of trainind data (ambient coordinate representation)
        y: (batch,) tensor of labels (integer representation)

    Returns:
        None (fits tree in place)
    """
    # Pre-preprocessing step: aggregate special dimensions into a new Euclidean component
    if self.use_special_dims:
        X, self.pm = self._aggregate_special_dims(X)

    # Preprocess data
    angles, labels, comparisons_reshaped = self._preprocess(X=X, y=y)

    # Fit node
    self.tree = self._fit_node(angles=angles, labels=labels, comparisons=comparisons_reshaped, depth=self.max_depth)

    self.is_fitted_ = True  # Mark the model as fitted
    return self
predict_proba(X)

Predict class probabilities for samples in X.

Source code in manify/predictors/decision_tree.py
607
608
609
610
611
612
613
614
615
@torch.no_grad()  # type: ignore
def predict_proba(self, X: Float[torch.Tensor, "batch intrinsic_dim"]) -> Float[torch.Tensor, "batch n_classes"]:
    """Predict class probabilities for samples in X."""
    if self.use_special_dims:
        X, _ = self._aggregate_special_dims(X)
    angles, _, _ = self._preprocess(X=X)
    if self.permutations is not None:
        angles = angles[:, self.permutations]
    return torch.vstack([self._traverse(angles_row, self.tree).probs for angles_row in angles])

ProductSpaceRF(pm, task='classification', use_special_dims=False, n_features='d', max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, ablate_midpoints=False, n_estimators=100, max_features='sqrt', max_samples=1.0, batch_size=None, random_state=None, n_jobs=-1, device=None)

Bases: BasePredictor

Random Forest in the product space.

Source code in manify/predictors/decision_tree.py
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
def __init__(
    self,
    pm: ProductManifold,
    task: Literal["classification", "regression"] = "classification",
    use_special_dims: bool = False,
    n_features: Literal["d", "d_choose_2"] = "d",
    max_depth: int | None = None,
    min_samples_leaf: int = 1,
    min_samples_split: int = 2,
    min_impurity_decrease: float = 0.0,
    ablate_midpoints: bool = False,
    n_estimators: int = 100,
    max_features: Literal["sqrt", "log2", "none"] = "sqrt",
    max_samples: float = 1.0,
    batch_size: int | None = None,
    random_state: int | None = None,
    n_jobs: int = -1,
    device: str | None = None,
):
    # Initialize the base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)

    # Raise error if manifold is stereographic
    if pm.is_stereographic:
        raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
    if task == "link_prediction":
        raise ValueError(
            "Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
        )

    # Tree hyperparameters
    tree_kwargs: Dict[str, Any] = {}
    self.pm = tree_kwargs["pm"] = pm
    self.task = tree_kwargs["task"] = task
    self.max_depth = tree_kwargs["max_depth"] = max_depth or -1
    self.min_samples_leaf = tree_kwargs["min_samples_leaf"] = min_samples_leaf
    self.min_samples_split = tree_kwargs["min_samples_split"] = min_samples_split
    self.min_impurity_decrease = tree_kwargs["min_impurity_decrease"] = min_impurity_decrease
    self.use_special_dims = tree_kwargs["use_special_dims"] = use_special_dims
    self.n_features = tree_kwargs["n_features"] = n_features
    self.batch_size = tree_kwargs["batch_size"] = batch_size
    self.ablate_midpoints = tree_kwargs["ablate_midpoints"] = ablate_midpoints

    # I use "batched" to mean "all at once" and "batch_size" to mean "in chunks"
    self.batch_size = batch_size
    if batch_size is not None:
        self.batched = False
    else:
        self.batched = True

    # Random forest hyperparameters
    self.n_estimators = n_estimators
    self.max_features = max_features
    self.max_samples = max_samples
    self.random_state = random_state
    self.n_jobs = n_jobs
    self.trees = [ProductSpaceDT(**tree_kwargs) for _ in range(n_estimators)]

    # These will become important later - just the sklearn-style stuff
    # For other special attributes, we just use ProductSpaceDT's attributes
    self.classes_: Float[torch.Tensor, "n_classes"] | None = None
    self.labels_: Int[torch.Tensor, "batch n_classes"] | None = None
fit(X, y)

Preprocess and fit an ensemble of trees on subsampled data.

Source code in manify/predictors/decision_tree.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
@torch.no_grad()  # type: ignore
def fit(self, X: Float[torch.Tensor, "batch ambient_dim"], y: Real[torch.Tensor, "batch"]) -> ProductSpaceRF:
    """Preprocess and fit an ensemble of trees on subsampled data."""
    # Pre-preprocessing step: aggregate special dimensions
    if self.use_special_dims:
        X, self.pm = self.trees[0]._aggregate_special_dims(X)
        for tree in self.trees:
            tree.pm = self.pm

    # Can use any tree to preprocess X and y
    angles, labels, comparisons = self.trees[0]._preprocess(X=X, y=y)

    # Also update angle2man and special_first
    for tree in self.trees:
        tree.angle2man = self.trees[0].angle2man
        tree.special_first = self.trees[0].special_first
        tree.classes_ = self.trees[0].classes_
    self.classes_ = self.trees[0].classes_

    # Use seed here
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Subsample - just the indices
    n, d = angles.shape
    idx_sample_all, idx_dim_all = self._generate_subsample(n_rows=n, n_cols=d, n_trees=self.n_estimators)

    # Fit trees
    for tree, idx_sample, idx_dim in zip(self.trees, idx_sample_all, idx_dim_all, strict=False):
        tree.permutations = idx_dim
        if self.batched:
            comparisons_subsample = comparisons[idx_sample][:, idx_dim][:, :, idx_sample]
        else:
            comparisons_subsample = comparisons
        tree.tree = tree._fit_node(
            angles=angles[idx_sample][:, idx_dim],
            labels=labels[idx_sample],
            comparisons=comparisons_subsample,
            depth=self.max_depth,
        )

    self.is_fitted_ = True
    return self
predict_proba(X)

Predict class probabilities for samples in X.

Source code in manify/predictors/decision_tree.py
749
750
751
752
@torch.no_grad()  # type: ignore
def predict_proba(self, X: Float[torch.Tensor, "batch intrinsic_dim"]) -> Float[torch.Tensor, "batch n_classes"]:
    """Predict class probabilities for samples in X."""
    return torch.stack([tree.predict_proba(X) for tree in self.trees]).mean(dim=0)

kappa_gcn

\(\kappa\)-GCN implementation.

KappaGCN(pm, output_dim, num_hidden=2, nonlinearity=torch.relu, task='classification', random_state=None, device=None)

Bases: BasePredictor, Module

Implementation for the Kappa GCN.

Attributes:
  • pm

    ProductManifold object for the Kappa GCN.

  • output_dim

    Number of output features.

  • num_hidden

    Number of hidden layers.

  • nonlinearity

    Function for nonlinear activation.

  • task

    Task type, one of ["classification", "regression", "link_prediction"]

  • random_state

    Random seed for reproducibility.

  • device

    Device to run the model on (default: None, uses current device).

  • is_fitted_ (bool) –

    Whether the model has been fitted.

  • loss_history_ (dict[str, list[float]]) –

    History of loss values during training.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object for the Kappa GCN

  • output_dim (int) –

    Number of output features

  • num_hidden (int, default: 2 ) –

    Number of hidden layers.

  • nonlinearity (Callable, default: relu ) –

    Function for nonlinear activation.

  • task (Literal['classification', 'regression', 'link_prediction'], default: 'classification' ) –

    Task type, one of ["classification", "regression", "link_prediction"].

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device to run the model on (default: None, uses current device).

Raises:
  • ValueError

    If the ProductManifold is not stereographic.

Source code in manify/predictors/kappa_gcn.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    pm: ProductManifold,
    output_dim: int,
    num_hidden: int = 2,
    nonlinearity: Callable = torch.relu,
    task: Literal["classification", "regression", "link_prediction"] = "classification",
    random_state: int | None = None,
    device: str | None = None,
):
    BasePredictor.__init__(self, pm=pm, task=task, random_state=random_state, device=device)
    torch.nn.Module.__init__(self)

    self.pm = pm
    self.task = task
    self.output_dim = output_dim
    self.num_hidden = num_hidden
    self.nonlinearity = nonlinearity

    # Ensure pm is stereographic
    if not pm.is_stereographic:
        raise ValueError(
            "ProductManifold must be stereographic for KappaGCN to work.Please use pm.stereographic() to convert."
        )

    # Build layer dimensions
    dims = [pm.dim] + [pm.dim] * num_hidden

    # Build the main GCN layers using Sequential
    gcn_layers = []
    for i in range(len(dims) - 1):
        gcn_layers.append(KappaGCNLayer(dims[i], dims[i + 1], pm, nonlinearity))

    self.gcn_layers = KappaSequential(*gcn_layers)

    # Task-specific output layers - much cleaner now!
    if task == "link_prediction":
        self.output_layer = FermiDiracDecoder(pm, learnable_params=True)
    else:
        # This is the same for classification/regression since we apply softmax in the loss function, not here
        self.output_layer = StereographicLogits(output_dim, pm, apply_softmax=False)
forward(X, A_hat=None, aggregate_logits=True, softmax=False)

Forward pass through the GCN layers and output head.

Source code in manify/predictors/kappa_gcn.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    aggregate_logits: bool = True,
    softmax: bool = False,
) -> (
    Float[torch.Tensor, "n_nodes n_classes"]
    | Float[torch.Tensor, "n_nodes"]
    | Float[torch.Tensor, "n_nodes n_nodes"]
):
    """Forward pass through the GCN layers and output head."""
    # Pass through main GCN layers
    H = self.gcn_layers(X, A_hat)

    # Task-specific output using the specialized layers
    if self.task == "link_prediction":
        return self.output_layer(H)  # Flattened for link prediction
    else:
        # For classification/regression, use stereographic logits
        logits = self.output_layer(H, A_hat, aggregate_logits=aggregate_logits)

        if softmax:
            logits = torch.softmax(logits, dim=-1)

        return logits.squeeze()
fit(X, y, A=None, epochs=2000, lr=0.01, use_tqdm=True, tqdm_prefix=None)

Fit the Kappa GCN model.

Parameters:
  • X (Tensor) –

    Feature matrix.

  • y (Tensor) –

    Labels for training nodes.

  • A (Tensor, default: None ) –

    Adjacency or distance matrix.

  • epochs (int, default: 2000 ) –

    Number of training epochs (default=200).

  • lr (float, default: 0.01 ) –

    Learning rate (default=1e-2).

  • use_tqdm (bool, default: True ) –

    Whether to use tqdm for progress bar.

  • tqdm_prefix (str | None, default: None ) –

    Prefix for tqdm progress bar.

Source code in manify/predictors/kappa_gcn.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def fit(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    y: Real[torch.Tensor, "n_nodes"],
    A: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    epochs: int = 2_000,
    lr: float = 1e-2,
    use_tqdm: bool = True,
    tqdm_prefix: str | None = None,
) -> KappaGCN:
    """Fit the Kappa GCN model.

    Args:
        X (torch.Tensor): Feature matrix.
        y (torch.Tensor): Labels for training nodes.
        A (torch.Tensor): Adjacency or distance matrix.
        epochs: Number of training epochs (default=200).
        lr: Learning rate (default=1e-2).
        use_tqdm: Whether to use tqdm for progress bar.
        tqdm_prefix: Prefix for tqdm progress bar.
    """
    # Copy everything
    X = X.clone()
    y = y.clone()
    A = A.clone() if A is not None else None

    # Convert A to A_hat
    A_hat = get_A_hat(A, make_symmetric=True, add_self_loops=True) if A is not None else None

    # Collect all paramters
    euclidean_params = []
    riemannian_params = []
    for layer in self.gcn_layers.layers:
        euclidean_params.append(layer.W)
    if self.task == "link_prediction":
        euclidean_params += [self.output_layer.temperature, self.output_layer.bias]
    else:
        euclidean_params += [self.output_layer.W]
        riemannian_params += [self.output_layer.p_ks]

    # Optimizers
    opt = torch.optim.Adam(euclidean_params, lr=lr)
    ropt = geoopt.optim.RiemannianAdam(riemannian_params, lr=lr) if riemannian_params else None

    if self.task == "classification":
        loss_fn = torch.nn.CrossEntropyLoss()
        y = y.long()
    elif self.task == "regression":
        loss_fn = torch.nn.MSELoss()
        y = y.float()
    elif self.task == "link_prediction":
        loss_fn = torch.nn.BCEWithLogitsLoss()
        # y = y.flatten().float()
        y = y.float()
    else:
        raise ValueError("Invalid task!")

    self.train()
    if use_tqdm:
        my_tqdm = tqdm(total=epochs, desc=tqdm_prefix)

    losses = []
    for i in range(epochs):
        opt.zero_grad()
        if riemannian_params:
            ropt.zero_grad()  # type: ignore
        y_pred = self(X, A_hat)
        loss = loss_fn(y_pred, y)
        loss.backward()
        opt.step()
        if riemannian_params:
            ropt.step()  # type: ignore

        # Progress bar
        if use_tqdm:
            my_tqdm.update(1)
            my_tqdm.set_description(f"Epoch {i + 1}/{epochs}, Loss: {loss.item():.4f}")

        # Early termination for nan loss
        if torch.isnan(loss):
            print("Loss is NaN, stopping training.")
            break
        losses.append(loss.item())

    if use_tqdm:
        my_tqdm.close()

    self.is_fitted_ = True
    self.loss_history_["train"] = losses
    return self
predict_proba(X, A=None)

Predict class probabilities using the trained Kappa GCN.

Parameters:
  • X (Tensor) –

    Feature matrix (NxD).

  • A (Tensor, default: None ) –

    Adjacency or distance matrix (NxN).

Returns:
  • Real[Tensor, 'n_nodes n_classes'] | Real[Tensor, 'n_nodes']

    torch.Tensor: Predicted class probabilities / regression targets.

Source code in manify/predictors/kappa_gcn.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def predict_proba(
    self, X: Float[torch.Tensor, "n_nodes dim"], A: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Real[torch.Tensor, "n_nodes n_classes"] | Real[torch.Tensor, "n_nodes"]:
    """Predict class probabilities using the trained Kappa GCN.

    Args:
        X (torch.Tensor): Feature matrix (NxD).
        A (torch.Tensor): Adjacency or distance matrix (NxN).

    Returns:
        torch.Tensor: Predicted class probabilities / regression targets.
    """
    # Copy everything
    X = X.clone()
    A = A.clone() if A is not None else None
    A_hat = get_A_hat(A, make_symmetric=True, add_self_loops=True) if A is not None else None

    # Get edges for test set
    self.eval()
    y_pred = self(X, A_hat)
    return y_pred

get_A_hat(A, make_symmetric=True, add_self_loops=True)

Normalize adjacency matrix.

Parameters:
  • A (Float[Tensor, 'n_nodes n_nodes']) –

    Adjacency matrix.

  • make_symmetric (bool, default: True ) –

    Whether to make the adjacency matrix symmetric.

  • add_self_loops (bool, default: True ) –

    Whether to add self-loops to the adjacency matrix.

Returns:
  • A_hat( Float[Tensor, 'n_nodes n_nodes'] ) –

    Normalized adjacency matrix.

Source code in manify/predictors/kappa_gcn.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def get_A_hat(
    A: Float[torch.Tensor, "n_nodes n_nodes"], make_symmetric: bool = True, add_self_loops: bool = True
) -> Float[torch.Tensor, "n_nodes n_nodes"]:
    """Normalize adjacency matrix.

    Args:
        A: Adjacency matrix.
        make_symmetric: Whether to make the adjacency matrix symmetric.
        add_self_loops: Whether to add self-loops to the adjacency matrix.

    Returns:
        A_hat: Normalized adjacency matrix.
    """
    # Fix nans
    A[torch.isnan(A)] = 0

    # Optional steps to make symmetric and add self-loops
    if make_symmetric and not torch.allclose(A, A.T):
        A = A + A.T
    if add_self_loops and not torch.allclose(torch.diag(A), torch.ones(A.shape[0], dtype=A.dtype, device=A.device)):
        A = A + torch.eye(A.shape[0], device=A.device, dtype=A.dtype)

    # Get degree matrix
    D = torch.diag(torch.sum(A, axis=1))

    # Compute D^(-1/2)
    D_inv_sqrt = torch.inverse(torch.sqrt(D))

    # Normalize adjacency matrix
    A_hat = D_inv_sqrt @ A @ D_inv_sqrt

    return A_hat.detach()

nn

Neural network layers for KappaGCN and related models.

FermiDiracDecoder(manifold, learnable_params=True)

Bases: Module

Fermi-Dirac decoder for link prediction tasks.

Computes pairwise distances and applies Fermi-Dirac transformation to predict edge probabilities.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • learnable_params (bool, default: True ) –

    If True, temperature and bias are learnable parameters. If False, they are fixed to 1.0 and 0.0, respectively.

Source code in manify/predictors/nn/layers.py
306
307
308
309
310
311
312
313
314
315
316
def __init__(self, manifold: Manifold | ProductManifold, learnable_params: bool = True):
    super().__init__()

    self.manifold = manifold

    if learnable_params:
        self.temperature = nn.Parameter(torch.tensor(1.0))
        self.bias = nn.Parameter(torch.tensor(0.0))
    else:
        self.register_buffer("temperature", torch.tensor(1.0))
        self.register_buffer("bias", torch.tensor(0.0))
forward(X, A_hat=None)

Forward pass through Fermi-Dirac decoder.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Node embeddings

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Ignored (for compatibility)

  • return_pairwise

    If True, return full pairwise matrix. If False, return flattened upper triangle.

Returns:
  • Float[Tensor, 'n_nodes n_nodes']

    Edge probabilities (logits, apply sigmoid if needed)

Source code in manify/predictors/nn/layers.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes n_nodes"]:
    """Forward pass through Fermi-Dirac decoder.

    Args:
        X: Node embeddings
        A_hat: Ignored (for compatibility)
        return_pairwise: If True, return full pairwise matrix. If False, return flattened upper triangle.

    Returns:
        Edge probabilities (logits, apply sigmoid if needed)
    """
    # Compute pairwise distances
    pairwise_dist = self.manifold.pdist2(X)

    # Apply Fermi-Dirac transformation
    logits = -(pairwise_dist - self.bias) / self.temperature

    return logits

KappaGCNLayer(in_features, out_features, manifold, nonlinearity=torch.relu)

Bases: Module

Implementation for the Kappa GCN layer.

Parameters:
  • in_features (int) –

    Number of input features

  • out_features (int) –

    Number of output features

  • manifold (Manifold) –

    Manifold object for the Kappa GCN

  • nonlinearity (Callable | None, default: relu ) –

    Function for nonlinear activation.

Attributes:
  • W

    Weight matrix parameter.

  • sigma

    Nonlinear activation function applied via the manifold.

  • manifold

    The manifold object for geometric operations.

Source code in manify/predictors/nn/layers.py
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self, in_features: int, out_features: int, manifold: Manifold, nonlinearity: Callable | None = torch.relu
):
    super().__init__()

    # Parameters are Euclidean, straightforwardly
    self.W = torch.nn.Parameter(torch.randn(in_features, out_features) * 0.01)

    # Nonlinearity must be applied via the manifold
    self.sigma = manifold.apply(nonlinearity) if nonlinearity else lambda x: x

    # Also store manifold
    self.manifold = manifold
forward(X, A_hat=None)

Forward pass for the Kappa GCN layer.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Embedding matrix

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Normalized adjacency matrix

Returns:
  • AXW( Float[Tensor, 'n_nodes dim'] ) –

    Transformed node features after message passing and nonlinear activation.

Source code in manify/predictors/nn/layers.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the Kappa GCN layer.

    Args:
        X: Embedding matrix
        A_hat: Normalized adjacency matrix

    Returns:
        AXW: Transformed node features after message passing and nonlinear activation.
    """
    # 1. right-multiply X by W - mobius_matvec broadcasts correctly (verified)
    XW = self.manifold.manifold.mobius_matvec(m=self.W, x=X)

    # 2. left-multiply (X @ W) by A_hat - we need our own implementation for this
    if A_hat is None:
        AXW = XW
    elif isinstance(self.manifold, ProductManifold):
        XWs = self.manifold.factorize(XW)
        AXW = torch.hstack([self._left_multiply(A_hat, XW, M) for XW, M in zip(XWs, self.manifold.P, strict=False)])
    else:
        AXW = self._left_multiply(A_hat, XW, self.manifold)

    # 3. Apply nonlinearity - note that sigma is wrapped with our manifold.apply decorator
    AXW = self.sigma(AXW)

    return AXW

KappaSequential(*layers)

Bases: Module

Sequential container for κ-layers that properly handles adjacency matrices.

Similar to nn.Sequential but passes the adjacency matrix through each layer. All layers should accept (X, A_hat) and return X.

Parameters:
  • *layers (Module, default: () ) –

    Variable number of layers to be added to the sequence.

Source code in manify/predictors/nn/layers.py
114
115
116
def __init__(self, *layers: nn.Module):
    super().__init__()
    self.layers = nn.ModuleList(layers)
forward(X, A_hat=None)

Forward pass through all layers.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Adjacency matrix passed to each layer

Returns:
  • Float[Tensor, 'n_nodes out_dim']

    Output after passing through all layers

Source code in manify/predictors/nn/layers.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes out_dim"]:
    """Forward pass through all layers.

    Args:
        X: Input features
        A_hat: Adjacency matrix passed to each layer

    Returns:
        Output after passing through all layers
    """
    for layer in self.layers:
        X = layer(X, A_hat)
    return X
append(layer)

Add a layer to the end of the sequence.

Source code in manify/predictors/nn/layers.py
134
135
136
def append(self, layer: nn.Module) -> None:
    """Add a layer to the end of the sequence."""
    self.layers.append(layer)

StereographicLogits(out_features, manifold, apply_softmax=False)

Bases: Module

Stereographic logits layer for classification and regression on product manifolds.

Computes signed distances from hyperplanes in the product manifold space. Can optionally apply softmax for classification tasks.

Parameters:
  • out_features (int) –

    Number of output classes (dimensionality of output space)

  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • apply_softmax (bool, default: False ) –

    Whether to apply softmax to the output logits (default: False)

Source code in manify/predictors/nn/layers.py
170
171
172
173
174
175
176
177
178
179
180
181
def __init__(self, out_features: int, manifold: Manifold | ProductManifold, apply_softmax: bool = False):
    super().__init__()

    self.out_features = out_features
    self.manifold = manifold
    self.apply_softmax = apply_softmax

    # Weight matrix (Euclidean parameters)
    self.W = nn.Parameter(torch.randn(manifold.dim, out_features) * 0.01)

    # Bias points on the manifold
    self.p_ks = geoopt.ManifoldParameter(torch.zeros(out_features, manifold.dim), manifold=manifold.manifold)
forward(X, A_hat=None, aggregate_logits=False)

Forward pass through stereographic logits.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Optional adjacency matrix for logit aggregation

  • aggregate_logits (bool, default: False ) –

    Whether to aggregate logits using adjacency matrix

Returns:
  • Float[Tensor, 'n_nodes n_classes']

    Logits (or probabilities if apply_softmax=True)

Source code in manify/predictors/nn/layers.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    aggregate_logits: bool = False,
) -> Float[torch.Tensor, "n_nodes n_classes"]:
    """Forward pass through stereographic logits.

    Args:
        X: Input features
        A_hat: Optional adjacency matrix for logit aggregation
        aggregate_logits: Whether to aggregate logits using adjacency matrix

    Returns:
        Logits (or probabilities if apply_softmax=True)
    """
    # Compute logits based on manifold type
    if isinstance(self.manifold, ProductManifold):
        logits = self._get_logits_product_manifold(X, self.W, self.p_ks, self.manifold)
    else:
        logits = self._get_logits_single_manifold(X, self.W, self.p_ks, self.manifold, return_inner_products=False)

    # Optional aggregation via adjacency matrix
    if A_hat is not None and aggregate_logits:
        logits = A_hat @ logits

    # Optional softmax for classification
    if self.apply_softmax:
        logits = torch.softmax(logits, dim=-1)

    return logits

layers

Neural network layers for product manifolds.

KappaGCNLayer(in_features, out_features, manifold, nonlinearity=torch.relu)

Bases: Module

Implementation for the Kappa GCN layer.

Parameters:
  • in_features (int) –

    Number of input features

  • out_features (int) –

    Number of output features

  • manifold (Manifold) –

    Manifold object for the Kappa GCN

  • nonlinearity (Callable | None, default: relu ) –

    Function for nonlinear activation.

Attributes:
  • W

    Weight matrix parameter.

  • sigma

    Nonlinear activation function applied via the manifold.

  • manifold

    The manifold object for geometric operations.

Source code in manify/predictors/nn/layers.py
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self, in_features: int, out_features: int, manifold: Manifold, nonlinearity: Callable | None = torch.relu
):
    super().__init__()

    # Parameters are Euclidean, straightforwardly
    self.W = torch.nn.Parameter(torch.randn(in_features, out_features) * 0.01)

    # Nonlinearity must be applied via the manifold
    self.sigma = manifold.apply(nonlinearity) if nonlinearity else lambda x: x

    # Also store manifold
    self.manifold = manifold
forward(X, A_hat=None)

Forward pass for the Kappa GCN layer.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Embedding matrix

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Normalized adjacency matrix

Returns:
  • AXW( Float[Tensor, 'n_nodes dim'] ) –

    Transformed node features after message passing and nonlinear activation.

Source code in manify/predictors/nn/layers.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the Kappa GCN layer.

    Args:
        X: Embedding matrix
        A_hat: Normalized adjacency matrix

    Returns:
        AXW: Transformed node features after message passing and nonlinear activation.
    """
    # 1. right-multiply X by W - mobius_matvec broadcasts correctly (verified)
    XW = self.manifold.manifold.mobius_matvec(m=self.W, x=X)

    # 2. left-multiply (X @ W) by A_hat - we need our own implementation for this
    if A_hat is None:
        AXW = XW
    elif isinstance(self.manifold, ProductManifold):
        XWs = self.manifold.factorize(XW)
        AXW = torch.hstack([self._left_multiply(A_hat, XW, M) for XW, M in zip(XWs, self.manifold.P, strict=False)])
    else:
        AXW = self._left_multiply(A_hat, XW, self.manifold)

    # 3. Apply nonlinearity - note that sigma is wrapped with our manifold.apply decorator
    AXW = self.sigma(AXW)

    return AXW
KappaSequential(*layers)

Bases: Module

Sequential container for κ-layers that properly handles adjacency matrices.

Similar to nn.Sequential but passes the adjacency matrix through each layer. All layers should accept (X, A_hat) and return X.

Parameters:
  • *layers (Module, default: () ) –

    Variable number of layers to be added to the sequence.

Source code in manify/predictors/nn/layers.py
114
115
116
def __init__(self, *layers: nn.Module):
    super().__init__()
    self.layers = nn.ModuleList(layers)
forward(X, A_hat=None)

Forward pass through all layers.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Adjacency matrix passed to each layer

Returns:
  • Float[Tensor, 'n_nodes out_dim']

    Output after passing through all layers

Source code in manify/predictors/nn/layers.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def forward(
    self, X: Float[torch.Tensor, "n_nodes dim"], A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None
) -> Float[torch.Tensor, "n_nodes out_dim"]:
    """Forward pass through all layers.

    Args:
        X: Input features
        A_hat: Adjacency matrix passed to each layer

    Returns:
        Output after passing through all layers
    """
    for layer in self.layers:
        X = layer(X, A_hat)
    return X
append(layer)

Add a layer to the end of the sequence.

Source code in manify/predictors/nn/layers.py
134
135
136
def append(self, layer: nn.Module) -> None:
    """Add a layer to the end of the sequence."""
    self.layers.append(layer)
StereographicLogits(out_features, manifold, apply_softmax=False)

Bases: Module

Stereographic logits layer for classification and regression on product manifolds.

Computes signed distances from hyperplanes in the product manifold space. Can optionally apply softmax for classification tasks.

Parameters:
  • out_features (int) –

    Number of output classes (dimensionality of output space)

  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • apply_softmax (bool, default: False ) –

    Whether to apply softmax to the output logits (default: False)

Source code in manify/predictors/nn/layers.py
170
171
172
173
174
175
176
177
178
179
180
181
def __init__(self, out_features: int, manifold: Manifold | ProductManifold, apply_softmax: bool = False):
    super().__init__()

    self.out_features = out_features
    self.manifold = manifold
    self.apply_softmax = apply_softmax

    # Weight matrix (Euclidean parameters)
    self.W = nn.Parameter(torch.randn(manifold.dim, out_features) * 0.01)

    # Bias points on the manifold
    self.p_ks = geoopt.ManifoldParameter(torch.zeros(out_features, manifold.dim), manifold=manifold.manifold)
forward(X, A_hat=None, aggregate_logits=False)

Forward pass through stereographic logits.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Input features

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Optional adjacency matrix for logit aggregation

  • aggregate_logits (bool, default: False ) –

    Whether to aggregate logits using adjacency matrix

Returns:
  • Float[Tensor, 'n_nodes n_classes']

    Logits (or probabilities if apply_softmax=True)

Source code in manify/predictors/nn/layers.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
    aggregate_logits: bool = False,
) -> Float[torch.Tensor, "n_nodes n_classes"]:
    """Forward pass through stereographic logits.

    Args:
        X: Input features
        A_hat: Optional adjacency matrix for logit aggregation
        aggregate_logits: Whether to aggregate logits using adjacency matrix

    Returns:
        Logits (or probabilities if apply_softmax=True)
    """
    # Compute logits based on manifold type
    if isinstance(self.manifold, ProductManifold):
        logits = self._get_logits_product_manifold(X, self.W, self.p_ks, self.manifold)
    else:
        logits = self._get_logits_single_manifold(X, self.W, self.p_ks, self.manifold, return_inner_products=False)

    # Optional aggregation via adjacency matrix
    if A_hat is not None and aggregate_logits:
        logits = A_hat @ logits

    # Optional softmax for classification
    if self.apply_softmax:
        logits = torch.softmax(logits, dim=-1)

    return logits
FermiDiracDecoder(manifold, learnable_params=True)

Bases: Module

Fermi-Dirac decoder for link prediction tasks.

Computes pairwise distances and applies Fermi-Dirac transformation to predict edge probabilities.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry

  • learnable_params (bool, default: True ) –

    If True, temperature and bias are learnable parameters. If False, they are fixed to 1.0 and 0.0, respectively.

Source code in manify/predictors/nn/layers.py
306
307
308
309
310
311
312
313
314
315
316
def __init__(self, manifold: Manifold | ProductManifold, learnable_params: bool = True):
    super().__init__()

    self.manifold = manifold

    if learnable_params:
        self.temperature = nn.Parameter(torch.tensor(1.0))
        self.bias = nn.Parameter(torch.tensor(0.0))
    else:
        self.register_buffer("temperature", torch.tensor(1.0))
        self.register_buffer("bias", torch.tensor(0.0))
forward(X, A_hat=None)

Forward pass through Fermi-Dirac decoder.

Parameters:
  • X (Float[Tensor, 'n_nodes dim']) –

    Node embeddings

  • A_hat (Float[Tensor, 'n_nodes n_nodes'] | None, default: None ) –

    Ignored (for compatibility)

  • return_pairwise

    If True, return full pairwise matrix. If False, return flattened upper triangle.

Returns:
  • Float[Tensor, 'n_nodes n_nodes']

    Edge probabilities (logits, apply sigmoid if needed)

Source code in manify/predictors/nn/layers.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    A_hat: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes n_nodes"]:
    """Forward pass through Fermi-Dirac decoder.

    Args:
        X: Node embeddings
        A_hat: Ignored (for compatibility)
        return_pairwise: If True, return full pairwise matrix. If False, return flattened upper triangle.

    Returns:
        Edge probabilities (logits, apply sigmoid if needed)
    """
    # Compute pairwise distances
    pairwise_dist = self.manifold.pdist2(X)

    # Apply Fermi-Dirac transformation
    logits = -(pairwise_dist - self.bias) / self.temperature

    return logits
StereographicLayerNorm(manifold, embedding_dim, curvatures)

Bases: Module

Stereographic Layer Normalization.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry.

  • embedding_dim (int) –

    Embedding dimension of the input points.

  • curvatures (Tensor['num_heads 1 1']) –

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

Attributes:
  • manifold

    The manifold object for geometric operations.

  • stereographic_norm

    Stereographic layernorm used in the tangent space.

  • curvatures

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

Source code in manify/predictors/nn/layers.py
358
359
360
361
362
363
364
365
def __init__(
    self, manifold: Manifold | ProductManifold, embedding_dim: int, curvatures: torch.Tensor["num_heads 1 1"]
):
    super().__init__()

    self.manifold = manifold
    self.stereographic_norm = self.manifold.apply(nn.LayerNorm(embedding_dim))
    self.curvatures = curvatures
forward(X)

Apply layer normalization on the stereographic manifold.

Source code in manify/predictors/nn/layers.py
367
368
369
370
371
def forward(self, X: Float[torch.Tensor, "n_nodes dim"]) -> Float[torch.Tensor, "n_nodes dim"]:
    """Apply layer normalization on the stereographic manifold."""
    norm_X = self.stereographic_norm(X)
    output = geoopt.manifolds.stereographic.math.project(norm_X, self.curvatures)
    return output
GeometricLinearizedAttention(curvatures, num_heads, head_dim)

Bases: Module

Geometric Linearized Attention.

Parameters:
  • curvatures (float | list[float]) –

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

  • num_heads (int) –

    Number of attention heads.

  • head_dim (int) –

    Dimension of each attention head.

Attributes:
  • num_heads

    Number of attention heads.

  • head_dim

    Dimension of each attention head.

  • epsilon

    Small epsilon for masking inverse denominator (constant).

  • clamp_epsilon

    Minimum clamp value for numerical stability in gamma denominator (constant).

Source code in manify/predictors/nn/layers.py
390
391
392
393
394
395
396
397
398
def __init__(self, curvatures: float | list[float], num_heads: int, head_dim: int):
    super().__init__()

    self.num_heads = num_heads
    self.curvatures = curvatures

    self.head_dim = head_dim
    self._epsilon = 1e-5
    self._clamp_epsilon = 1e-10
forward(Q, K, V, mask)

Forward pass for the geometric linearized attention layer.

Parameters:
  • Q (Float[Tensor, 'batch_size num_heads n_nodes head_dim']) –

    Query tensor.

  • K (Float[Tensor, 'batch_size num_heads n_nodes head_dim']) –

    Key tensor.

  • V (Float[Tensor, 'batch_size num_heads n_nodes head_dim']) –

    Value tensor.

  • mask (Float[Tensor, '1 1 n_nodes n_nodes']) –

    Mask tensor for attention.

Returns:
  • Float[Tensor, 'batch_size n_nodes dim']

    Output tensor after applying attention.

Source code in manify/predictors/nn/layers.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def forward(
    self,
    Q: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
    K: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
    V: Float[torch.Tensor, "batch_size num_heads n_nodes head_dim"],
    mask: Float[torch.Tensor, "1 1 n_nodes n_nodes"],
) -> Float[torch.Tensor, "batch_size n_nodes dim"]:
    """Forward pass for the geometric linearized attention layer.

    Args:
        Q: Query tensor.
        K: Key tensor.
        V: Value tensor.
        mask: Mask tensor for attention.

    Returns:
        Output tensor after applying attention.
    """
    v1 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, Q, k=self.curvatures)
    v2 = geoopt.manifolds.stereographic.math.parallel_transport0back(V, K, k=self.curvatures)

    gamma = geoopt.manifolds.stereographic.math.lambda_x(x=V, k=self.curvatures, keepdim=True, dim=-1)
    denominator = geoopt.utils.clamp_abs((gamma - 1), self._clamp_epsilon)

    x = ((gamma / denominator) * V) * mask[None, :, None]

    v1 = nn.functional.elu(v1) + 1
    v2 = (denominator * (nn.functional.elu(v2) + 1)) * mask[None, :, None]

    # Linearized approximation
    v2_cumsum = v2.sum(dim=-2)  # [B, H, D]
    D = torch.einsum("...nd,...d->...n", v1, v2_cumsum.type_as(v1))  # normalization terms
    D_inv = 1.0 / D.masked_fill_(D == 0, self._epsilon)
    context = torch.einsum("...nd,...ne->...de", v2, x)
    X = torch.einsum("...de,...nd,...n->...ne", context, v1, D_inv)

    X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)
    X = geoopt.manifolds.stereographic.math.mobius_scalar_mul(
        torch.tensor(0.5, dtype=X.dtype, device=X.device), X, k=self.curvatures, dim=-1
    )
    X = geoopt.manifolds.stereographic.math.project(X, k=self.curvatures)

    return X
StereographicAttention(manifold, num_heads, dim, head_dim)

Bases: Module

Stereographic Attention Layer.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry.

  • num_heads (int) –

    Number of attention heads.

  • dim (int) –

    Embedding dimension of the input points.

  • head_dim (int) –

    Dimension of each attention head.

Attributes:
  • manifold

    The manifold object for geometric operations.

  • curvatures

    Tensor of shape [num_heads, 1, 1] representing the curvature value used per head in geometric computations.

  • num_heads

    Number of attention heads.

  • head_dim

    Dimensionality of each attention head.

  • W_q

    Linear layer projecting inputs to query vectors.

  • W_k

    Linear layer projecting inputs to key vectors.

  • W_v

    Manifold-aware linear layer projecting to value vectors.

  • attn

    Stereographic multi-head attention module.

  • ff

    Manifold-aware linear layer for the feedforward output.

Source code in manify/predictors/nn/layers.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def __init__(self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int):
    super().__init__()

    self.manifold = manifold
    self.num_heads = num_heads
    self.head_dim = head_dim
    self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), self.num_heads)

    self.W_q = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
    self.W_k = nn.Linear(in_features=dim, out_features=self.num_heads * self.head_dim)
    self.W_v = KappaGCNLayer(in_features=dim, out_features=self.num_heads * self.head_dim, manifold=self.manifold)

    self.attn = GeometricLinearizedAttention(
        curvatures=self.curvatures, num_heads=self.num_heads, head_dim=self.head_dim
    )
    self.ff = KappaGCNLayer(in_features=self.num_heads * self.head_dim, out_features=dim, manifold=self.manifold)
forward(X, mask=None)

Forward pass for the stereographic attention layer.

Source code in manify/predictors/nn/layers.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass for the stereographic attention layer."""
    Q = self._split_heads(self.W_q(X))  # [B, H, N, D]
    K = self._split_heads(self.W_k(X))
    V = self._split_heads(self.W_v(X=X))

    attn_out = self.attn(Q, K, V, mask.unsqueeze(0).unsqueeze(0))  # type: ignore
    attn_out = self._combine_heads(attn_out)

    out = self.ff(X=attn_out)

    return out
StereographicTransformer(manifold, num_heads, dim, head_dim, use_layer_norm=True)

Bases: Module

Stereographic Transformer Block.

Parameters:
  • manifold (Manifold | ProductManifold) –

    Manifold or ProductManifold object defining the geometry.

  • num_heads (int) –

    Number of attention heads.

  • dim (int) –

    Dimensionality of the input features.

  • head_dim (int) –

    Dimensionality of each attention head.

  • use_layer_norm (bool, default: True ) –

    Whether to apply layer normalization in tangent space.

Attributes:
  • manifold

    The manifold object for geometric operations.

  • curvatures

    Manifold curvatures reshaped to [num_heads, 1, 1] for broadcasting.

  • mha

    Multi-head stereographic attention module.

  • norm1

    First normalization layer (can be Identity or StereographicLayerNorm).

  • norm2

    Second normalization layer.

  • mlpblock

    Feedforward network in stereographic space.

  • stereographic_activation

    Activation wrapped to operate in tangent space.

Source code in manify/predictors/nn/layers.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
def __init__(
    self, manifold: Manifold | ProductManifold, num_heads: int, dim: int, head_dim: int, use_layer_norm: bool = True
):
    super().__init__()

    # Check that manifold is stereographic
    if not manifold.is_stereographic:
        raise ValueError(
            "Manifold must be stereographic for StereographicLayerNorm to work. Please use manifold.stereographic() to convert."
        )

    self.manifold = manifold
    self.curvatures = _reshape_curvatures(_get_curvatures(self.manifold), num_heads)
    self.stereographic_activation = self.manifold.apply(nn.ReLU())
    self.mha = StereographicAttention(manifold=self.manifold, num_heads=num_heads, dim=dim, head_dim=head_dim)

    if use_layer_norm:
        self.norm1 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
        self.norm2 = StereographicLayerNorm(manifold=self.manifold, embedding_dim=dim, curvatures=self.curvatures)
    else:
        self.norm1 = nn.Identity()
        self.norm2 = nn.Identity()

    self.mlpblock = nn.Sequential(
        KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
        self.stereographic_activation,
        KappaGCNLayer(in_features=dim, out_features=dim, manifold=self.manifold),
    )
forward(X, mask=None)

Forward pass through the stereographic transformer block.

Source code in manify/predictors/nn/layers.py
581
582
583
584
585
586
587
588
589
590
591
592
def forward(
    self,
    X: Float[torch.Tensor, "n_nodes dim"],
    mask: Float[torch.Tensor, "n_nodes n_nodes"] | None = None,
) -> Float[torch.Tensor, "n_nodes dim"]:
    """Forward pass through the stereographic transformer block."""
    X = geoopt.manifolds.stereographic.math.mobius_add(self.mha(self.norm1(X), mask), X, self.curvatures)
    X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)
    X = geoopt.manifolds.stereographic.math.mobius_add(self.mlpblock(self.norm2(X)), X, self.curvatures)
    X = geoopt.manifolds.stereographic.math.project(X, self.curvatures)

    return X

perceptron

Product space perceptron implementation.

ProductSpacePerceptron(pm, max_epochs=1000, patience=5, weights=None, task='classification', random_state=None, device=None)

Bases: BasePredictor

A product-space perceptron model for multiclass classification in the product manifold space.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object for the product space.

  • max_epochs (int, default: 1000 ) –

    Maximum number of training epochs.

  • patience (int, default: 5 ) –

    Number of consecutive epochs without improvement to consider convergence.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Per-manifold weights for kernel combination.

  • task (str, default: 'classification' ) –

    Task type (defaults to "classification").

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Attributes:
  • pm

    ProductManifold object associated with the predictor.

  • max_epochs

    Maximum number of training epochs.

  • patience

    Number of consecutive epochs without improvement to consider convergence.

  • weights

    Per-manifold weights for kernel combination.

  • alpha

    Dictionary storing perceptron coefficients for each class.

  • X_train_

    Training data points.

  • y_train_

    Training labels.

  • is_fitted_ (bool) –

    Boolean flag indicating if the predictor has been fitted.

Source code in manify/predictors/perceptron.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __init__(
    self,
    pm: ProductManifold,
    max_epochs: int = 1_000,
    patience: int = 5,
    weights: Float[torch.Tensor, "n_manifolds"] | None = None,
    task: str = "classification",
    random_state: int | None = None,
    device: str | None = None,
):
    # Initialize base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)
    self.pm = pm  # ProductManifold instance
    self.max_epochs = max_epochs
    self.patience = patience  # Number of consecutive epochs without improvement to consider convergence
    self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
    assert len(self.weights) == len(pm.P), "Number of weights must match the number of manifolds."
fit(X, y)

Trains the perceptron model using the provided data and labels.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Training data tensor.

  • y (Int[Tensor, 'n_samples']) –

    Class labels for the training data.

Returns:
Source code in manify/predictors/perceptron.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def fit(
    self, X: Float[torch.Tensor, "n_samples n_manifolds"], y: Int[torch.Tensor, "n_samples"]
) -> ProductSpacePerceptron:
    """Trains the perceptron model using the provided data and labels.

    Args:
        X: Training data tensor.
        y: Class labels for the training data.

    Returns:
        self: Fitted perceptron model.
    """
    # Identify unique classes for multiclass classification
    self._store_classes(y)
    n_samples = X.shape[0]

    # Precompute kernel matrix
    Ks, _ = product_kernel(self.pm, X, None)
    K = torch.ones((n_samples, n_samples), dtype=X.dtype, device=X.device)
    for K_m, w in zip(Ks, self.weights, strict=False):
        K += w * K_m

    # Store training data and labels for prediction
    self.X_train_ = X
    self.y_train_ = y

    # Initialize dictionary to store alpha coefficients for each class
    self.alpha = {}

    # For patience checking
    best_epoch, least_errors = 0, n_samples + 1

    for class_label in self.classes_:
        class_label_item = class_label.item()

        # One-vs-rest labels
        y_binary = torch.where(y == class_label_item, 1, -1)  # Shape: (n_samples,)

        # Initialize alpha coefficients for this class
        alpha = torch.zeros(n_samples, dtype=X.dtype, device=X.device)

        for epoch in range(self.max_epochs):
            # Compute decision function: f = K @ (alpha * y_binary)
            f = K @ (alpha * y_binary)  # Shape: (n_samples,)

            # Compute predictions
            predictions = torch.sign(f)

            # Find misclassified samples
            misclassified = predictions != y_binary

            # If no misclassifications, break early
            if not misclassified.any():
                break

            # Test patience
            n_errors = misclassified.sum().item()
            if n_errors < least_errors:
                best_epoch, least_errors = epoch, n_errors
            if epoch - best_epoch >= self.patience:
                break

            # Update alpha coefficients for misclassified samples
            alpha[misclassified] += 1

        # Store the alpha coefficients for the current class
        self.alpha[class_label_item] = alpha

    self.is_fitted_ = True
    return self
predict_proba(X)

Predicts the decision values for each class.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Test data tensor.

Returns:
  • decision_values( Float[Tensor, 'n_points n_classes'] ) –

    Decision values for each test sample and each class.

Source code in manify/predictors/perceptron.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def predict_proba(
    self,
    X: Float[torch.Tensor, "n_points n_features"],  # type: ignore[override]
) -> Float[torch.Tensor, "n_points n_classes"]:
    """Predicts the decision values for each class.

    Args:
        X: Test data tensor.

    Returns:
        decision_values: Decision values for each test sample and each class.
    """
    n_samples = X.shape[0]
    n_classes = len(self.classes_)
    decision_values = torch.zeros((n_samples, n_classes), dtype=X.dtype, device=X.device)

    # Compute kernel matrix between training data and test data
    Ks, _ = product_kernel(self.pm, self.X_train_, X)
    K_test = torch.ones((self.X_train_.shape[0], n_samples), dtype=X.dtype, device=X.device)
    for K_m, w in zip(Ks, self.weights, strict=False):
        K_test += w * K_m
    # K_test = self.X_train_ @ X.T

    for idx, class_label in enumerate(self.classes_):
        class_label_item = class_label.item()
        alpha = self.alpha[class_label_item]  # Shape: (n_samples_train,)
        y_binary = torch.where(self.y_train_ == class_label_item, 1, -1)  # Shape: (n_samples_train,)

        # Compute decision function for test samples
        f = (alpha * y_binary) @ K_test  # Shape: (n_samples_test,)
        decision_values[:, idx] = f

    return decision_values

svm

Implementation for Support Vector Machine in Product Manifolds.

ProductSpaceSVM(pm, weights=None, h_constraints=True, e_constraints=True, s_constraints=True, task='classification', epsilon=1e-05, random_state=None, device=None)

Bases: BasePredictor

Product Space SVM class in a product manifold setting.

Trains one-vs-rest SVMs with Euclidean, spherical, and hyperbolic constraints enforced via second-order-cone (SOC) formulations for convexity.

Parameters:
  • pm (ProductManifold) –

    A ProductManifold instance specifying component manifolds.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Optional per-manifold weights tensor.

  • h_constraints (bool, default: True ) –

    Whether to enforce hyperbolic constraints.

  • e_constraints (bool, default: True ) –

    Whether to enforce Euclidean constraints.

  • s_constraints (bool, default: True ) –

    Whether to enforce spherical constraints.

  • task (Literal['classification', 'regression'], default: 'classification' ) –

    Task type, either "classification" or "regression".

  • epsilon (float, default: 1e-05 ) –

    Slack parameter for SOC constraints.

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Attributes:
  • pm

    ProductManifold object associated with the predictor.

  • weights

    Per-manifold weights for kernel combination.

  • h_constraints

    Whether to enforce hyperbolic constraints.

  • e_constraints

    Whether to enforce Euclidean constraints.

  • s_constraints

    Whether to enforce spherical constraints.

  • eps

    Slack parameter for SOC constraints.

  • beta

    Dictionary storing SVM coefficients for each class.

  • zeta

    Dictionary storing slack variables for each class.

  • epsilon

    Dictionary storing epsilon values for each class.

  • b

    Dictionary storing bias terms for each class.

  • X_train_

    Training data points.

  • is_fitted_ (bool) –

    Boolean flag indicating if the predictor has been fitted.

Initialize the ProductSpaceSVM.

Parameters:
  • pm (ProductManifold) –

    A ProductManifold instance specifying component manifolds.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Optional per-manifold weights tensor.

  • h_constraints (bool, default: True ) –

    Whether to enforce hyperbolic constraints.

  • e_constraints (bool, default: True ) –

    Whether to enforce Euclidean constraints.

  • s_constraints (bool, default: True ) –

    Whether to enforce spherical constraints.

  • task (Literal['classification', 'regression'], default: 'classification' ) –

    Task type, either "classification" or "regression".

  • epsilon (float, default: 1e-05 ) –

    Slack parameter for SOC constraints.

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Source code in manify/predictors/svm.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    pm: ProductManifold,
    weights: Float[torch.Tensor, "n_manifolds"] | None = None,
    h_constraints: bool = True,
    e_constraints: bool = True,
    s_constraints: bool = True,
    task: Literal["classification", "regression"] = "classification",
    epsilon: float = 1e-5,
    random_state: int | None = None,
    device: str | None = None,
):
    """Initialize the ProductSpaceSVM.

    Args:
        pm: A ProductManifold instance specifying component manifolds.
        weights: Optional per-manifold weights tensor.
        h_constraints: Whether to enforce hyperbolic constraints.
        e_constraints: Whether to enforce Euclidean constraints.
        s_constraints: Whether to enforce spherical constraints.
        task: Task type, either "classification" or "regression".
        epsilon: Slack parameter for SOC constraints.
        random_state: Random seed for reproducibility.
        device: Device for tensor computations.
    """
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)
    self.pm = pm
    self.h_constraints = h_constraints
    self.s_constraints = s_constraints
    self.e_constraints = e_constraints
    self.eps = epsilon
    self.task = task
    self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
    assert len(self.weights) == len(pm.P), "Number of weights must match the number of manifolds."
fit(X, y)

Fit one-vs-rest SVMs on the product manifold data.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Training points tensor.

  • y (Int[Tensor, 'n_samples']) –

    Integer class labels tensor.

Returns:
Source code in manify/predictors/svm.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def fit(
    self,
    X: Float[torch.Tensor, "n_samples n_manifolds"],
    y: Int[torch.Tensor, "n_samples"],
) -> ProductSpaceSVM:
    """Fit one-vs-rest SVMs on the product manifold data.

    Args:
        X: Training points tensor.
        y: Integer class labels tensor.

    Returns:
        self: Fitted ProductSpaceSVM instance.
    """
    # unique classes
    self._store_classes(y)
    n = X.shape[0]

    # aggregated kernel
    Ks, _ = product_kernel(self.pm, X, None)
    K_sum = torch.ones((n, n), dtype=X.dtype, device=X.device)
    for K_m, w in zip(Ks, self.weights, strict=False):
        K_sum += w * K_m

    X_np = X.detach().cpu().numpy()
    K_np = K_sum.detach().cpu().numpy()

    def sqrtm_psd(P: np.ndarray) -> Any:
        w, V = np.linalg.eigh(P)
        w_s = np.sqrt(np.clip(w, 0, None))
        B = V @ np.diag(w_s) @ V.T
        return (B + B.T) * 0.5

    # containers
    self.beta = {}
    self.zeta = {}
    self.epsilon = {}
    self.b = {}

    for cls in self.classes_:
        cls_item = cls.item() if isinstance(cls, torch.Tensor) else cls
        # one-vs-rest labels: +1 for cls, -1 for others
        y_bin = torch.where(y == cls_item, 1, -1)
        Y = torch.diagflat(y_bin).detach().cpu().numpy()

        # variables
        beta_var = cp.Variable(n)
        zeta = cp.Variable(n, nonneg=True)
        eps_var = cp.Variable(1)
        b_var = cp.Variable(1)

        # base constraints
        constraints = [eps_var >= 0]
        constraints.append(Y @ (K_np @ beta_var + b_var) >= eps_var - zeta)

        # per-manifold SOC
        for M, K_comp in zip(self.pm.P, Ks, strict=False):
            P_np = K_comp.detach().cpu().numpy()
            if M.type == "E" and self.e_constraints:
                B = sqrtm_psd(P_np)
                constraints.append(cp.norm(B @ beta_var, 2) <= 1.0)
            elif M.type == "S" and self.s_constraints:
                B = sqrtm_psd(P_np)
                constraints.append(cp.norm(B @ beta_var, 2) <= np.sqrt(np.pi / 2))
            elif M.type == "H" and self.h_constraints:
                # PSD split
                eigvals, eigvecs = np.linalg.eigh(P_np)
                plus = np.clip(eigvals, 0, None)
                minus = np.clip(-eigvals, 0, None)
                Kp = (eigvecs @ np.diag(plus) @ eigvecs.T + (eigvecs @ np.diag(plus) @ eigvecs.T).T) * 0.5
                Km = (eigvecs @ np.diag(minus) @ eigvecs.T + (eigvecs @ np.diag(minus) @ eigvecs.T).T) * 0.5
                Bp = sqrtm_psd(Kp)
                Bm = sqrtm_psd(Km)

                C_H = abs(M.curvature)
                R = -M.scale
                r_h = abs(np.arcsinh(-(R**2) * C_H))
                r = self.eps

                constraints.append(cp.norm(Bm @ beta_var, 2) <= np.sqrt(max(r, 0.0)))
                constraints.append(cp.norm(Bp @ beta_var, 2) <= np.sqrt(max(r + r_h, 0.0)))

        # solve
        prob = cp.Problem(cp.Minimize(-eps_var + cp.sum(zeta)), constraints)
        prob.solve(solver="SCS")

        # save results
        self.beta[cls_item] = np.ravel(beta_var.value)
        self.zeta[cls_item] = zeta.value
        self.epsilon[cls_item] = float(eps_var.value)
        self.b[cls_item] = float(b_var.value)

    # store training data
    self.X_train_ = torch.tensor(X_np, dtype=torch.float32)
    self.is_fitted_ = True
    return self
predict_proba(X)

Predict class probabilities using the fitted SVMs.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Test points tensor.

Returns:
  • class_probabilities( Float[Tensor, 'n_samples n_classes'] ) –

    Class probabilities for each test sample.

Source code in manify/predictors/svm.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def predict_proba(
    self,
    X: Float[torch.Tensor, "n_samples n_manifolds"],
) -> Float[torch.Tensor, "n_samples n_classes"]:
    """Predict class probabilities using the fitted SVMs.

    Args:
        X: Test points tensor.

    Returns:
        class_probabilities: Class probabilities for each test sample.
    """
    X_tensor = torch.tensor(X, dtype=torch.float32) if not isinstance(X, torch.Tensor) else X
    X_tensor = X_tensor.to(self.X_train_.device)

    Ks_test, _ = product_kernel(self.pm, self.X_train_, X_tensor)
    Kt = torch.ones((self.X_train_.shape[0], X_tensor.shape[0]), device=X_tensor.device)
    for K_m, w in zip(Ks_test, self.weights, strict=False):
        Kt += w * K_m
    Kt_np = Kt.detach().cpu().numpy()

    n_test = X_tensor.shape[0]
    n_cls = len(self.classes_)
    dec = np.zeros((n_test, n_cls))
    for idx, cls in enumerate(self.classes_):
        cls_item = cls.item() if isinstance(cls, torch.Tensor) else cls
        beta_vec: np.ndarray = np.ravel(self.beta[cls_item])
        dec[:, idx] = Kt_np.T @ beta_vec + self.b[cls_item]

    exp_scores = np.exp(dec - dec.max(axis=1, keepdims=True))
    probs = exp_scores / exp_scores.sum(axis=1, keepdims=True)
    return torch.tensor(probs, dtype=torch.float32)