Compute the midpoint between two angular coordinates given the manifold type.
This function automatically selects the appropriate midpoint calculation depending
on the manifold type. It supports hyperbolic, Euclidean, and spherical geometries.
| Parameters: |
-
u
(Float[Tensor, ''])
–
The first angular coordinate.
-
v
(Float[Tensor, ''])
–
The second angular coordinate.
-
manifold
(Manifold)
–
An object representing the manifold type.
-
special_first
(bool, default:
False
)
–
If True, uses the manifold-specific midpoint calculations given the manifold type of hyperbolic
or euclidean.
|
| Returns: |
-
midpoint( Float[Tensor, '']
) –
The computed midpoint between u and v, based on the selected geometry.
|
Source code in manify/predictors/_midpoint.py
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116 | def midpoint(
u: Float[torch.Tensor, ""], v: Float[torch.Tensor, ""], manifold: Manifold, special_first: bool = False
) -> Float[torch.Tensor, ""]:
"""Compute the midpoint between two angular coordinates given the manifold type.
This function automatically selects the appropriate midpoint calculation depending
on the manifold type. It supports hyperbolic, Euclidean, and spherical geometries.
Args:
u: The first angular coordinate.
v: The second angular coordinate.
manifold: An object representing the manifold type.
special_first: If True, uses the manifold-specific midpoint calculations given the manifold type of hyperbolic
or euclidean.
Returns:
midpoint: The computed midpoint between u and v, based on the selected geometry.
"""
if torch.isclose(u, v):
return u
elif manifold.type == "H" and special_first:
return hyperbolic_midpoint(u, v)
elif manifold.type == "E" and special_first:
return euclidean_midpoint(u, v)
# Spherical midpoint handles all spherical angles
# *AND* any angles that don't involve figuring out where you hit the manifold
else:
return spherical_midpoint(u, v)
|