Optimizers

manify.optimizers

New Riemannian Adan optimizer implementation.

RiemannianAdan(params, lr=0.001, betas=(0.98, 0.92, 0.99), eps=1e-08, weight_decay=0.0, max_grad_norm=0.0, no_prox=False, foreach=True, fused=False)

Bases: OptimMixin, Adan

Riemannian Adan with the same API as :class:adan.Adan.

Attributes:
  • param_groups

    iterable of parameter groups, each containing parameters to optimize and optimization options

  • _default_manifold

    the default manifold used for optimization if not specified in parameters

Parameters:
  • params

    iterable of parameters to optimize or dicts defining parameter groups

  • lr

    learning rate (default: 1e-3)

  • betas

    coefficients used for computing (default: (0.98, 0.92, 0.99))

  • eps

    term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay

    weight decay (L2 penalty) (default: 0)

Source code in manify/optimizers/_adan.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    params,
    lr=1e-3,
    betas=(0.98, 0.92, 0.99),
    eps=1e-8,
    weight_decay=0.0,
    max_grad_norm=0.0,
    no_prox=False,
    foreach: bool = True,
    fused: bool = False,
):
    if not 0.0 <= max_grad_norm:
        raise ValueError("Invalid Max grad norm: {}".format(max_grad_norm))
    if not 0.0 <= lr:
        raise ValueError("Invalid learning rate: {}".format(lr))
    if not 0.0 <= eps:
        raise ValueError("Invalid epsilon value: {}".format(eps))
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
    if not 0.0 <= betas[2] < 1.0:
        raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
    if fused:
        _check_fused_available()

    defaults = dict(
        lr=lr,
        betas=betas,
        eps=eps,
        weight_decay=weight_decay,
        max_grad_norm=max_grad_norm,
        no_prox=no_prox,
        foreach=foreach,
        fused=fused,
    )
    super().__init__(params, defaults)

step(closure=None)

Performs a single optimization step.

Parameters:
  • closure (Callable | None, default: None ) –

    A closure that reevaluates the model and returns the loss.

Returns:
  • Float[Tensor, ''] | None

    The loss value if closure is provided, otherwise None.

Source code in manify/optimizers/radan.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def step(self, closure: Callable | None = None) -> Float[torch.Tensor, ""] | None:
    """Performs a single optimization step.

    Args:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        The loss value if closure is provided, otherwise None.
    """
    loss = None
    if closure is not None:
        loss = closure()

    with torch.no_grad():
        for group in self.param_groups:
            betas = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            learning_rate = group["lr"]
            stablilize = False
            for point in group["params"]:
                grad = point.grad
                if grad is None:
                    continue
                if isinstance(point, ManifoldParameter | ManifoldTensor):
                    manifold = point.manifold
                else:
                    manifold = self._default_manifold

                if grad.is_sparse:
                    raise RuntimeError("RiemannianAdan does not support sparse gradients")

                state = self.state[point]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(point)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(point)
                    # new param
                    state["exp_avg_diff"] = torch.zeros_like(point)
                    # last step grad
                    state["last_grad"] = torch.zeros_like(point)

                state["step"] += 1
                # make local variables for easy access
                exp_avg = state["exp_avg"]
                exp_avg_diff = state["exp_avg_diff"]
                exp_avg_sq = state["exp_avg_sq"]
                last_grad = state["last_grad"]
                # actual step

                grad.add_(point, alpha=weight_decay)
                grad = manifold.egrad2rgrad(point, grad)
                # grad_last_diff
                grad_last_diff = grad - last_grad
                exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
                # grad_last_diff
                exp_avg_diff.mul_(betas[1]).add_(grad_last_diff, alpha=1 - betas[1])
                # z_t
                zt = grad_last_diff.mul(betas[1]).add_(grad)
                # z_t^2
                exp_avg_sq.mul_(betas[2]).add_(manifold.component_inner(point, zt), alpha=1 - betas[2])
                bias_correction1 = 1 - betas[0] ** state["step"]
                bias_correction2 = 1 - betas[1] ** state["step"]
                bias_correction3 = 1 - betas[2] ** state["step"]

                denom = exp_avg_sq.div(bias_correction3).sqrt_()

                # copy the state, we need it for retraction
                # get the direction for ascend
                direction = (
                    (exp_avg.div(bias_correction1)).add_((exp_avg_diff.div(bias_correction2)), alpha=betas[1])
                ) / denom.add_(eps)

                # transport the exponential averaging to the new point
                new_point, exp_avg_new = manifold.retr_transp(point, -learning_rate * direction, exp_avg)

                last_grad.copy_(manifold.transp(point, new_point, grad))
                # transport v_t
                exp_avg_diff.copy_(manifold.transp(point, new_point, exp_avg_diff))
                exp_avg.copy_(exp_avg_new)
                point.copy_(new_point)

                if group["stabilize"] is not None and state["step"] % group["stabilize"] == 0:
                    stablilize = True
            if stablilize:
                self.stabilize_group(group)
    return loss

