Sectional Curvature

manify.curvature_estimation.sectional_curvature

Sectional curvature estimation for graphs.

This module implements the graph sectional curvature estimation from: Gu et al. "Learning mixed-curvature representations in product spaces." ICLR 2019.

Estimates local curvature at nodes using a discrete triangle comparison theorem.

sectional_curvature(adjacency_matrix, distance_matrix, samples=None, relative=True)

Estimates sectional curvature of a graph.

Uses discrete triangle comparison theorem to estimate local curvature. Positive values indicate spherical-like regions, negative values indicate hyperbolic-like regions, zero indicates flat regions.

Parameters:
  • adjacency_matrix (Float[Tensor, 'n_points n_points']) –

    Binary adjacency matrix indicating graph connections.

  • distance_matrix (Float[Tensor, 'n_points n_points']) –

    Pairwise shortest path distance matrix.

  • samples (int | None, default: None ) –

    Number of triangle configurations to sample. If None, computes per-node curvatures.

  • relative (bool, default: True ) –

    Whether to normalize by maximum distance.

Returns:
  • Float[Tensor, 'n_points'] | Float[Tensor, 'samples']

    Sectional curvature estimates:

  • Float[Tensor, 'n_points'] | Float[Tensor, 'samples']
    • When samples is not None: torch.Tensor of shape (samples,)
  • Float[Tensor, 'n_points'] | Float[Tensor, 'samples']
    • When samples is None: torch.Tensor of shape (n_points,) with per-node curvatures
Note

For global statistics, call .mean() or other aggregation functions on the result.

Source code in manify/curvature_estimation/sectional_curvature.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def sectional_curvature(
    adjacency_matrix: Float[torch.Tensor, "n_points n_points"],
    distance_matrix: Float[torch.Tensor, "n_points n_points"],
    samples: int | None = None,
    relative: bool = True,
) -> Float[torch.Tensor, "n_points"] | Float[torch.Tensor, "samples"]:
    r"""Estimates sectional curvature of a graph.

    Uses discrete triangle comparison theorem to estimate local curvature.
    Positive values indicate spherical-like regions, negative values indicate
    hyperbolic-like regions, zero indicates flat regions.

    Args:
        adjacency_matrix: Binary adjacency matrix indicating graph connections.
        distance_matrix: Pairwise shortest path distance matrix.
        samples: Number of triangle configurations to sample. If None, computes per-node curvatures.
        relative: Whether to normalize by maximum distance.

    Returns:
        Sectional curvature estimates:
        - When samples is not None: torch.Tensor of shape (samples,)
        - When samples is None: torch.Tensor of shape (n_points,) with per-node curvatures

    Note:
        For global statistics, call .mean() or other aggregation functions on the result.
    """
    if not isinstance(adjacency_matrix, torch.Tensor) or not isinstance(distance_matrix, torch.Tensor):
        raise TypeError("Both adjacency_matrix and distance_matrix must be torch.Tensors")

    if adjacency_matrix.shape != distance_matrix.shape:
        raise ValueError("Adjacency matrix and distance matrix must have the same shape")

    A = adjacency_matrix.float()
    D = distance_matrix.float()

    if samples is not None:
        return _sample_curvatures(D, samples, relative)
    else:
        return _compute_node_curvatures(A, D, relative)