Skip to content

Commit

Permalink
Ranger Gradient Release
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Feb 16, 2024
1 parent 2f97eba commit 4d0bb13
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 99 deletions.
255 changes: 162 additions & 93 deletions optimi/ranger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@

import math
from typing import Any, Callable, Iterable
from warnings import warn

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _default_to_fused_or_foreach
from torch.optim.optimizer import _default_to_fused_or_foreach
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

from optimi.utils import MIN_TORCH_2_1, debias, debias_beta
from optimi.optimizer import OptimiOptimizer
from optimi.utils import debias, debias_beta

__all__ = ["Ranger", "ranger"]


class Ranger(Optimizer):
class Ranger(OptimiOptimizer):
"""Ranger optimizer. RAdam with Lookahead.
Args:
Expand All @@ -47,6 +47,9 @@ class Ranger(Optimizer):
parameters (default: None)
foreach: Enables the foreach implementation. If unspecified, tries to use foreach over
for-loop implementation since it is significantly faster (default: None)
gradient_release: Fuses optimizer step and zero_grad as part of the parameter's backward
pass. Requires model hooks created with `register_gradient_release`. Incompatible with
closure (default: False)
"""

def __init__(
Expand All @@ -63,32 +66,14 @@ def __init__(
max_lr: float | None = None,
kahan_sum: bool | None = None,
foreach: bool | None = None,
gradient_release: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr=}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta1 parameter: {betas[0]=}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta2 parameter: {betas[1]=}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight decay: {weight_decay=}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon: {eps=}")
if decouple_lr and max_lr is None:
max_lr = lr
if max_lr is not None and not 0.0 <= max_lr:
raise ValueError(f"Invalid maximum learning rate: {max_lr=}")
if decouple_lr and weight_decay >= 1e-3:
warn(
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight "
f"decay, fully decoupled weight decay does not reduce weight decay by the learning rate.",
category=UserWarning,
)
if not MIN_TORCH_2_1:
if foreach:
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.")
else:
foreach = False

defaults = dict(
lr=lr,
Expand All @@ -103,10 +88,26 @@ def __init__(
max_lr=max_lr,
kahan_sum=kahan_sum,
foreach=foreach,
gradient_release=gradient_release,
setup=False,
)
super().__init__(params, defaults)

def _init_state(self, group: dict[str, Any], state: dict[Tensor, Any], param: Tensor):
if len(state) <= 1:
state["exp_avg"] = torch.zeros_like(param, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(param, memory_format=torch.preserve_format)
state["la_param"] = param.data.clone()

if (group["kahan_sum"] or group["kahan_sum"] is None) and param.dtype in [torch.float16, torch.bfloat16]:
state["kahan_comp"] = torch.zeros_like(param, memory_format=torch.preserve_format)
group["kahan_sum"] = True
else:
state["kahan_comp"] = None

if group["gradient_release"]:
state["step"] = torch.tensor(0, dtype=torch.int32)

def _init_group(
self,
group: dict[str, Any],
Expand All @@ -125,17 +126,7 @@ def _init_group(
grads.append(p.grad)
state = self.state[p]

# State initialization
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["la_param"] = p.data.clone()

if (group["kahan_sum"] or group["kahan_sum"] is None) and p.dtype in [torch.float16, torch.bfloat16]:
state["kahan_comp"] = torch.zeros_like(p, memory_format=torch.preserve_format)
group["kahan_sum"] = True
else:
state["kahan_comp"] = None
self._init_state(group, state, p)

exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
Expand All @@ -150,41 +141,74 @@ def _init_group(
_, group["foreach"] = _default_to_fused_or_foreach(params, False, False)

@torch.no_grad()
def step(self, closure: Callable | None = None):
"""Performs a single optimization step.
def step(self, closure: Callable | None = None, param: Tensor | None = None):
"""Performs a single optimization step on the whole model or individual parameter.
Args:
closure: A closure which reevaluates the model and returns the loss
closure: A closure which reevaluates the model and returns the loss. Incompatible with
performing an optimization step on a single `param`.
param: An individual parameter to perform a fused optimization step during the backward
pass. Requires optimizer to be initialized with `gradient_release=True` and model
hooks created with `register_gradient_release`. Incompatible with `closure`.
"""
loss = None
if closure is not None:
if closure is not None and param is None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params, grads, exp_avgs, exp_avg_sqs, la_params, kahan_comps = [], [], [], [], [], []
self._init_group(group, params, grads, exp_avgs, exp_avg_sqs, la_params, kahan_comps)
if param is None:
for group in self.param_groups:
params, grads, exp_avgs, exp_avg_sqs, la_params, kahan_comps = [], [], [], [], [], []
self._init_group(group, params, grads, exp_avgs, exp_avg_sqs, la_params, kahan_comps)

ranger(
params=params,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
la_params=la_params,
kahan_comps=kahan_comps,
lr=group["lr"],
beta1=group["beta1"],
beta2=group["beta2"],
weight_decay=group["weight_decay"],
eps=group["eps"],
k=group["k"],
alpha=group["alpha"],
step=group["step"],
decouple_wd=group["decouple_wd"],
decouple_lr=group["decouple_lr"],
max_lr=group["max_lr"],
kahan_sum=group["kahan_sum"],
foreach=group["foreach"],
gradient_release=False,
)
else:
state = self.state[param]
group = state["group"]
self._init_state(group, state, param)

ranger(
params=params,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
la_params=la_params,
kahan_comps=kahan_comps,
params=param,
grads=param.grad,
exp_avgs=state["exp_avg"],
exp_avg_sqs=state["exp_avg_sq"],
la_params=state["la_param"],
kahan_comps=state["kahan_comp"],
lr=group["lr"],
beta1=group["beta1"],
beta2=group["beta2"],
weight_decay=group["weight_decay"],
eps=group["eps"],
k=group["k"],
alpha=group["alpha"],
step=group["step"],
step=state["step"],
decouple_wd=group["decouple_wd"],
decouple_lr=group["decouple_lr"],
max_lr=group["max_lr"],
kahan_sum=group["kahan_sum"],
foreach=group["foreach"],
foreach=False,
gradient_release=True,
)

return loss
Expand All @@ -211,6 +235,7 @@ def ranger(
max_lr: float | None = None,
kahan_sum: bool = False,
foreach: bool = False,
gradient_release: bool = False,
):
"""Functional API to apply a Ranger optimization step.
Expand All @@ -236,6 +261,7 @@ def ranger(
max_lr: Maximum scheduled learning rate for `decouple_lr`
kahan_sum: Enables Kahan summation for low precision parameters
foreach: Enables the faster foreach implementation
gradient_release: Fuses optimizer step as part of the parameter's backward pass
"""
# calculate debiased beta hat & complement terms
step.add_(1)
Expand Down Expand Up @@ -264,16 +290,18 @@ def ranger(

if foreach:
func = _foreach_ranger
elif gradient_release:
func = _single_param_ranger
else:
func = _single_ranger

func(
params=params,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
la_params=la_params,
kahan_comps=kahan_comps,
params,
grads,
exp_avgs,
exp_avg_sqs,
la_params,
kahan_comps,
lr=lr,
beta1_comp=beta1_comp,
beta2_hat=beta2_hat,
Expand Down Expand Up @@ -315,54 +343,95 @@ def _single_ranger(
la_param = la_params[i]
kahan_comp = kahan_comps[i]

# decoupled weight decay, fully decoupled weight decay, or L2 weight decay
if weight_decay != 0:
if decouple_wd:
param.mul_(weight_decay)
else:
grad.add_(param, alpha=weight_decay)
_single_param_ranger(
param=param,
grad=grad,
exp_avg=exp_avg,
exp_avg_sq=exp_avg_sq,
la_param=la_param,
kahan_comp=kahan_comp,
lr=lr,
beta1_comp=beta1_comp,
beta2_hat=beta2_hat,
weight_decay=weight_decay,
eps=eps,
rect=rect,
k=k,
alpha=alpha,
step=step,
decouple_wd=decouple_wd,
kahan_sum=kahan_sum,
)

# update gradient moving averages with debiased betas
exp_avg.lerp_(grad, weight=beta1_comp)
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1 - beta2_hat)

if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# RAdam step
if rect is not None:
kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect)
else:
kahan_comp.add_(exp_avg, alpha=-lr)
def _single_param_ranger(
param: Tensor,
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
la_param: Tensor,
kahan_comp: Tensor | None,
*,
lr: float,
beta1_comp: float,
beta2_hat: float,
weight_decay: float,
eps: float,
rect: float | None,
k: int,
alpha: float,
step: int,
decouple_wd: bool,
kahan_sum: bool = False,
):
# decoupled weight decay, fully decoupled weight decay, or L2 weight decay
if weight_decay != 0:
if decouple_wd:
param.mul_(weight_decay)
else:
grad.add_(param, alpha=weight_decay)

# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)
# update gradient moving averages with debiased betas
exp_avg.lerp_(grad, weight=beta1_comp)
exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1 - beta2_hat)

# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))
if kahan_sum and param.dtype in [torch.float16, torch.bfloat16]:
# RAdam step
if rect is not None:
kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect)
else:
kahan_comp.add_(exp_avg, alpha=-lr)

# Lookahead step
if step % k == 0:
kahan_comp.add_(param.sub_(la_param), alpha=alpha)
# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)

# update weights with kahan compensation using grad as temp buffer
grad.copy_(la_param.detach())
la_param.add_(kahan_comp)
# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(param))

# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(la_param), alpha=alpha)
# Lookahead step
if step % k == 0:
kahan_comp.add_(param.sub_(la_param), alpha=alpha)

# update weights with kahan compensation using grad as temp buffer
grad.copy_(la_param.detach())
la_param.add_(kahan_comp)

# save error back to kahan compensation for next iteration
kahan_comp.add_(grad.sub_(la_param), alpha=alpha)

param.copy_(la_param)
param.copy_(la_param)
else:
# RAdam step
if rect is not None:
param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect)
else:
# RAdam step
if rect is not None:
param.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(eps), value=-lr * rect)
else:
param.add_(exp_avg, alpha=-lr)
param.add_(exp_avg, alpha=-lr)

# Lookahead step
if step % k == 0:
la_param.add_(param.sub(la_param), alpha=alpha)
param.copy_(la_param)
# Lookahead step
if step % k == 0:
la_param.add_(param.sub(la_param), alpha=alpha)
param.copy_(la_param)


def _foreach_ranger(
Expand Down
Loading

0 comments on commit 4d0bb13

Please sign in to comment.