stabilize_group(group)

Stabilizes the parameters in the group by projecting them onto their respective manifolds.

Parameters:
  • group (dict[str, Any]) –

    A dictionary containing the parameters and their states.

Returns:
  • None

    None

Source code in manify/optimizers/radan.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@torch.no_grad()  # type: ignore
def stabilize_group(self, group: dict[str, Any]) -> None:
    """Stabilizes the parameters in the group by projecting them onto their respective manifolds.

    Args:
        group: A dictionary containing the parameters and their states.

    Returns:
        None
    """
    for p in group["params"]:
        if not isinstance(p, ManifoldParameter | ManifoldTensor):
            continue
        state = self.state[p]
        if not state:  # due to None grads
            continue
        manifold = p.manifold
        exp_avg = state["exp_avg"]
        exp_avg_diff = state["exp_avg_diff"]
        last_grad = state["last_grad"]
        p.copy_(manifold.projx(p))
        exp_avg.copy_(manifold.proju(p, exp_avg))
        exp_avg_diff.copy_(manifold.proju(p, exp_avg_diff))
        last_grad.copy_(manifold.proju(p, last_grad))

radan

Riemannian Adan (Radan).

Radan is the Riemannian version of the Adaptive Nesterov Momentum algorithm. This code is compatible with both Geoopt and Manify libraries, and is designed for Riemannian Fuzzy K-Means. We recommend using the parameters [0.7, 0.99, 0.99] for best performance.**

For more details on the Radan algorithm, please refer to: https://openreview.net/forum?id=9VmOgMN4Ie

If you find this work useful, please cite the paper as follows:

@article{Yuan2025,
  title={Riemannian Fuzzy K-Means},
  author={Anonymous},
  journal={OpenReview},
  year={2025},
  url={https://openreview.net/forum?id=9VmOgMN4Ie}
}

If you're interested in Adan, you can see:

@ARTICLE{10586270,
  author={Xie, Xingyu and Zhou, Pan and Li, Huan and Lin, Zhouchen and Yan, Shuicheng},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
  title={Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models},
  year={2024},
  volume={46},
  number={12},
  pages={9508-9520},
  keywords={Training;Convergence;Complexity theory;Deep learning;Computer architecture;Task analysis;Stochastic processes;Adaptive optimizer;fast DNN training;DNN optimizer},
  doi={10.1109/TPAMI.2024.3423382}
}

If you have questions about the code, feel free to contact: yuanjinghuiiii@gmail.com.

RiemannianAdan(params, lr=0.001, betas=(0.98, 0.92, 0.99), eps=1e-08, weight_decay=0.0, max_grad_norm=0.0, no_prox=False, foreach=True, fused=False)

Bases: OptimMixin, Adan

Riemannian Adan with the same API as :class:adan.Adan.

Attributes:
  • param_groups

    iterable of parameter groups, each containing parameters to optimize and optimization options

  • _default_manifold

    the default manifold used for optimization if not specified in parameters

Parameters:
  • params

    iterable of parameters to optimize or dicts defining parameter groups

  • lr

    learning rate (default: 1e-3)

  • betas

    coefficients used for computing (default: (0.98, 0.92, 0.99))

  • eps

    term added to the denominator to improve numerical stability (default: 1e-8)

  • weight_decay

    weight decay (L2 penalty) (default: 0)

Source code in manify/optimizers/_adan.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    params,
    lr=1e-3,
    betas=(0.98, 0.92, 0.99),
    eps=1e-8,
    weight_decay=0.0,
    max_grad_norm=0.0,
    no_prox=False,
    foreach: bool = True,
    fused: bool = False,
):
    if not 0.0 <= max_grad_norm:
        raise ValueError("Invalid Max grad norm: {}".format(max_grad_norm))
    if not 0.0 <= lr:
        raise ValueError("Invalid learning rate: {}".format(lr))
    if not 0.0 <= eps:
        raise ValueError("Invalid epsilon value: {}".format(eps))
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
    if not 0.0 <= betas[2] < 1.0:
        raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2]))
    if fused:
        _check_fused_available()

    defaults = dict(
        lr=lr,
        betas=betas,
        eps=eps,
        weight_decay=weight_decay,
        max_grad_norm=max_grad_norm,
        no_prox=no_prox,
        foreach=foreach,
        fused=fused,
    )
    super().__init__(params, defaults)
