Kernel

manify.predictors._kernel

Kernel matrix calculation for product manifolds.

compute_kernel_and_norm_manifold(manifold, X_source, X_target)

Computes the kernel matrix for a single manifold.

Parameters:
  • manifold (Manifold) –

    The manifold for the computation.

  • X_source (Float[Tensor, 'n_points_source n_dim']) –

    Tensor of source points.

  • X_target (Float[Tensor, 'n_points_target n_dim'] | None) –

    Tensor of target points.

Returns:
  • kernel_matrix( Float[Tensor, 'n_points_source n_points_target'] ) –

    The kernel matrix between source and target points.

  • norm_constant( Float[Tensor, ''] ) –

    Scalar normalization constant for the kernel.

Source code in manify/predictors/_kernel.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def compute_kernel_and_norm_manifold(
    manifold: Manifold,
    X_source: Float[torch.Tensor, "n_points_source n_dim"],
    X_target: Float[torch.Tensor, "n_points_target n_dim"] | None,
) -> tuple[Float[torch.Tensor, "n_points_source n_points_target"], Float[torch.Tensor, ""]]:
    """Computes the kernel matrix for a single manifold.

    Args:
        manifold: The manifold for the computation.
        X_source: Tensor of source points.
        X_target: Tensor of target points.

    Returns:
        kernel_matrix: The kernel matrix between source and target points.
        norm_constant: Scalar normalization constant for the kernel.
    """
    X_target = X_source if X_target is None else X_target

    ip = manifold.inner(X_source, X_target)
    ip *= manifold.scale
    if manifold.type == "E":
        # K_E is just inner products
        K = ip
        norm = torch.tensor(1.0)
    elif manifold.type == "S":
        # K_S is asin(C_S * inner products)
        # C_S is the curvature (see p.5 of Tabaghi paper)
        C_S = manifold.curvature
        K = torch.asin(torch.clamp(ip * C_S, -1, 1)) * C_S**0.5
        norm = torch.tensor(C_S**0.5)
        # norm is sqrt(C_S) (see p.16 of Tabaghi paper)
    elif manifold.type == "H":
        # K_H is asinh(R^-2 * Lorentz inner products) * sqrt(-C_H)
        C_H = abs(manifold.curvature)
        R = -1 * manifold.scale
        # R = (X_source @ X_target.T).sqrt().max()
        # K = torch.asinh(torch.clamp(ip / R**2, -1, 1)) * C_H**0.5
        K = torch.asinh(ip / R**2) * C_H**0.5
        # norm = torch.tensor(C_H)
        # norm is sqrt(-C_H)
        # norm = torch.asinh(-(R**2) * C_H)
        norm = torch.tensor(C_H) ** 0.5
    else:
        raise ValueError("Invalid manifold type!")

    return K, norm

product_kernel(pm, X_source, X_target)

Computes the kernel matrix for a product manifold.

Parameters:
  • pm (ProductManifold) –

    The product manifold for the computation.

  • X_source (Float[Tensor, 'n_points_source n_dim']) –

    Tensor of source points.

  • X_target (Float[Tensor, 'n_points_target n_dim'] | None) –

    Tensor of target points.

Returns:
  • kernel_matrices( list[Float[Tensor, 'n_points_source n_points_target']] ) –

    List of kernel matrices for each component manifold.

  • norm_constants( list[Float[Tensor, '']] ) –

    List of normalization constants for each kernel.

Source code in manify/predictors/_kernel.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def product_kernel(
    pm: ProductManifold,
    X_source: Float[torch.Tensor, "n_points_source n_dim"],
    X_target: Float[torch.Tensor, "n_points_target n_dim"] | None,
) -> tuple[list[Float[torch.Tensor, "n_points_source n_points_target"]], list[Float[torch.Tensor, ""]]]:
    """Computes the kernel matrix for a product manifold.

    Args:
        pm: The product manifold for the computation.
        X_source: Tensor of source points.
        X_target: Tensor of target points.

    Returns:
        kernel_matrices: List of kernel matrices for each component manifold.
        norm_constants: List of normalization constants for each kernel.
    """
    X_target = X_source if X_target is None else X_target

    # Compute the kernel matrix and norm for each manifold
    Ks = []
    norms = []
    for M, x_source, x_target in zip(pm.P, pm.factorize(X_source), pm.factorize(X_target), strict=False):
        K_m, norm_m = compute_kernel_and_norm_manifold(M, x_source, x_target)
        Ks.append(K_m)
        norms.append(norm_m)

    return Ks, norms