def fit(
self, X: Float[torch.Tensor, "n_samples n_manifolds"], y: Int[torch.Tensor, "n_samples"]
) -> ProductSpacePerceptron:
"""Trains the perceptron model using the provided data and labels.
Args:
X: Training data tensor.
y: Class labels for the training data.
Returns:
self: Fitted perceptron model.
"""
# Identify unique classes for multiclass classification
self._store_classes(y)
n_samples = X.shape[0]
# Precompute kernel matrix
Ks, _ = product_kernel(self.pm, X, None)
K = torch.ones((n_samples, n_samples), dtype=X.dtype, device=X.device)
for K_m, w in zip(Ks, self.weights, strict=False):
K += w * K_m
# Store training data and labels for prediction
self.X_train_ = X
self.y_train_ = y
# Initialize dictionary to store alpha coefficients for each class
self.alpha = {}
# For patience checking
best_epoch, least_errors = 0, n_samples + 1
for class_label in self.classes_:
class_label_item = class_label.item()
# One-vs-rest labels
y_binary = torch.where(y == class_label_item, 1, -1) # Shape: (n_samples,)
# Initialize alpha coefficients for this class
alpha = torch.zeros(n_samples, dtype=X.dtype, device=X.device)
for epoch in range(self.max_epochs):
# Compute decision function: f = K @ (alpha * y_binary)
f = K @ (alpha * y_binary) # Shape: (n_samples,)
# Compute predictions
predictions = torch.sign(f)
# Find misclassified samples
misclassified = predictions != y_binary
# If no misclassifications, break early
if not misclassified.any():
break
# Test patience
n_errors = misclassified.sum().item()
if n_errors < least_errors:
best_epoch, least_errors = epoch, n_errors
if epoch - best_epoch >= self.patience:
break
# Update alpha coefficients for misclassified samples
alpha[misclassified] += 1
# Store the alpha coefficients for the current class
self.alpha[class_label_item] = alpha
self.is_fitted_ = True
return self