Decision Tree

manify.predictors.decision_tree

Decision tree and random forest predictors for product space manifolds.

For more information, see Chlenski et al. (2024): https://arxiv.org/abs/2410.13879

ProductSpaceDT(pm, max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, task='classification', use_special_dims=False, batch_size=None, n_features='d', ablate_midpoints=False, random_state=None, device=None)

Bases: BasePredictor

Decision tree in the product space to handle hyperbolic, euclidean, and hyperspherical data.

Source code in manify/predictors/decision_tree.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def __init__(
    self,
    pm: ProductManifold,
    max_depth: int | None = None,
    min_samples_leaf: int = 1,
    min_samples_split: int = 2,
    min_impurity_decrease: float = 0.0,
    task: Literal["classification", "regression", "link_prediction"] = "classification",
    use_special_dims: bool = False,
    batch_size: int | None = None,
    n_features: Literal["d", "d_choose_2"] = "d",
    ablate_midpoints: bool = False,
    random_state: int | None = None,
    device: str | None = None,
):
    # Initialize the base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)

    # Raise error if manifold is stereographic
    if pm.is_stereographic:
        raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
    if task == "link_prediction":
        raise ValueError(
            "Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
        )

    # Store hyperparameters
    self.pm = pm
    self.max_depth = max_depth or -1
    self.min_samples_leaf = min_samples_leaf
    self.min_samples_split = min_samples_split
    self.min_impurity_decrease = min_impurity_decrease
    self.use_special_dims = use_special_dims
    self.n_features = n_features
    self.ablate_midpoints = ablate_midpoints

    # I use "batched" to mean "all at once" and "batch_size" to mean "in chunks"
    self.batch_size = batch_size
    if batch_size is not None:
        self.batched = False
    else:
        self.batched = True

    # Task-specific stuff
    self.task = task
    self.criterion = "gini" if task == "classification" else "mse"

    # These will become important later, when fit is called
    self.nodes: list[_DecisionNode] = []  # For fitted nodes
    self.permutations: Int[torch.Tensor, "n_classes"] | None = None  # If used as part of a random forest
    self.angle2man: list[int] = []  # Maps preprocessed angles to manifold indices
    self.special_first: list[bool] = []  # Whether the first dimension is special in a projection
    self.angle_dims: list[tuple[int, int]] = []  # Maps preprocessed angles to dimension indices
    self.tree: _DecisionNode = _DecisionNode()  # The root of the tree
    self.classes_: Float[torch.Tensor, "n_classes"] = torch.empty(0)  # Initialize as an empty tensor
    self.labels_: Int[torch.Tensor, "batch n_classes"] = torch.tensor([])  # sklearn-style labels
    self.signature: list[tuple[float, int]] = pm.signature  # The signature of the manifold

fit(X, y)

Reworked fit function for new version of ProductDT.

Parameters:
  • X (Float[Tensor, 'batch ambient_dim']) –

    (batch, ambient_dim) tensor of trainind data (ambient coordinate representation)

  • y (Real[Tensor, 'batch']) –

    (batch,) tensor of labels (integer representation)

Returns:
Source code in manify/predictors/decision_tree.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
@torch.no_grad()  # type: ignore
def fit(self, X: Float[torch.Tensor, "batch ambient_dim"], y: Real[torch.Tensor, "batch"]) -> ProductSpaceDT:
    """Reworked fit function for new version of ProductDT.

    Args:
        X: (batch, ambient_dim) tensor of trainind data (ambient coordinate representation)
        y: (batch,) tensor of labels (integer representation)

    Returns:
        None (fits tree in place)
    """
    # Pre-preprocessing step: aggregate special dimensions into a new Euclidean component
    if self.use_special_dims:
        X, self.pm = self._aggregate_special_dims(X)

    # Preprocess data
    angles, labels, comparisons_reshaped = self._preprocess(X=X, y=y)

    # Fit node
    self.tree = self._fit_node(angles=angles, labels=labels, comparisons=comparisons_reshaped, depth=self.max_depth)

    self.is_fitted_ = True  # Mark the model as fitted
    return self

predict_proba(X)

Predict class probabilities for samples in X.

Source code in manify/predictors/decision_tree.py
607
608
609
610
611
612
613
614
615
@torch.no_grad()  # type: ignore
def predict_proba(self, X: Float[torch.Tensor, "batch intrinsic_dim"]) -> Float[torch.Tensor, "batch n_classes"]:
    """Predict class probabilities for samples in X."""
    if self.use_special_dims:
        X, _ = self._aggregate_special_dims(X)
    angles, _, _ = self._preprocess(X=X)
    if self.permutations is not None:
        angles = angles[:, self.permutations]
    return torch.vstack([self._traverse(angles_row, self.tree).probs for angles_row in angles])

ProductSpaceRF(pm, task='classification', use_special_dims=False, n_features='d', max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, ablate_midpoints=False, n_estimators=100, max_features='sqrt', max_samples=1.0, batch_size=None, random_state=None, n_jobs=-1, device=None)

Bases: BasePredictor

Random Forest in the product space.

