From a76fa2577215bd427bb0dea6727567d59501c60c Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 15 Feb 2024 19:20:20 -0600 Subject: [PATCH] Set foreach=False if gradient_release=True --- optimi/optimizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optimi/optimizer.py b/optimi/optimizer.py index f78409f..7f488a8 100644 --- a/optimi/optimizer.py +++ b/optimi/optimizer.py @@ -40,8 +40,11 @@ def __init__(self, params: Iterable[Tensor] | Iterable[dict], defaults: dict[str super().__init__(params, defaults) + # if gradient_release is enabled, disable foreach step so normal optimizer step won't error if self.defaults["gradient_release"]: + self.defaults["foreach"] = False for group in self.param_groups: + group["foreach"] = False for p in group["params"]: self.state[p]["group"] = group