Computes the kernel matrix for a single manifold.
| Parameters: |
-
manifold
(Manifold)
–
The manifold for the computation.
-
X_source
(Float[Tensor, 'n_points_source n_dim'])
–
-
X_target
(Float[Tensor, 'n_points_target n_dim'] | None)
–
|
| Returns: |
-
kernel_matrix( Float[Tensor, 'n_points_source n_points_target']
) –
The kernel matrix between source and target points.
-
norm_constant( Float[Tensor, '']
) –
Scalar normalization constant for the kernel.
|
Source code in manify/predictors/_kernel.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 | def compute_kernel_and_norm_manifold(
manifold: Manifold,
X_source: Float[torch.Tensor, "n_points_source n_dim"],
X_target: Float[torch.Tensor, "n_points_target n_dim"] | None,
) -> tuple[Float[torch.Tensor, "n_points_source n_points_target"], Float[torch.Tensor, ""]]:
"""Computes the kernel matrix for a single manifold.
Args:
manifold: The manifold for the computation.
X_source: Tensor of source points.
X_target: Tensor of target points.
Returns:
kernel_matrix: The kernel matrix between source and target points.
norm_constant: Scalar normalization constant for the kernel.
"""
X_target = X_source if X_target is None else X_target
ip = manifold.inner(X_source, X_target)
ip *= manifold.scale
if manifold.type == "E":
# K_E is just inner products
K = ip
norm = torch.tensor(1.0)
elif manifold.type == "S":
# K_S is asin(C_S * inner products)
# C_S is the curvature (see p.5 of Tabaghi paper)
C_S = manifold.curvature
K = torch.asin(torch.clamp(ip * C_S, -1, 1)) * C_S**0.5
norm = torch.tensor(C_S**0.5)
# norm is sqrt(C_S) (see p.16 of Tabaghi paper)
elif manifold.type == "H":
# K_H is asinh(R^-2 * Lorentz inner products) * sqrt(-C_H)
C_H = abs(manifold.curvature)
R = -1 * manifold.scale
# R = (X_source @ X_target.T).sqrt().max()
# K = torch.asinh(torch.clamp(ip / R**2, -1, 1)) * C_H**0.5
K = torch.asinh(ip / R**2) * C_H**0.5
# norm = torch.tensor(C_H)
# norm is sqrt(-C_H)
# norm = torch.asinh(-(R**2) * C_H)
norm = torch.tensor(C_H) ** 0.5
else:
raise ValueError("Invalid manifold type!")
return K, norm
|