step(closure=None)

Performs a single optimization step.

Parameters:
  • closure (Callable | None, default: None ) –

    A closure that reevaluates the model and returns the loss.

Returns:
  • Float[Tensor, ''] | None

    The loss value if closure is provided, otherwise None.

Source code in manify/optimizers/radan.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def step(self, closure: Callable | None = None) -> Float[torch.Tensor, ""] | None:
    """Performs a single optimization step.

    Args:
        closure: A closure that reevaluates the model and returns the loss.

    Returns:
        The loss value if closure is provided, otherwise None.
    """
    loss = None
    if closure is not None:
        loss = closure()

    with torch.no_grad():
        for group in self.param_groups:
            betas = group["betas"]
            weight_decay = group["weight_decay"]
            eps = group["eps"]
            learning_rate = group["lr"]
            stablilize = False
            for point in group["params"]:
                grad = point.grad
                if grad is None:
                    continue
                if isinstance(point, ManifoldParameter | ManifoldTensor):
                    manifold = point.manifold
                else:
                    manifold = self._default_manifold

                if grad.is_sparse:
                    raise RuntimeError("RiemannianAdan does not support sparse gradients")

                state = self.state[point]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(point)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(point)
                    # new param
                    state["exp_avg_diff"] = torch.zeros_like(point)
                    # last step grad
                    state["last_grad"] = torch.zeros_like(point)

                state["step"] += 1
                # make local variables for easy access
                exp_avg = state["exp_avg"]
                exp_avg_diff = state["exp_avg_diff"]
                exp_avg_sq = state["exp_avg_sq"]
                last_grad = state["last_grad"]
                # actual step

                grad.add_(point, alpha=weight_decay)
                grad = manifold.egrad2rgrad(point, grad)
                # grad_last_diff
                grad_last_diff = grad - last_grad
                exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0])
                # grad_last_diff
                exp_avg_diff.mul_(betas[1]).add_(grad_last_diff, alpha=1 - betas[1])
                # z_t
                zt = grad_last_diff.mul(betas[1]).add_(grad)
                # z_t^2
                exp_avg_sq.mul_(betas[2]).add_(manifold.component_inner(point, zt), alpha=1 - betas[2])
                bias_correction1 = 1 - betas[0] ** state["step"]
                bias_correction2 = 1 - betas[1] ** state["step"]
                bias_correction3 = 1 - betas[2] ** state["step"]

                denom = exp_avg_sq.div(bias_correction3).sqrt_()

                # copy the state, we need it for retraction
                # get the direction for ascend
                direction = (
                    (exp_avg.div(bias_correction1)).add_((exp_avg_diff.div(bias_correction2)), alpha=betas[1])
                ) / denom.add_(eps)

                # transport the exponential averaging to the new point
                new_point, exp_avg_new = manifold.retr_transp(point, -learning_rate * direction, exp_avg)

                last_grad.copy_(manifold.transp(point, new_point, grad))
                # transport v_t
                exp_avg_diff.copy_(manifold.transp(point, new_point, exp_avg_diff))
                exp_avg.copy_(exp_avg_new)
                point.copy_(new_point)

                if group["stabilize"] is not None and state["step"] % group["stabilize"] == 0:
                    stablilize = True
            if stablilize:
                self.stabilize_group(group)
    return loss
stabilize_group(group)

Stabilizes the parameters in the group by projecting them onto their respective manifolds.

Parameters:
  • group (dict[str, Any]) –

    A dictionary containing the parameters and their states.

Returns:
  • None

    None

Source code in manify/optimizers/radan.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@torch.no_grad()  # type: ignore
def stabilize_group(self, group: dict[str, Any]) -> None:
    """Stabilizes the parameters in the group by projecting them onto their respective manifolds.

    Args:
        group: A dictionary containing the parameters and their states.

    Returns:
        None
    """
    for p in group["params"]:
        if not isinstance(p, ManifoldParameter | ManifoldTensor):
            continue
        state = self.state[p]
        if not state:  # due to None grads
            continue
        manifold = p.manifold
        exp_avg = state["exp_avg"]
        exp_avg_diff = state["exp_avg_diff"]
        last_grad = state["last_grad"]
        p.copy_(manifold.projx(p))
        exp_avg.copy_(manifold.proju(p, exp_avg))
        exp_avg_diff.copy_(manifold.proju(p, exp_avg_diff))
        last_grad.copy_(manifold.proju(p, last_grad))