Skip to content

Commit

Permalink
Set foreach=False if gradient_release=True
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Feb 16, 2024
1 parent a7c4d6d commit a76fa25
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimi/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str

super().__init__(params, defaults)

# if gradient_release is enabled, disable foreach step so normal optimizer step won't error
if self.defaults["gradient_release"]:
self.defaults["foreach"] = False
for group in self.param_groups:
group["foreach"] = False
for p in group["params"]:
self.state[p]["group"] = group

Expand Down

0 comments on commit a76fa25

Please sign in to comment.