-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cecfc90
commit d7b421c
Showing
6 changed files
with
386 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
from functools import partial | ||
from warnings import warn | ||
|
||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from optimi.optimizer import OptimiOptimizer | ||
|
||
|
||
def _gradient_release_hook(param: Tensor, optimizer: OptimiOptimizer): | ||
optimizer.step(param=param) | ||
optimizer.zero_grad(param=param) | ||
|
||
|
||
def prepare_for_gradient_release(model: Module, optimizer: OptimiOptimizer, ignore_existing_hooks: bool = False): | ||
"""Register post_accumulate_grad_hooks on parameters for the gradient release optimization step. | ||
Args: | ||
model: Model to register post_accumulate_grad_hooks. Only registers on parameters with | ||
`requires_grad=True`. | ||
optimizer: Optimizer providing the fused optimizer step during the backward pass. Requires | ||
optimizer to be initialized with `gradient_release=True` | ||
ignore_existing_hooks: If True, will not register post_accumulate_grad_hooks if the model | ||
(default: False) | ||
""" | ||
if not isinstance(optimizer, OptimiOptimizer): | ||
raise TypeError("`optimizer` must be an instance of `OptimiOptimizer`") | ||
if not optimizer.defaults["gradient_release"]: | ||
raise ValueError("`optimizer` must be initialized with `gradient_release=True`") | ||
|
||
hooks = [] | ||
for p in model.parameters(): | ||
if p.requires_grad: | ||
if (p._post_accumulate_grad_hooks is not None) and len(p._post_accumulate_grad_hooks) > 0 and (not ignore_existing_hooks): | ||
for hook in hooks: | ||
if hasattr(hook, "remove"): | ||
hook.remove() | ||
raise ValueError( | ||
"Model already has post_accumulate_grad_hooks. If this is expected, rerun with `ignore_existing_hooks=True`." | ||
) | ||
hooks.append(p.register_post_accumulate_grad_hook(partial(_gradient_release_hook, optimizer=optimizer))) | ||
model._gradient_release_hooks = hooks | ||
|
||
|
||
def remove_gradient_release(model: Module): | ||
"""Removes post_accumulate_grad_hooks created by `prepare_for_gradient_release`. | ||
Args: | ||
model: Model to remove gradient release post_accumulate_grad_hooks from. | ||
""" | ||
if not hasattr(model, "_gradient_release_hooks"): | ||
warn("`model` does not have any gradient release post_accumulate_grad_hooks to remove.") | ||
return | ||
|
||
for hook in model._gradient_release_hooks: | ||
if hasattr(hook, "remove"): | ||
hook.remove() | ||
del model._gradient_release_hooks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Callable, Iterable | ||
from warnings import warn | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.optim.optimizer import Optimizer | ||
|
||
from optimi.utils import MIN_TORCH_2_1 | ||
|
||
|
||
class OptimiOptimizer(Optimizer): | ||
"""Provides common functionality for optimi optimizers.""" | ||
|
||
def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str, Any]): | ||
if not MIN_TORCH_2_1: | ||
if defaults["foreach"]: | ||
foreach = defaults["foreach"] | ||
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.") | ||
else: | ||
defaults["foreach"] = False | ||
|
||
if defaults["decouple_lr"] and defaults["weight_decay"] >= 1e-3: | ||
weight_decay = defaults["weight_decay"] | ||
decouple_lr = defaults["decouple_lr"] | ||
warn( | ||
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight " | ||
f"decay, fully decoupled weight decay does not reduce weight decay by the learning rate.", | ||
category=UserWarning, | ||
) | ||
|
||
super().__init__(params, defaults) | ||
|
||
if self.defaults["gradient_release"]: | ||
for group in self.param_groups: | ||
for p in group["params"]: | ||
self.state[p]["group"] = group | ||
|
||
def step(self, closure: Callable | None = None, param: Tensor | None = None): | ||
"""Performs a single optimization step on the whole model or individual parameter. | ||
Args: | ||
closure: A closure which reevaluates the model and returns the loss. Incompatible with | ||
performing an optimization step on a single `param`. | ||
param: An individual parameter to perform a fused optimization step during the backward | ||
pass. Requires optimizer to be initialized with `gradient_release=True` and model | ||
hooks created with `register_gradient_release`. Incompatible with `closure`. | ||
""" | ||
raise NotImplementedError | ||
|
||
@torch._disable_dynamo | ||
def zero_grad(self, set_to_none: bool = True, param: Tensor | None = None): | ||
"""Resets the gradients of all optimized parameters or individual parameter. | ||
Args: | ||
set_to_none: If True, the gradients will be deallocated after the call (default: True) | ||
param: Resets the gradients of the passed `param`. For use with `gradient_release=True`. | ||
""" | ||
if param is None: | ||
super().zero_grad(set_to_none=set_to_none) | ||
else: | ||
if param.grad is not None: | ||
if set_to_none: | ||
param.grad = None | ||
else: | ||
if param.grad.grad_fn is not None: | ||
param.grad.detach_() | ||
else: | ||
param.grad.requires_grad_(False) |
Oops, something went wrong.