Svm

manify.predictors.svm

Implementation for Support Vector Machine in Product Manifolds.

ProductSpaceSVM(pm, weights=None, h_constraints=True, e_constraints=True, s_constraints=True, task='classification', epsilon=1e-05, random_state=None, device=None)

Bases: BasePredictor

Product Space SVM class in a product manifold setting.

Trains one-vs-rest SVMs with Euclidean, spherical, and hyperbolic constraints enforced via second-order-cone (SOC) formulations for convexity.

Parameters:
  • pm (ProductManifold) –

    A ProductManifold instance specifying component manifolds.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Optional per-manifold weights tensor.

  • h_constraints (bool, default: True ) –

    Whether to enforce hyperbolic constraints.

  • e_constraints (bool, default: True ) –

    Whether to enforce Euclidean constraints.

  • s_constraints (bool, default: True ) –

    Whether to enforce spherical constraints.

  • task (Literal['classification', 'regression'], default: 'classification' ) –

    Task type, either "classification" or "regression".

  • epsilon (float, default: 1e-05 ) –

    Slack parameter for SOC constraints.

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Attributes:
  • pm

    ProductManifold object associated with the predictor.

  • weights

    Per-manifold weights for kernel combination.

  • h_constraints

    Whether to enforce hyperbolic constraints.

  • e_constraints

    Whether to enforce Euclidean constraints.

  • s_constraints

    Whether to enforce spherical constraints.

  • eps

    Slack parameter for SOC constraints.

  • beta

    Dictionary storing SVM coefficients for each class.

  • zeta

    Dictionary storing slack variables for each class.

  • epsilon

    Dictionary storing epsilon values for each class.

  • b

    Dictionary storing bias terms for each class.

  • X_train_

    Training data points.

  • is_fitted_ (bool) –

    Boolean flag indicating if the predictor has been fitted.

Initialize the ProductSpaceSVM.

Parameters:
  • pm (ProductManifold) –

    A ProductManifold instance specifying component manifolds.

  • weights (Float[Tensor, 'n_manifolds'] | None, default: None ) –

    Optional per-manifold weights tensor.

  • h_constraints (bool, default: True ) –

    Whether to enforce hyperbolic constraints.

  • e_constraints (bool, default: True ) –

    Whether to enforce Euclidean constraints.

  • s_constraints (bool, default: True ) –

    Whether to enforce spherical constraints.

  • task (Literal['classification', 'regression'], default: 'classification' ) –

    Task type, either "classification" or "regression".

  • epsilon (float, default: 1e-05 ) –

    Slack parameter for SOC constraints.

  • random_state (int | None, default: None ) –

    Random seed for reproducibility.

  • device (str | None, default: None ) –

    Device for tensor computations.

Source code in manify/predictors/svm.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    pm: ProductManifold,
    weights: Float[torch.Tensor, "n_manifolds"] | None = None,
    h_constraints: bool = True,
    e_constraints: bool = True,
    s_constraints: bool = True,
    task: Literal["classification", "regression"] = "classification",
    epsilon: float = 1e-5,
    random_state: int | None = None,
    device: str | None = None,
):
    """Initialize the ProductSpaceSVM.

    Args:
        pm: A ProductManifold instance specifying component manifolds.
        weights: Optional per-manifold weights tensor.
        h_constraints: Whether to enforce hyperbolic constraints.
        e_constraints: Whether to enforce Euclidean constraints.
        s_constraints: Whether to enforce spherical constraints.
        task: Task type, either "classification" or "regression".
        epsilon: Slack parameter for SOC constraints.
        random_state: Random seed for reproducibility.
        device: Device for tensor computations.
    """
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)
    self.pm = pm
    self.h_constraints = h_constraints
    self.s_constraints = s_constraints
    self.e_constraints = e_constraints
    self.eps = epsilon
    self.task = task
    self.weights = torch.ones(len(pm.P), dtype=torch.float32) if weights is None else weights
    assert len(self.weights) == len(pm.P), "Number of weights must match the number of manifolds."

fit(X, y)

Fit one-vs-rest SVMs on the product manifold data.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Training points tensor.

  • y (Int[Tensor, 'n_samples']) –

    Integer class labels tensor.

