Skip to content

Commit

Permalink
Add warning, version, and value checks
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Oct 30, 2023
1 parent c7815be commit 92a640e
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 14 deletions.
24 changes: 19 additions & 5 deletions optimi/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from __future__ import annotations

from typing import Any, Callable, Iterable
from warnings import warn

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _default_to_fused_or_foreach, required
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

from optimi.utils import debias_beta
from optimi.utils import MIN_TORCH_2_1, debias_beta

__all__ = ["Adam", "adam"]

Expand Down Expand Up @@ -57,21 +58,34 @@ def __init__(
decouple_wd: bool = False,
decouple_lr: bool = False,
max_lr: float | None = None,
foreach: bool | None = None,
kahan_sum: bool | None = None,
foreach: bool | None = None,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr=}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon: {eps=}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta1 parameter: {betas[0]=}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta2 parameter: {betas[1]=}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight decay: {weight_decay=}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon: {eps=}")
if decouple_lr and max_lr is None:
max_lr = lr
if max_lr is not None and not 0.0 <= max_lr:
raise ValueError(f"Invalid maximum learning rate: {max_lr=}")
if decouple_lr and weight_decay >= 1e-3:
warn(
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight "
f"decay, learning rate decoupled weight decay does not reduce weight decay by the learning rate.",
category=UserWarning,
)
if not MIN_TORCH_2_1:
if foreach:
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.")
else:
foreach = False

