Decision Tree¶
manify.predictors.decision_tree
¶
Decision tree and random forest predictors for product space manifolds.
For more information, see Chlenski et al. (2024): https://arxiv.org/abs/2410.13879
ProductSpaceDT(pm, max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, task='classification', use_special_dims=False, batch_size=None, n_features='d', ablate_midpoints=False, random_state=None, device=None)
¶
Bases: BasePredictor
Decision tree in the product space to handle hyperbolic, euclidean, and hyperspherical data.
Source code in manify/predictors/decision_tree.py
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 | |
fit(X, y)
¶
Reworked fit function for new version of ProductDT.
| Parameters: |
|
|---|
| Returns: |
|
|---|
Source code in manify/predictors/decision_tree.py
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 | |
predict_proba(X)
¶
Predict class probabilities for samples in X.
Source code in manify/predictors/decision_tree.py
607 608 609 610 611 612 613 614 615 | |
ProductSpaceRF(pm, task='classification', use_special_dims=False, n_features='d', max_depth=None, min_samples_leaf=1, min_samples_split=2, min_impurity_decrease=0.0, ablate_midpoints=False, n_estimators=100, max_features='sqrt', max_samples=1.0, batch_size=None, random_state=None, n_jobs=-1, device=None)
¶
Bases: BasePredictor
Random Forest in the product space.
Source code in manify/predictors/decision_tree.py
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 | |
fit(X, y)
¶
Preprocess and fit an ensemble of trees on subsampled data.
Source code in manify/predictors/decision_tree.py
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 | |
predict_proba(X)
¶
Predict class probabilities for samples in X.
Source code in manify/predictors/decision_tree.py
749 750 751 752 | |