Returns:
Source code in manify/predictors/svm.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def fit(
    self,
    X: Float[torch.Tensor, "n_samples n_manifolds"],
    y: Int[torch.Tensor, "n_samples"],
) -> ProductSpaceSVM:
    """Fit one-vs-rest SVMs on the product manifold data.

    Args:
        X: Training points tensor.
        y: Integer class labels tensor.

    Returns:
        self: Fitted ProductSpaceSVM instance.
    """
    # unique classes
    self._store_classes(y)
    n = X.shape[0]

    # aggregated kernel
    Ks, _ = product_kernel(self.pm, X, None)
    K_sum = torch.ones((n, n), dtype=X.dtype, device=X.device)
    for K_m, w in zip(Ks, self.weights, strict=False):
        K_sum += w * K_m

    X_np = X.detach().cpu().numpy()
    K_np = K_sum.detach().cpu().numpy()

    def sqrtm_psd(P: np.ndarray) -> Any:
        w, V = np.linalg.eigh(P)
        w_s = np.sqrt(np.clip(w, 0, None))
        B = V @ np.diag(w_s) @ V.T
        return (B + B.T) * 0.5

    # containers
    self.beta = {}
    self.zeta = {}
    self.epsilon = {}
    self.b = {}

    for cls in self.classes_:
        cls_item = cls.item() if isinstance(cls, torch.Tensor) else cls
        # one-vs-rest labels: +1 for cls, -1 for others
        y_bin = torch.where(y == cls_item, 1, -1)
        Y = torch.diagflat(y_bin).detach().cpu().numpy()

        # variables
        beta_var = cp.Variable(n)
        zeta = cp.Variable(n, nonneg=True)
        eps_var = cp.Variable(1)
        b_var = cp.Variable(1)

        # base constraints
        constraints = [eps_var >= 0]
        constraints.append(Y @ (K_np @ beta_var + b_var) >= eps_var - zeta)

        # per-manifold SOC
        for M, K_comp in zip(self.pm.P, Ks, strict=False):
            P_np = K_comp.detach().cpu().numpy()
            if M.type == "E" and self.e_constraints:
                B = sqrtm_psd(P_np)
                constraints.append(cp.norm(B @ beta_var, 2) <= 1.0)
            elif M.type == "S" and self.s_constraints:
                B = sqrtm_psd(P_np)
                constraints.append(cp.norm(B @ beta_var, 2) <= np.sqrt(np.pi / 2))
            elif M.type == "H" and self.h_constraints:
                # PSD split
                eigvals, eigvecs = np.linalg.eigh(P_np)
                plus = np.clip(eigvals, 0, None)
                minus = np.clip(-eigvals, 0, None)
                Kp = (eigvecs @ np.diag(plus) @ eigvecs.T + (eigvecs @ np.diag(plus) @ eigvecs.T).T) * 0.5
                Km = (eigvecs @ np.diag(minus) @ eigvecs.T + (eigvecs @ np.diag(minus) @ eigvecs.T).T) * 0.5
                Bp = sqrtm_psd(Kp)
                Bm = sqrtm_psd(Km)

                C_H = abs(M.curvature)
                R = -M.scale
                r_h = abs(np.arcsinh(-(R**2) * C_H))
                r = self.eps

                constraints.append(cp.norm(Bm @ beta_var, 2) <= np.sqrt(max(r, 0.0)))
                constraints.append(cp.norm(Bp @ beta_var, 2) <= np.sqrt(max(r + r_h, 0.0)))

        # solve
        prob = cp.Problem(cp.Minimize(-eps_var + cp.sum(zeta)), constraints)
        prob.solve(solver="SCS")

        # save results
        self.beta[cls_item] = np.ravel(beta_var.value)
        self.zeta[cls_item] = zeta.value
        self.epsilon[cls_item] = float(eps_var.value)
        self.b[cls_item] = float(b_var.value)

    # store training data
    self.X_train_ = torch.tensor(X_np, dtype=torch.float32)
    self.is_fitted_ = True
    return self

predict_proba(X)

Predict class probabilities using the fitted SVMs.

Parameters:
  • X (Float[Tensor, 'n_samples n_manifolds']) –

    Test points tensor.

Returns:
  • class_probabilities( Float[Tensor, 'n_samples n_classes'] ) –

    Class probabilities for each test sample.

Source code in manify/predictors/svm.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def predict_proba(
    self,
    X: Float[torch.Tensor, "n_samples n_manifolds"],
) -> Float[torch.Tensor, "n_samples n_classes"]:
    """Predict class probabilities using the fitted SVMs.

    Args:
        X: Test points tensor.

    Returns:
        class_probabilities: Class probabilities for each test sample.
    """
    X_tensor = torch.tensor(X, dtype=torch.float32) if not isinstance(X, torch.Tensor) else X
    X_tensor = X_tensor.to(self.X_train_.device)

    Ks_test, _ = product_kernel(self.pm, self.X_train_, X_tensor)
    Kt = torch.ones((self.X_train_.shape[0], X_tensor.shape[0]), device=X_tensor.device)
    for K_m, w in zip(Ks_test, self.weights, strict=False):
        Kt += w * K_m
    Kt_np = Kt.detach().cpu().numpy()

    n_test = X_tensor.shape[0]
    n_cls = len(self.classes_)
    dec = np.zeros((n_test, n_cls))
    for idx, cls in enumerate(self.classes_):
        cls_item = cls.item() if isinstance(cls, torch.Tensor) else cls
        beta_vec: np.ndarray = np.ravel(self.beta[cls_item])
        dec[:, idx] = Kt_np.T @ beta_vec + self.b[cls_item]

    exp_scores = np.exp(dec - dec.max(axis=1, keepdims=True))
    probs = exp_scores / exp_scores.sum(axis=1, keepdims=True)
    return torch.tensor(probs, dtype=torch.float32)