Source code in manify/predictors/decision_tree.py
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
def __init__(
    self,
    pm: ProductManifold,
    task: Literal["classification", "regression"] = "classification",
    use_special_dims: bool = False,
    n_features: Literal["d", "d_choose_2"] = "d",
    max_depth: int | None = None,
    min_samples_leaf: int = 1,
    min_samples_split: int = 2,
    min_impurity_decrease: float = 0.0,
    ablate_midpoints: bool = False,
    n_estimators: int = 100,
    max_features: Literal["sqrt", "log2", "none"] = "sqrt",
    max_samples: float = 1.0,
    batch_size: int | None = None,
    random_state: int | None = None,
    n_jobs: int = -1,
    device: str | None = None,
):
    # Initialize the base class
    super().__init__(pm=pm, task=task, random_state=random_state, device=device)

    # Raise error if manifold is stereographic
    if pm.is_stereographic:
        raise ValueError("Stereographic manifolds are not supported. Use a different representation.")
    if task == "link_prediction":
        raise ValueError(
            "Link prediction is not supported for decision trees. Please use utils.link_prediction to reframe as classification"
        )

    # Tree hyperparameters
    tree_kwargs: Dict[str, Any] = {}
    self.pm = tree_kwargs["pm"] = pm
    self.task = tree_kwargs["task"] = task
    self.max_depth = tree_kwargs["max_depth"] = max_depth or -1
    self.min_samples_leaf = tree_kwargs["min_samples_leaf"] = min_samples_leaf
    self.min_samples_split = tree_kwargs["min_samples_split"] = min_samples_split
    self.min_impurity_decrease = tree_kwargs["min_impurity_decrease"] = min_impurity_decrease
    self.use_special_dims = tree_kwargs["use_special_dims"] = use_special_dims
    self.n_features = tree_kwargs["n_features"] = n_features
    self.batch_size = tree_kwargs["batch_size"] = batch_size
    self.ablate_midpoints = tree_kwargs["ablate_midpoints"] = ablate_midpoints

    # I use "batched" to mean "all at once" and "batch_size" to mean "in chunks"
    self.batch_size = batch_size
    if batch_size is not None:
        self.batched = False
    else:
        self.batched = True

    # Random forest hyperparameters
    self.n_estimators = n_estimators
    self.max_features = max_features
    self.max_samples = max_samples
    self.random_state = random_state
    self.n_jobs = n_jobs
    self.trees = [ProductSpaceDT(**tree_kwargs) for _ in range(n_estimators)]

    # These will become important later - just the sklearn-style stuff
    # For other special attributes, we just use ProductSpaceDT's attributes
    self.classes_: Float[torch.Tensor, "n_classes"] | None = None
    self.labels_: Int[torch.Tensor, "batch n_classes"] | None = None

fit(X, y)

Preprocess and fit an ensemble of trees on subsampled data.

Source code in manify/predictors/decision_tree.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
@torch.no_grad()  # type: ignore
def fit(self, X: Float[torch.Tensor, "batch ambient_dim"], y: Real[torch.Tensor, "batch"]) -> ProductSpaceRF:
    """Preprocess and fit an ensemble of trees on subsampled data."""
    # Pre-preprocessing step: aggregate special dimensions
    if self.use_special_dims:
        X, self.pm = self.trees[0]._aggregate_special_dims(X)
        for tree in self.trees:
            tree.pm = self.pm

    # Can use any tree to preprocess X and y
    angles, labels, comparisons = self.trees[0]._preprocess(X=X, y=y)

    # Also update angle2man and special_first
    for tree in self.trees:
        tree.angle2man = self.trees[0].angle2man
        tree.special_first = self.trees[0].special_first
        tree.classes_ = self.trees[0].classes_
    self.classes_ = self.trees[0].classes_

    # Use seed here
    if self.random_state is not None:
        torch.manual_seed(self.random_state)

    # Subsample - just the indices
    n, d = angles.shape
    idx_sample_all, idx_dim_all = self._generate_subsample(n_rows=n, n_cols=d, n_trees=self.n_estimators)

    # Fit trees
    for tree, idx_sample, idx_dim in zip(self.trees, idx_sample_all, idx_dim_all, strict=False):
        tree.permutations = idx_dim
        if self.batched:
            comparisons_subsample = comparisons[idx_sample][:, idx_dim][:, :, idx_sample]
        else:
            comparisons_subsample = comparisons
        tree.tree = tree._fit_node(
            angles=angles[idx_sample][:, idx_dim],
            labels=labels[idx_sample],
            comparisons=comparisons_subsample,
            depth=self.max_depth,
        )

    self.is_fitted_ = True
    return self

predict_proba(X)

Predict class probabilities for samples in X.

Source code in manify/predictors/decision_tree.py
749
750
751
752
@torch.no_grad()  # type: ignore
def predict_proba(self, X: Float[torch.Tensor, "batch intrinsic_dim"]) -> Float[torch.Tensor, "batch n_classes"]:
    """Predict class probabilities for samples in X."""
    return torch.stack([tree.predict_proba(X) for tree in self.trees]).mean(dim=0)