58 questions
1
vote
1
answer
25
views
jax and flax not playing nicely with each other
I want to implement a neural network with multiple LSTM gates stacked one after the other.I set the hidden states to 0, as suggested here. When I try to run the code, I get
JaxTransformError: Jax ...
0
votes
0
answers
26
views
VAE Loss Decreases But Reconstruction Doesn't Improve
I am having trouble with the reconstruction images not working at all. Here's how I'm performing a single update step. I understand that a simpler implementation is possible for a regular VAE, but due ...
1
vote
1
answer
95
views
Restoring flax model checkpoints using orbax throws ValueError
The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory.
from flax.training import orbax_utils
import orbax....
0
votes
0
answers
59
views
Restoring an optimizer state across multiple devices in JAX
I'm training a model using jax and optax on four GPUs and need to save and restore the optimizer state, but I'm running into a problem loading it. The optimizer state is initialized --
optimizer = ...
0
votes
0
answers
37
views
Flax.linen.conv unexpected behavior
I'm experiencing an unexpected output when using flax.linen.Conv. My output from conv layer has very odd stats. The mean is around 100-110 and sometimes is nan . I tested the same against TensorFlow ...
0
votes
0
answers
30
views
Print Output Value of Neural Network in Flax
I use Flax to solve a neural differential equation i.e. part of my PDE is represented by a NN. Doesn't really matter, just for context. Assume we have a neural network like this
import flax.linen as ...
1
vote
1
answer
90
views
Why is Flax Linear layer not identical to matrix multiplication?
Due to the novelty of Flax, NNX, and JAX, there’s not a lot of resources available. I’m running into the following peculiarity:
x = jnp.random.normal((1,512), key=KEY)
layer = nnx.Linear(512, 512, ...
1
vote
0
answers
51
views
Initialize carry of RNNCells now that it is an instance function
I've been using an older version of Flax, and was able to incorporate existing RNNCells into a custom RNNCell where I can keep its carry in case I want to:
class CustomRNNCell(nn.Module):
RNNCell: ...
2
votes
1
answer
306
views
how to vmap over multiple Dense instances in flax model? trying to avoid looping over a list of Dense instances
from jax import random,vmap
from jax import numpy as jnp
import pprint
def f(s,layers,do,dx):
x = jnp.zeros((do,dx))
for i,layer in enumerate(layers):
x=x.at[i].set( layer( s[i] ) )
...
0
votes
1
answer
837
views
How to restore a orbax checkpoint with jax/flax?
I saved a orbax checkpoint with the code below:
check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp....
1
vote
1
answer
137
views
Flax neural network with nans in the outputs
I am training a neural network using Flax. My training data has a significant number of nans in the outputs. I want to ignore these and only use the non-nan values for training. To achieve this, I ...
2
votes
1
answer
186
views
AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'
I'm trying to run a model written in jax, https://github.com/lindermanlab/S5. However, I ran into some error that says
Traceback (most recent call last):
File "/Path/run_train.py", line ...
0
votes
1
answer
290
views
How can I convert a flax.linen.Module to a torch.nn.Module?
I would like to convert a flax.linen.Module, taken from here and replicated below this post, to a torch.nn.Module.
However, I find it extremely hard to figure out how I need to replace
The flax.linen....
2
votes
1
answer
667
views
Using Orbax to checkpoint flax `TrainState` with new `CheckpointManager` API
Context
The Flax docs describe how to checkpoint a flax.training.train_state.TrainState with orbax. In a nutshell, you set up a orbax.checkpoint.CheckpointManager which keeps track of checkpoints. ...
1
vote
1
answer
155
views
Getting derivatives of NNs according to its inputs by batches in JAX
There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of ...
0
votes
0
answers
110
views
@nn.compact usage in FLAX
What is the use of @nn.compact in FLAX?
class Net(nn.Module):
@nn.compact
def __call__(self, x):
self.n = 1
return x
The same can be done by using :
class Net(nn.Module):
def setup(self):
...
0
votes
0
answers
215
views
Flax JIT error for inherited nn.Module class methods
Based on this answer I am trying to make a class jit compatible by creating a pytree node, but I get:
TypeError: Cannot interpret value of type <class '__main__.TestModel'> as an abstract array; ...
0
votes
1
answer
254
views
How do you use flax.linen.checkpoint with static_argnums to allow a boolean argument to __call__?
I have a subclassed flax.linen.Module that takes a boolean argument in its __call__ method. I want to use gradient checkpointing to reduce the GPU memory footprint of this layer, so I am using flax....
0
votes
0
answers
375
views
NaN gradients in Jax softmax
I have encountered some mysterious NaN gradient issues in my training process of a model implemented through flax. With the help of jax_debug_nans, I am able to identify that it comes from the ...
0
votes
0
answers
100
views
NaN values in all parameters after training a custom RNN with Flax/JAX
I've implemented a custom RNN cell using Flax and JAX, and after training with less number of epochs, all model parameters turn to NaN. I'm seeking advice on potential causes and solutions.
params = ...
0
votes
0
answers
38
views
GPU allocation explodes when logging scalar to flax's tensorboard
I noticed that when I use flax's tensorboard from flax.metrics import tensorboard to log the loss, the GPU allocation explodes.
To compute the loss metrics I use the has_aux as explained in here.
...
1
vote
1
answer
859
views
How to use FLAX LSTM in 2023
I am wondering if anyone here knows how to get FLAX LSTM layers to work in 2023. I have tried some of the code snippets on the actual Flax documentation, such as:
https://flax.readthedocs.io/en/latest/...
1
vote
1
answer
287
views
Should models be trained using fori_loop?
When optimizing weights and biases of a model, does it make sense to replace:
for _ in range(epochs):
w, b = step(w, b)
With:
w, b = lax.fori_loop(0, epochs, lambda wb: step(wb[0], wb[1]), (w, b))...
0
votes
0
answers
394
views
conditionally call vs. don't call a function using flax.linen
Based on a boolean flag, I want to either 1) call or 2) not call the following function (which operates on a flax linen module).
def true_fn(module, carry, inputs):
carry, output = flax.linen....
1
vote
1
answer
240
views
AttributeError: module 'flax.linen' has no attribute 'transforms'
I received an error from flax 0.7.5, could u help me:
File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\__init__.py:5
from jVMC.nets.rnn import *
File ~\AppData\Roaming\Python\...
1
vote
1
answer
283
views
How to select between different function based on a value of a parameter in flax?
I am iterating through each head and applying either f1 or f2 function depending on the value of the parameter self.alpha.
I only want to evaluate either function f1 or f2 not both and then select ...
0
votes
1
answer
158
views
Prefetching an iterator of 128-dim array to device
I'm having trouble using flax.jax_utils.prefetch_to_device for the simple function below. I'm loading the SIFT 1M dataset, and converting the array to jnp array.
I then want to prefetch the iterator ...
1
vote
1
answer
1k
views
Computing the gradient of a batched function using JAX
I would need to compute the gradient of a batched function using JAX. The following is a minimal example of what I would like to do:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
...
0
votes
1
answer
238
views
Computing dot product of gradients with itself for a neural network model in JAX
I have the following piece of code of JAX with my neural network model -- model:
(loss, (inner_state, logits)), grad = jax.value_and_grad(
lambda m: forward_and_loss(m, true_gradient=True), ...
1
vote
0
answers
4k
views
How to convert .safetensors or .ckpt Files and Using in FlaxStableDiffusionImg2ImgPipeline?
I am trying to convert a .safetensors model to a diffusers model using the Python script found at https://github.com/huggingface/diffusers/blob/main/scripts/...
1
vote
1
answer
44
views
data_format in JAX/FLAX
I did not find any settings for data_format=channels_first or data_format=channels_last in FLAX modules ( which are based on JAX ).
On the contrary, TensorFlow does have that designation. Does the ...
1
vote
1
answer
90
views
Passing JAX tracers to Huggingface CLIP transformer for calculating loss
I'm working on a vision task using JAX, and I'm facing an issue with passing intermediate JAX tracer objects as images to the CLIP model for calculating the loss. The CLIP model expects NumPy arrays ...
1
vote
1
answer
63
views
The exact meaning of n_jitted_steps=5
I have tried to run the code. Here, there is a command called n_jitted_steps=5, which according to the authors, can accumulate several steps. Since the code is rather complicated, it might be ...
0
votes
0
answers
43
views
Data download issue with official Flax Image Net example
I am still trying to understand this official Flax Example. For the convenience of the experiment, I have created my own copy.
In the section on running locally, it seems that there is no download ...
0
votes
1
answer
457
views
How to unroll the training loop so that Jax can train multiple steps in GPU/TPU
When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible.
with strategy.scope():
model = create_model()
...
1
vote
1
answer
2k
views
No module named 'jax.experimental.global_device_array' when running the official Flax Example on Colab with V100
I have been trying to understand this official flax example, based on a Coalb pro+ account with V100. When I execute the command python main.py --workdir=./imagenet --config=configs/v100_x8.py , the ...
1
vote
1
answer
226
views
Fail to understand the usage of partial argument in Flax Resnet Official Example
I have been trying to understand this official example. However, I am very confused about the use of partial in two places.
For example, in line 94, we have the following:
conv = partial(self.conv, ...
1
vote
0
answers
199
views
Why run_t5_mlm_flax.py does not produces model weight file etc?
I was trying to reproduce this Hugging Face tutorial on T5-like span masked-language-modeling.
I have the following code tokenizing_and_configing.py:
import datasets
from t5_tokenizer_model import ...
1
vote
1
answer
825
views
access a submodule of a flax class/module without calling model.apply()
I have module/class of this kind:
class autoencoder(nn.Module):
hidden_dim: int
z_dim: int
output_dim: int
def setup(self):
self.encoder = encoder(self....
1
vote
0
answers
262
views
pip install gives ResolutionImpossible: the user requested flax 0.6.8 but t5x depends on flax 0.6.8
I am trying to a requirements file that depends on versions of the packages flax and t5x at specific commits.
The problem can be reproduced with the following command:
pip install "flax @ git+...
0
votes
1
answer
351
views
Getting incorrect output from the flax model's init call
I am trying to create a simple neural network using flax, as shown below.
However, the params frozen dict I receive as the output to of model.init is empty instead of having the parameters of the ...
2
votes
0
answers
432
views
How to convert flax saved checkpoints to model?
https://github.com/google-research/scenic/tree/main/scenic/projects/mbt
I am trying to use the pretrained model present in the git, which is basically a Flax Checkpoint. I want to convert it back to ...
0
votes
0
answers
294
views
Flax implementation of padding_idx from torch.nn.embedding
I have been rewriting some of my pytorch models in jax/flax and came across the issue of converting torch.nn.Embedding to flax.linen.Embed.
There does not appear to be a direct translation for pytorch'...
0
votes
0
answers
107
views
Vanishing parameters in MAML JAX (Meta Learning)
I am working on an implementation of MAML (see https://arxiv.org/pdf/1703.03400.pdf) in Jax.
When training on a distribution of simple linear regression tasks it seems to perform fine (takes a while ...
6
votes
1
answer
2k
views
AttributeError: module 'flax' has no attribute 'nn'
I'm trying to run RegNeRF, which requires flax. On installing the latest version of flax==0.6.0, I got an error stating flax has no attribute optim. This answer suggested to downgrade flax to 0.5.1. ...
-1
votes
1
answer
2k
views
Why JAX throws an unfiltered stack trace?
I need to jit the train step but when I do I get this error
import jax_resnet
import jax
import jax.numpy as jnp
from flax import linen as nn
import tensorflow_datasets as tfds
from flax.training ...
2
votes
2
answers
2k
views
Jax - vmap over batch of dataclasses
In JAX, I am looking to vmap a function over a fixed length list of dataclasses, for example:
import jax, chex
from flax import struct
@struct.dataclass
class EnvParams:
max_steps: int = 500
...
0
votes
1
answer
2k
views
How can I initialize the hidden state (carry) of a (flax linen) GRUCell as a learnable parameter (e.g. using model.init)
I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows:
import jax.numpy as np
from jax import random
import flax.linen as nn
from jax.nn import ...
4
votes
1
answer
4k
views
AttributeError: module 'flax' has no attribute 'optim'
My code is as follows:
!pip install flax
init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params']
print(f'Model parameters: {n_params(init_params):,}')
optim = flax.optim....
1
vote
1
answer
11k
views
I am trying to assign a JAX Tracer object to a NumPy array that requires concrete values - work around needed please
I am new to Jax.
I am implementing a variational autoencoder (VAE) using Jax and Flax. During training, I sample a latent code (from the distribution inferred by the encoder, which I implement using ...