Bases: BaseEstimator, ABC
Base class for everything in manify.predictors.
This is an abstract class that defines a common interface for all mixed-curvature predictors. We assume only that a
ProductManifold object is given. We try to follow the scikit-learn API's fit/predict_proba/predict paradigm as
closely as possible, while accommodating the nuances of product manifold geometry and Pytorch/Geoopt.
| Parameters: |
-
pm
(ProductManifold)
–
ProductManifold object associated with the predictor.
-
task
(Literal['classification', 'regression', 'link_prediction'])
–
Task type, either "classification" or "regression".
-
random_state
(int | None, default:
None
)
–
Random state for reproducibility.
-
device
(str | None, default:
None
)
–
Device for tensor computations.
|
| Attributes: |
-
pm
–
ProductManifold object associated with the predictor.
-
task
–
Task type, either "classification" or "regression".
-
random_state
–
Random state for reproducibility.
-
device
–
Device for tensor computations. If not provided, defaults to pm.device.
-
loss_history_
(dict[str, list[float]])
–
History of loss values during training.
-
is_fitted_
(bool)
–
Boolean flag indicating if the predictor has been fitted.
|
Source code in manify/predictors/_base.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63 | def __init__(
self,
pm: ProductManifold,
task: Literal["classification", "regression", "link_prediction"],
random_state: int | None = None,
device: str | None = None,
) -> None:
self.pm = pm
self.task = task
self.random_state = random_state
self.device = device or pm.device
self.loss_history_: dict[str, list[float]] = {}
self.is_fitted_: bool = False
# Initialize appropriate base class depending on task
if task == "classification":
ClassifierMixin.__init__(self)
elif task == "regression":
RegressorMixin.__init__(self)
elif task == "link_prediction":
# For link prediction, we also use ClassifierMixin, as we think of it as binary classificaiton.
ClassifierMixin.__init__(self)
else:
raise ValueError(f"Unknown task type: {task}")
|
fit(X, y)
abstractmethod
Abstract method to fit a predictor. Requires features and labels.
| Parameters: |
-
X
(Float[Tensor, 'n_points n_features'])
–
-
y
(Float[Tensor, 'n_points n_classes'] | Float[Tensor, 'n_points'])
–
|
| Returns: |
-
self( 'BasePredictor'
) –
Fitted predictor instance.
|
Source code in manify/predictors/_base.py
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96 | @abstractmethod
def fit(
self,
X: Float[torch.Tensor, "n_points n_features"],
y: Float[torch.Tensor, "n_points n_classes"] | Float[torch.Tensor, "n_points"],
) -> "BasePredictor":
"""Abstract method to fit a predictor. Requires features and labels.
Args:
X: Features to fit.
y: Labels for the features.
Returns:
self: Fitted predictor instance.
"""
pass
|
predict_proba(X)
abstractmethod
Compute the predicted probabilities for the given features.
| Parameters: |
-
X
(Float[Tensor, 'n_points n_features'])
–
New inputs for which to make predictions.
|
| Returns: |
-
X_proba( Float[Tensor, 'n_points n_classes']
) –
Predicted probabilities for the input features.
|
Source code in manify/predictors/_base.py
98
99
100
101
102
103
104
105
106
107
108 | @abstractmethod
def predict_proba(self, X: Float[torch.Tensor, "n_points n_features"]) -> Float[torch.Tensor, "n_points n_classes"]:
"""Compute the predicted probabilities for the given features.
Args:
X: New inputs for which to make predictions.
Returns:
X_proba: Predicted probabilities for the input features.
"""
pass
|
predict(X, **kwargs)
Compute the predicted classes for the given features.
| Parameters: |
-
X
(Float[Tensor, 'n_points n_features'])
–
New inputs for which to make predictions.
-
**kwargs
(dict, default:
{}
)
–
Additional keyword arguments that get passed to self.predict_proba().
|
| Returns: |
-
X_proba( Float[Tensor, 'n_points']
) –
Predicted probabilities for the input features.
|
Source code in manify/predictors/_base.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127 | def predict(self, X: Float[torch.Tensor, "n_points n_features"], **kwargs: dict) -> Float[torch.Tensor, "n_points"]:
"""Compute the predicted classes for the given features.
Args:
X: New inputs for which to make predictions.
**kwargs: Additional keyword arguments that get passed to `self.predict_proba()`.
Returns:
X_proba: Predicted probabilities for the input features.
"""
if self.task == "regression":
return self.predict_proba(X=X, **kwargs).squeeze(-1)
elif self.task == "link_prediction":
logits = self.predict_proba(X=X, **kwargs)
return (logits > 0.5).float() # Threshold at 0.5
else: # classification
class_indices = self.predict_proba(X=X, **kwargs).argmax(dim=-1)
return self._get_class_predictions(class_indices)
|
score(X, y, sample_weight=None, **kwargs)
Return the mean accuracy/R² score.
| Parameters: |
-
X
(Float[Tensor, 'n_points n_features'])
–
-
y
(Float[Tensor, 'n_points n_classes'] | Float[Tensor, 'n_points'])
–
-
sample_weight
(Float[Tensor, 'n_points'] | None, default:
None
)
–
Sample weights for each point. Defaults to None, which means all points are equally weighted.
-
**kwargs
(dict, default:
{}
)
–
Additional keyword arguments that get passed to self.predict_proba().
|
| Returns: |
-
score( float
) –
Mean accuracy (classification) or R² score (regression).
|
Source code in manify/predictors/_base.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159 | def score(
self,
X: Float[torch.Tensor, "n_points n_features"],
y: Float[torch.Tensor, "n_points n_classes"] | Float[torch.Tensor, "n_points"],
sample_weight: Float[torch.Tensor, "n_points"] | None = None,
**kwargs: dict,
) -> float:
"""Return the mean accuracy/R² score.
Args:
X: Input features.
y: Target labels.
sample_weight: Sample weights for each point. Defaults to None, which means all points are equally weighted.
**kwargs: Additional keyword arguments that get passed to `self.predict_proba()`.
Returns:
score: Mean accuracy (classification) or R² score (regression).
"""
predictions = self.predict(X, **kwargs)
if sample_weight is None:
sample_weight = torch.ones_like(predictions, dtype=torch.float32)
if self.task == "classification":
out = ((predictions == y).float() * sample_weight).mean().item()
elif self.task == "regression":
out = (((predictions - y) ** 2 * sample_weight).mean()).item()
else: # link_prediction
out = ((predictions == y).float() * sample_weight).mean().item()
return float(out)
|