Base

manify.predictors._base

Base predictor class.

BasePredictor(pm, task, random_state=None, device=None)

Bases: BaseEstimator, ABC

Base class for everything in manify.predictors.

This is an abstract class that defines a common interface for all mixed-curvature predictors. We assume only that a ProductManifold object is given. We try to follow the scikit-learn API's fit/predict_proba/predict paradigm as closely as possible, while accommodating the nuances of product manifold geometry and Pytorch/Geoopt.

Parameters:
  • pm (ProductManifold) –

    ProductManifold object associated with the predictor.

  • task (Literal['classification', 'regression', 'link_prediction']) –

    Task type, either "classification" or "regression".

  • random_state (int | None, default: None ) –

    Random state for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Attributes:
  • pm

    ProductManifold object associated with the predictor.

  • task

    Task type, either "classification" or "regression".

  • random_state

    Random state for reproducibility.

  • device

    Device for tensor computations. If not provided, defaults to pm.device.

  • loss_history_ (dict[str, list[float]]) –

    History of loss values during training.

  • is_fitted_ (bool) –

    Boolean flag indicating if the predictor has been fitted.

Source code in manify/predictors/_base.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    pm: ProductManifold,
    task: Literal["classification", "regression", "link_prediction"],
    random_state: int | None = None,
    device: str | None = None,
) -> None:
    self.pm = pm
    self.task = task
    self.random_state = random_state
    self.device = device or pm.device
    self.loss_history_: dict[str, list[float]] = {}
    self.is_fitted_: bool = False

    # Initialize appropriate base class depending on task
    if task == "classification":
        ClassifierMixin.__init__(self)
    elif task == "regression":
        RegressorMixin.__init__(self)
    elif task == "link_prediction":
        # For link prediction, we also use ClassifierMixin, as we think of it as binary classificaiton.
        ClassifierMixin.__init__(self)
    else:
        raise ValueError(f"Unknown task type: {task}")

fit(X, y) abstractmethod

Abstract method to fit a predictor. Requires features and labels.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Features to fit.

  • y (Float[Tensor, 'n_points n_classes'] | Float[Tensor, 'n_points']) –

    Labels for the features.

Returns:
  • self( 'BasePredictor' ) –

    Fitted predictor instance.

Source code in manify/predictors/_base.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@abstractmethod
def fit(
    self,
    X: Float[torch.Tensor, "n_points n_features"],
    y: Float[torch.Tensor, "n_points n_classes"] | Float[torch.Tensor, "n_points"],
) -> "BasePredictor":
    """Abstract method to fit a predictor. Requires features and labels.

    Args:
        X: Features to fit.
        y: Labels for the features.

    Returns:
        self: Fitted predictor instance.
    """
    pass

predict_proba(X) abstractmethod

Compute the predicted probabilities for the given features.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    New inputs for which to make predictions.

Returns:
  • X_proba( Float[Tensor, 'n_points n_classes'] ) –

    Predicted probabilities for the input features.

Source code in manify/predictors/_base.py
 98
 99
100
101
102
103
104
105
106
107
108
@abstractmethod
def predict_proba(self, X: Float[torch.Tensor, "n_points n_features"]) -> Float[torch.Tensor, "n_points n_classes"]:
    """Compute the predicted probabilities for the given features.

    Args:
        X: New inputs for which to make predictions.

    Returns:
        X_proba: Predicted probabilities for the input features.
    """
    pass

predict(X, **kwargs)

Compute the predicted classes for the given features.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    New inputs for which to make predictions.

  • **kwargs (dict, default: {} ) –

    Additional keyword arguments that get passed to self.predict_proba().

Returns:
  • X_proba( Float[Tensor, 'n_points'] ) –

    Predicted probabilities for the input features.

Source code in manify/predictors/_base.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def predict(self, X: Float[torch.Tensor, "n_points n_features"], **kwargs: dict) -> Float[torch.Tensor, "n_points"]:
    """Compute the predicted classes for the given features.

    Args:
        X: New inputs for which to make predictions.
        **kwargs: Additional keyword arguments that get passed to `self.predict_proba()`.

    Returns:
        X_proba: Predicted probabilities for the input features.
    """
    if self.task == "regression":
        return self.predict_proba(X=X, **kwargs).squeeze(-1)
    elif self.task == "link_prediction":
        logits = self.predict_proba(X=X, **kwargs)
        return (logits > 0.5).float()  # Threshold at 0.5
    else:  # classification
        class_indices = self.predict_proba(X=X, **kwargs).argmax(dim=-1)
        return self._get_class_predictions(class_indices)

score(X, y, sample_weight=None, **kwargs)

Return the mean accuracy/R² score.

Parameters:
  • X (Float[Tensor, 'n_points n_features']) –

    Input features.

  • y (Float[Tensor, 'n_points n_classes'] | Float[Tensor, 'n_points']) –

    Target labels.

  • sample_weight (Float[Tensor, 'n_points'] | None, default: None ) –

    Sample weights for each point. Defaults to None, which means all points are equally weighted.

  • **kwargs (dict, default: {} ) –

    Additional keyword arguments that get passed to self.predict_proba().

Returns:
  • score( float ) –

    Mean accuracy (classification) or R² score (regression).

Source code in manify/predictors/_base.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def score(
    self,
    X: Float[torch.Tensor, "n_points n_features"],
    y: Float[torch.Tensor, "n_points n_classes"] | Float[torch.Tensor, "n_points"],
    sample_weight: Float[torch.Tensor, "n_points"] | None = None,
    **kwargs: dict,
) -> float:
    """Return the mean accuracy/R² score.

    Args:
        X: Input features.
        y: Target labels.
        sample_weight: Sample weights for each point. Defaults to None, which means all points are equally weighted.
        **kwargs: Additional keyword arguments that get passed to `self.predict_proba()`.

    Returns:
        score: Mean accuracy (classification) or R² score (regression).
    """
    predictions = self.predict(X, **kwargs)

    if sample_weight is None:
        sample_weight = torch.ones_like(predictions, dtype=torch.float32)

    if self.task == "classification":
        out = ((predictions == y).float() * sample_weight).mean().item()
    elif self.task == "regression":
        out = (((predictions - y) ** 2 * sample_weight).mean()).item()
    else:  # link_prediction
        out = ((predictions == y).float() * sample_weight).mean().item()

    return float(out)