Computes δ-hyperbolicity from a distance matrix.
For each triplet of points (x,y,z) and reference point w, computes:
δ(x,y,z) = min((x,y)_w, (y,z)_w) - (x,z)_w
where (a,b)_w = ½(d(w,a) + d(w,b) - d(a,b)) is the Gromov product.
| Parameters: |
-
distance_matrix
(Float[Tensor, 'n_points n_points'])
–
Pairwise distance matrix.
-
samples
(int | None, default:
None
)
–
Number of triplets to sample. If None, computes full δ tensor over all triplets.
-
reference_idx
(int, default:
0
)
–
Index of the reference point w.
-
relative
(bool, default:
True
)
–
Whether to normalize by maximum distance.
|
| Returns: |
-
Float[Tensor, 'n_points n_points n_points'] | Float[Tensor, 'samples']
–
δ-hyperbolicity estimates:
-
Float[Tensor, 'n_points n_points n_points'] | Float[Tensor, 'samples']
–
- When samples is not None: torch.Tensor of shape (samples,)
-
Float[Tensor, 'n_points n_points n_points'] | Float[Tensor, 'samples']
–
- When samples is None: torch.Tensor of shape (n_points, n_points, n_points)
|
Note
For global statistics, call .max() or other aggregation functions on the result.
Source code in manify/curvature_estimation/delta_hyperbolicity.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 | def delta_hyperbolicity(
distance_matrix: Float[torch.Tensor, "n_points n_points"],
samples: int | None = None,
reference_idx: int = 0,
relative: bool = True,
) -> Float[torch.Tensor, "n_points n_points n_points"] | Float[torch.Tensor, "samples"]:
r"""Computes δ-hyperbolicity from a distance matrix.
For each triplet of points (x,y,z) and reference point w, computes:
δ(x,y,z) = min((x,y)_w, (y,z)_w) - (x,z)_w
where (a,b)_w = ½(d(w,a) + d(w,b) - d(a,b)) is the Gromov product.
Args:
distance_matrix: Pairwise distance matrix.
samples: Number of triplets to sample. If None, computes full δ tensor over all triplets.
reference_idx: Index of the reference point w.
relative: Whether to normalize by maximum distance.
Returns:
δ-hyperbolicity estimates:
- When samples is not None: torch.Tensor of shape (samples,)
- When samples is None: torch.Tensor of shape (n_points, n_points, n_points)
Note:
For global statistics, call .max() or other aggregation functions on the result.
"""
if not isinstance(distance_matrix, torch.Tensor):
raise TypeError(f"distance_matrix must be a torch.Tensor, got {type(distance_matrix)}")
D = distance_matrix.float()
if samples is not None:
return _sample_delta_values(D, samples, reference_idx, relative)
else:
return _compute_full_delta_tensor(D, reference_idx, relative)
|