Riemannian Adan (Radan).
Radan is the Riemannian version of the Adaptive Nesterov Momentum algorithm.
This code is compatible with both Geoopt and Manify libraries, and is designed for Riemannian Fuzzy K-Means.
We recommend using the parameters [0.7, 0.99, 0.99] for best performance.**
For more details on the Radan algorithm, please refer to:
https://openreview.net/forum?id=9VmOgMN4Ie
If you find this work useful, please cite the paper as follows:
@article{Yuan2025,
title={Riemannian Fuzzy K-Means},
author={Anonymous},
journal={OpenReview},
year={2025},
url={https://openreview.net/forum?id=9VmOgMN4Ie}
}
If you're interested in Adan, you can see:
@ARTICLE{10586270,
author={Xie, Xingyu and Zhou, Pan and Li, Huan and Lin, Zhouchen and Yan, Shuicheng},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
title={Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models},
year={2024},
volume={46},
number={12},
pages={9508-9520},
keywords={Training;Convergence;Complexity theory;Deep learning;Computer architecture;Task analysis;Stochastic processes;Adaptive optimizer;fast DNN training;DNN optimizer},
doi={10.1109/TPAMI.2024.3423382}
}
If you have questions about the code, feel free to contact: yuanjinghuiiii@gmail.com.
RiemannianAdan(params, lr=0.001, betas=(0.98, 0.92, 0.99), eps=1e-08, weight_decay=0.0, max_grad_norm=0.0, no_prox=False, foreach=True, fused=False)
Bases: OptimMixin, Adan
Riemannian Adan with the same API as :class:adan.Adan.
| Attributes: |
-
param_groups
–
iterable of parameter groups, each containing parameters to optimize and optimization options
-
_default_manifold
–
the default manifold used for optimization if not specified in parameters
|
| Parameters: |
-
params
–
iterable of parameters to optimize or dicts defining parameter groups
-
lr
–
learning rate (default: 1e-3)
-
betas
–
coefficients used for computing (default: (0.98, 0.92, 0.99))
-
eps
–
term added to the denominator to improve numerical stability (default: 1e-8)
-
weight_decay
–
weight decay (L2 penalty) (default: 0)
|
Source code in manify/optimizers/_adan.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111 | def __init__(
self,
params,
lr=1e-3,
betas=(0.98, 0.92, 0.99),
eps=1e-8,
weight_decay=0.0,
max_grad_norm=0.0,
no_prox=False,
foreach: bool = True,
fused: bool = False,
):
if not 0.0 <= max_grad_norm:
raise ValueError("Invalid Max grad norm: {}".format(max_grad_norm))
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
if fused:
_check_fused_available()
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
no_prox=no_prox,
foreach=foreach,
fused=fused,
)
super().__init__(params, defaults)
|
step(closure=None)
Performs a single optimization step.
| Parameters: |
-
closure
(Callable | None, default:
None
)
–
A closure that reevaluates the model and returns the loss.
|
| Returns: |
-
Float[Tensor, ''] | None
–
The loss value if closure is provided, otherwise None.
|
Source code in manify/optimizers/radan.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161 | def step(self, closure: Callable | None = None) -> Float[torch.Tensor, ""] | None:
"""Performs a single optimization step.
Args:
closure: A closure that reevaluates the model and returns the loss.
Returns:
The loss value if closure is provided, otherwise None.
"""
loss = None
if closure is not None:
loss = closure()
with torch.no_grad():
for group in self.param_groups:
betas = group["betas"]
weight_decay = group["weight_decay"]
eps = group["eps"]
learning_rate = group["lr"]
stablilize = False
for point in group["params"]:
grad = point.grad
if grad is None:
continue
if isinstance(point, ManifoldParameter | ManifoldTensor):
manifold = point.manifold
else:
manifold = self._default_manifold
if grad.is_sparse:
raise RuntimeError("RiemannianAdan does not support sparse gradients")
state = self.state[point]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(point)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(point)
# new param
state["exp_avg_diff"] = torch.zeros_like(point)
# last step grad
state["last_grad"] = torch.zeros_like(point)
state["step"] += 1
# make local variables for easy access
exp_avg = state["exp_avg"]
exp_avg_diff = state["exp_avg_diff"]
exp_avg_sq = state["exp_avg_sq"]
last_grad = state["last_grad"]
# actual step
grad.add_(point, alpha=weight_decay)
grad = manifold.egrad2rgrad(point, grad)
# grad_last_diff
grad_last_diff = grad - last_grad
exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
# grad_last_diff
exp_avg_diff.mul_(betas[1]).add_(grad_last_diff, alpha=1 - betas[1])
# z_t
zt = grad_last_diff.mul(betas[1]).add_(grad)
# z_t^2
exp_avg_sq.mul_(betas[2]).add_(manifold.component_inner(point, zt), alpha=1 - betas[2])
bias_correction1 = 1 - betas[0] ** state["step"]
bias_correction2 = 1 - betas[1] ** state["step"]
bias_correction3 = 1 - betas[2] ** state["step"]
denom = exp_avg_sq.div(bias_correction3).sqrt_()
# copy the state, we need it for retraction
# get the direction for ascend
direction = (
(exp_avg.div(bias_correction1)).add_((exp_avg_diff.div(bias_correction2)), alpha=betas[1])
) / denom.add_(eps)
# transport the exponential averaging to the new point
new_point, exp_avg_new = manifold.retr_transp(point, -learning_rate * direction, exp_avg)
last_grad.copy_(manifold.transp(point, new_point, grad))
# transport v_t
exp_avg_diff.copy_(manifold.transp(point, new_point, exp_avg_diff))
exp_avg.copy_(exp_avg_new)
point.copy_(new_point)
if group["stabilize"] is not None and state["step"] % group["stabilize"] == 0:
stablilize = True
if stablilize:
self.stabilize_group(group)
return loss
|
stabilize_group(group)
Stabilizes the parameters in the group by projecting them onto their respective manifolds.
| Parameters: |
-
group
(dict[str, Any])
–
A dictionary containing the parameters and their states.
|
Source code in manify/optimizers/radan.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186 | @torch.no_grad() # type: ignore
def stabilize_group(self, group: dict[str, Any]) -> None:
"""Stabilizes the parameters in the group by projecting them onto their respective manifolds.
Args:
group: A dictionary containing the parameters and their states.
Returns:
None
"""
for p in group["params"]:
if not isinstance(p, ManifoldParameter | ManifoldTensor):
continue
state = self.state[p]
if not state: # due to None grads
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
exp_avg_diff = state["exp_avg_diff"]
last_grad = state["last_grad"]
p.copy_(manifold.projx(p))
exp_avg.copy_(manifold.proju(p, exp_avg))
exp_avg_diff.copy_(manifold.proju(p, exp_avg_diff))
last_grad.copy_(manifold.proju(p, last_grad))
|