Skip to content

Commit

Permalink
SGD Gradient Release
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Feb 6, 2024
1 parent cecfc90 commit d7b421c
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 75 deletions.
1 change: 1 addition & 0 deletions optimi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .adam import Adam, adam
from .adamw import AdamW, adamw
from .adan import Adan, adan
from .gradientrelease import prepare_for_gradient_release, remove_gradient_release
from .lion import Lion, lion
from .radam import RAdam, radam
from .ranger import Ranger, ranger
Expand Down
60 changes: 60 additions & 0 deletions optimi/gradientrelease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from functools import partial
from warnings import warn

from torch import Tensor
from torch.nn import Module

from optimi.optimizer import OptimiOptimizer


def _gradient_release_hook(param: Tensor, optimizer: OptimiOptimizer):
optimizer.step(param=param)
optimizer.zero_grad(param=param)


def prepare_for_gradient_release(model: Module, optimizer: OptimiOptimizer, ignore_existing_hooks: bool = False):
"""Register post_accumulate_grad_hooks on parameters for the gradient release optimization step.
Args:
model: Model to register post_accumulate_grad_hooks. Only registers on parameters with
`requires_grad=True`.
optimizer: Optimizer providing the fused optimizer step during the backward pass. Requires
optimizer to be initialized with `gradient_release=True`
ignore_existing_hooks: If True, will not register post_accumulate_grad_hooks if the model
(default: False)
"""
if not isinstance(optimizer, OptimiOptimizer):
raise TypeError("`optimizer` must be an instance of `OptimiOptimizer`")
if not optimizer.defaults["gradient_release"]:
raise ValueError("`optimizer` must be initialized with `gradient_release=True`")

hooks = []
for p in model.parameters():
if p.requires_grad:
if (p._post_accumulate_grad_hooks is not None) and len(p._post_accumulate_grad_hooks) > 0 and (not ignore_existing_hooks):
for hook in hooks:
if hasattr(hook, "remove"):
hook.remove()
raise ValueError(
"Model already has post_accumulate_grad_hooks. If this is expected, rerun with `ignore_existing_hooks=True`."
)
hooks.append(p.register_post_accumulate_grad_hook(partial(_gradient_release_hook, optimizer=optimizer)))
model._gradient_release_hooks = hooks


def remove_gradient_release(model: Module):
"""Removes post_accumulate_grad_hooks created by `prepare_for_gradient_release`.
Args:
model: Model to remove gradient release post_accumulate_grad_hooks from.
"""
if not hasattr(model, "_gradient_release_hooks"):
warn("`model` does not have any gradient release post_accumulate_grad_hooks to remove.")
return

for hook in model._gradient_release_hooks:
if hasattr(hook, "remove"):
hook.remove()
del model._gradient_release_hooks
70 changes: 70 additions & 0 deletions optimi/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from typing import Any, Callable, Iterable
from warnings import warn

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer

from optimi.utils import MIN_TORCH_2_1


class OptimiOptimizer(Optimizer):
"""Provides common functionality for optimi optimizers."""

def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str, Any]):
if not MIN_TORCH_2_1:
if defaults["foreach"]:
foreach = defaults["foreach"]
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.")
else:
defaults["foreach"] = False

if defaults["decouple_lr"] and defaults["weight_decay"] >= 1e-3:
weight_decay = defaults["weight_decay"]
decouple_lr = defaults["decouple_lr"]
warn(
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight "
f"decay, fully decoupled weight decay does not reduce weight decay by the learning rate.",
category=UserWarning,
)

super().__init__(params, defaults)

if self.defaults["gradient_release"]:
for group in self.param_groups:
for p in group["params"]:
self.state[p]["group"] = group

def step(self, closure: Callable | None = None, param: Tensor | None = None):
"""Performs a single optimization step on the whole model or individual parameter.
Args:
closure: A closure which reevaluates the model and returns the loss. Incompatible with
performing an optimization step on a single `param`.
param: An individual parameter to perform a fused optimization step during the backward
pass. Requires optimizer to be initialized with `gradient_release=True` and model
hooks created with `register_gradient_release`. Incompatible with `closure`.
"""
raise NotImplementedError

@torch._disable_dynamo
def zero_grad(self, set_to_none: bool = True, param: Tensor | None = None):
"""Resets the gradients of all optimized parameters or individual parameter.
Args:
set_to_none: If True, the gradients will be deallocated after the call (default: True)
param: Resets the gradients of the passed `param`. For use with `gradient_release=True`.
"""
if param is None:
super().zero_grad(set_to_none=set_to_none)
else:
if param.grad is not None:
if set_to_none:
param.grad = None
else:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
Loading

0 comments on commit d7b421c

Please sign in to comment.