defaults = dict(
lr=lr,
Expand All @@ -82,8 +96,8 @@ def __init__(
decouple_wd=decouple_wd,
decouple_lr=decouple_lr,
max_lr=max_lr,
foreach=foreach,
kahan_sum=kahan_sum,
foreach=foreach,
setup=False,
)
super().__init__(params, defaults)
Expand Down
16 changes: 15 additions & 1 deletion optimi/adan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from __future__ import annotations

from typing import Any, Callable, Iterable
from warnings import warn

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _default_to_fused_or_foreach, required
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

from optimi.utils import debias_beta
from optimi.utils import MIN_TORCH_2_1, debias_beta

__all__ = ["Adan", "adan"]

Expand Down Expand Up @@ -78,6 +79,19 @@ def __init__(
raise ValueError(f"Invalid epsilon: {eps=}")
if decouple_lr and max_lr is None:
max_lr = lr
if max_lr is not None and not 0.0 <= max_lr:
raise ValueError(f"Invalid maximum learning rate: {max_lr=}")
if decouple_lr and weight_decay >= 1e-3:
warn(
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight "
f"decay, learning rate decoupled weight decay does not reduce weight decay by the learning rate.",
category=UserWarning,
)
if not MIN_TORCH_2_1:
if foreach:
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.")
else:
foreach = False

defaults = dict(
lr=lr,
Expand Down
16 changes: 16 additions & 0 deletions optimi/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from __future__ import annotations

from typing import Any, Callable, Iterable
from warnings import warn

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _default_to_fused_or_foreach, required
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

from optimi.utils import MIN_TORCH_2_1

__all__ = ["Lion", "lion"]


Expand Down Expand Up @@ -64,6 +67,19 @@ def __init__(
raise ValueError(f"Invalid weight decay: {weight_decay=}")
if decouple_lr and max_lr is None:
max_lr = lr
if max_lr is not None and not 0.0 <= max_lr:
raise ValueError(f"Invalid maximum learning rate: {max_lr=}")
if decouple_lr and weight_decay >= 1e-3:
warn(
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight "
f"decay, learning rate decoupled weight decay does not reduce weight decay by the learning rate.",
category=UserWarning,
)
if not MIN_TORCH_2_1:
if foreach:
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.")
else:
foreach = False

defaults = dict(
lr=lr,
Expand Down
20 changes: 18 additions & 2 deletions optimi/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from __future__ import annotations

from typing import Any, Callable, Iterable
from warnings import warn

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _default_to_fused_or_foreach, required
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

from optimi.utils import MIN_TORCH_2_1

__all__ = ["SGD", "sgd"]


Expand Down Expand Up @@ -72,6 +75,19 @@ def __init__(
raise ValueError(f"Invalid weight decay: {weight_decay=}")
if decouple_lr and max_lr is None:
max_lr = lr
if max_lr is not None and not 0.0 <= max_lr:
raise ValueError(f"Invalid maximum learning rate: {max_lr=}")
if decouple_lr and weight_decay >= 1e-3:
warn(
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight "
f"decay, learning rate decoupled weight decay does not reduce weight decay by the learning rate.",
category=UserWarning,
)
if not MIN_TORCH_2_1:
if foreach:
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.")
else:
foreach = False

defaults = dict(
lr=lr,
Expand Down Expand Up @@ -181,9 +197,9 @@ def sgd(
grads (list): Parameter gradients
exp_avgs (list): Momentum buffers
kahan_comps (list, optional): Kahan summation compensations
weight_decay (float): Weight decay coefficient
lr (float): Learning rate
momentum (float): Momentum factor
weight_decay (float): Weight decay coefficient
dampening (bool): Use dampening for momentum update
decouple_wd (bool): Apply decoupled weight decay
decouple_lr (bool): Apply learning rate decoupled weight decay
Expand Down Expand Up @@ -256,7 +272,7 @@ def _single_sgd(
# SGD with Momentum step
kahan_comp.add_(exp_avg, alpha=-lr)

# update weights with kahan compensation
# update weights with kahan compensation using grad as temp buffer
grad.copy_(param.detach())
param.add_(kahan_comp)

Expand Down
24 changes: 19 additions & 5 deletions optimi/stableadamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from __future__ import annotations

from typing import Any, Callable, Iterable
from warnings import warn

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _default_to_fused_or_foreach, required
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

from optimi.utils import debias_beta
from optimi.utils import MIN_TORCH_2_1, debias_beta

__all__ = ["StableAdamW", "stableadamw"]

Expand Down Expand Up @@ -55,21 +56,34 @@ def __init__(
eps: float = 1e-6,
decouple_lr: bool = False,
max_lr: float | None = None,
foreach: bool | None = None,
kahan_sum: bool | None = None,
foreach: bool | None = None,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr=}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon: {eps=}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta1 parameter: {betas[0]=}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta2 parameter: {betas[1]=}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight decay: {weight_decay=}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon: {eps=}")
if decouple_lr and max_lr is None:
max_lr = lr
if max_lr is not None and not 0.0 <= max_lr:
raise ValueError(f"Invalid maximum learning rate: {max_lr=}")
if decouple_lr and weight_decay >= 1e-3:
warn(
f"You are using {weight_decay=} which is potentially high for {decouple_lr=}. Unlike decoupled weight "
f"decay, learning rate decoupled weight decay does not reduce weight decay by the learning rate.",
category=UserWarning,
)
if not MIN_TORCH_2_1:
if foreach:
raise ValueError(f"{foreach=} requires PyTorch 2.1 or later. Set foreach=False or upgrade PyTorch.")
else:
foreach = False

defaults = dict(
lr=lr,
Expand All @@ -79,8 +93,8 @@ def __init__(
weight_decay=weight_decay,
decouple_lr=decouple_lr,
max_lr=max_lr,
foreach=foreach,
kahan_sum=kahan_sum,
foreach=foreach,
setup=False,
)
super().__init__(params, defaults)
Expand Down
5 changes: 5 additions & 0 deletions optimi/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Optimizer utilities

import torch
from packaging.version import parse

MIN_TORCH_2_1 = parse(torch.__version__) >= parse("2.1")


def debias(beta: float, step: int) -> float:
"""Adam-style debias correction. Returns `1 - beta ** step`."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
'Topic :: Scientific/Engineering :: Artificial Intelligence'
]
dependencies = ["torch>=1.13"]
dependencies = ["torch>=1.13", "packaging>=21.3"]

[project.optional-dependencies]
dev = ["pytest>=7.4.3", "ruff>=0.1.3"]
Expand Down

0 comments on commit 92a640e

Please sign in to comment.