Skip to main content
Filter by
Sorted by
Tagged with
1 vote
1 answer

jax and flax not playing nicely with each other

I want to implement a neural network with multiple LSTM gates stacked one after the other.I set the hidden states to 0, as suggested here. When I try to run the code, I get JaxTransformError: Jax ...
Dan Leonte's user avatar
0 votes
0 answers

VAE Loss Decreases But Reconstruction Doesn't Improve

I am having trouble with the reconstruction images not working at all. Here's how I'm performing a single update step. I understand that a simpler implementation is possible for a regular VAE, but due ...
m4thphobia's user avatar
1 vote
1 answer

Restoring flax model checkpoints using orbax throws ValueError

The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory. from import orbax_utils import orbax....
yash's user avatar
  • 380
0 votes
0 answers

Restoring an optimizer state across multiple devices in JAX

I'm training a model using jax and optax on four GPUs and need to save and restore the optimizer state, but I'm running into a problem loading it. The optimizer state is initialized -- optimizer = ...
Jamie Mahowald's user avatar
0 votes
0 answers

Flax.linen.conv unexpected behavior

I'm experiencing an unexpected output when using flax.linen.Conv. My output from conv layer has very odd stats. The mean is around 100-110 and sometimes is nan . I tested the same against TensorFlow ...
Nithish M's user avatar
0 votes
0 answers

Print Output Value of Neural Network in Flax

I use Flax to solve a neural differential equation i.e. part of my PDE is represented by a NN. Doesn't really matter, just for context. Assume we have a neural network like this import flax.linen as ...
xotix's user avatar
  • 520
1 vote
1 answer

Why is Flax Linear layer not identical to matrix multiplication?

Due to the novelty of Flax, NNX, and JAX, there’s not a lot of resources available. I’m running into the following peculiarity: x = jnp.random.normal((1,512), key=KEY) layer = nnx.Linear(512, 512, ...
Value_Investor's user avatar
1 vote
0 answers

Initialize carry of RNNCells now that it is an instance function

I've been using an older version of Flax, and was able to incorporate existing RNNCells into a custom RNNCell where I can keep its carry in case I want to: class CustomRNNCell(nn.Module): RNNCell: ...
Tusike's user avatar
  • 31
2 votes
1 answer

how to vmap over multiple Dense instances in flax model? trying to avoid looping over a list of Dense instances

from jax import random,vmap from jax import numpy as jnp import pprint def f(s,layers,do,dx): x = jnp.zeros((do,dx)) for i,layer in enumerate(layers):[i].set( layer( s[i] ) ) ...
user137146's user avatar
0 votes
1 answer

How to restore a orbax checkpoint with jax/flax?

I saved a orbax checkpoint with the code below: check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True) check_path = Path(os.getcwd(), out_dir, 'checkpoint') checkpoint_manager = ocp....
Dmitry J's user avatar
  • 143
1 vote
1 answer

Flax neural network with nans in the outputs

I am training a neural network using Flax. My training data has a significant number of nans in the outputs. I want to ignore these and only use the non-nan values for training. To achieve this, I ...
rhombidodecahedron's user avatar
2 votes
1 answer

AttributeError: module 'flax.traverse_util' has no attribute 'unfreeze'

I'm trying to run a model written in jax, However, I ran into some error that says Traceback (most recent call last): File "/Path/", line ...
WillWu's user avatar
  • 97
0 votes
1 answer

How can I convert a flax.linen.Module to a torch.nn.Module?

I would like to convert a flax.linen.Module, taken from here and replicated below this post, to a torch.nn.Module. However, I find it extremely hard to figure out how I need to replace The flax.linen....
0xbadf00d's user avatar
  • 18.2k
2 votes
1 answer

Using Orbax to checkpoint flax `TrainState` with new `CheckpointManager` API

Context The Flax docs describe how to checkpoint a with orbax. In a nutshell, you set up a orbax.checkpoint.CheckpointManager which keeps track of checkpoints. ...
Hylke's user avatar
  • 117
1 vote
1 answer

Getting derivatives of NNs according to its inputs by batches in JAX

There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of ...
Michaela's user avatar
0 votes
0 answers

@nn.compact usage in FLAX

What is the use of @nn.compact in FLAX? class Net(nn.Module): @nn.compact def __call__(self, x): self.n = 1 return x The same can be done by using : class Net(nn.Module): def setup(self): ...
akhil's user avatar
  • 37
0 votes
0 answers

Flax JIT error for inherited nn.Module class methods

Based on this answer I am trying to make a class jit compatible by creating a pytree node, but I get: TypeError: Cannot interpret value of type <class '__main__.TestModel'> as an abstract array; ...
Momo's user avatar
  • 960
0 votes
1 answer

How do you use flax.linen.checkpoint with static_argnums to allow a boolean argument to __call__?

I have a subclassed flax.linen.Module that takes a boolean argument in its __call__ method. I want to use gradient checkpointing to reduce the GPU memory footprint of this layer, so I am using flax....
Ian Holmes's user avatar
0 votes
0 answers

NaN gradients in Jax softmax

I have encountered some mysterious NaN gradient issues in my training process of a model implemented through flax. With the help of jax_debug_nans, I am able to identify that it comes from the ...
Will's user avatar
  • 101
0 votes
0 answers

NaN values in all parameters after training a custom RNN with Flax/JAX

I've implemented a custom RNN cell using Flax and JAX, and after training with less number of epochs, all model parameters turn to NaN. I'm seeking advice on potential causes and solutions. params = ...
PHANEE CHOWDARY's user avatar
0 votes
0 answers

GPU allocation explodes when logging scalar to flax's tensorboard

I noticed that when I use flax's tensorboard from flax.metrics import tensorboard to log the loss, the GPU allocation explodes. To compute the loss metrics I use the has_aux as explained in here. ...
Chutlhu's user avatar
  • 357
1 vote
1 answer

How to use FLAX LSTM in 2023

I am wondering if anyone here knows how to get FLAX LSTM layers to work in 2023. I have tried some of the code snippets on the actual Flax documentation, such as:
al_cc's user avatar
  • 105
1 vote
1 answer

Should models be trained using fori_loop?

When optimizing weights and biases of a model, does it make sense to replace: for _ in range(epochs): w, b = step(w, b) With: w, b = lax.fori_loop(0, epochs, lambda wb: step(wb[0], wb[1]), (w, b))...
ldmat's user avatar
  • 1,041
0 votes
0 answers

conditionally call vs. don't call a function using flax.linen

Based on a boolean flag, I want to either 1) call or 2) not call the following function (which operates on a flax linen module). def true_fn(module, carry, inputs): carry, output = flax.linen....
Jabby's user avatar
  • 63
1 vote
1 answer

AttributeError: module 'flax.linen' has no attribute 'transforms'

I received an error from flax 0.7.5, could u help me: File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\ from jVMC.nets.rnn import * File ~\AppData\Roaming\Python\...
Jiayu Weng's user avatar
1 vote
1 answer

How to select between different function based on a value of a parameter in flax?

I am iterating through each head and applying either f1 or f2 function depending on the value of the parameter self.alpha. I only want to evaluate either function f1 or f2 not both and then select ...
Naren Dhyani's user avatar
0 votes
1 answer

Prefetching an iterator of 128-dim array to device

I'm having trouble using flax.jax_utils.prefetch_to_device for the simple function below. I'm loading the SIFT 1M dataset, and converting the array to jnp array. I then want to prefetch the iterator ...
jeffreyveon's user avatar
  • 13.8k
1 vote
1 answer

Computing the gradient of a batched function using JAX

I would need to compute the gradient of a batched function using JAX. The following is a minimal example of what I would like to do: import jax import jax.numpy as jnp import matplotlib.pyplot as plt ...
al_cc's user avatar
  • 105
0 votes
1 answer

Computing dot product of gradients with itself for a neural network model in JAX

I have the following piece of code of JAX with my neural network model -- model: (loss, (inner_state, logits)), grad = jax.value_and_grad( lambda m: forward_and_loss(m, true_gradient=True), ...
abc's user avatar
  • 221
1 vote
0 answers

How to convert .safetensors or .ckpt Files and Using in FlaxStableDiffusionImg2ImgPipeline?

I am trying to convert a .safetensors model to a diffusers model using the Python script found at
Aero Wang's user avatar
  • 9,177
1 vote
1 answer

data_format in JAX/FLAX

I did not find any settings for data_format=channels_first or data_format=channels_last in FLAX modules ( which are based on JAX ). On the contrary, TensorFlow does have that designation. Does the ...
ujjwalnur's user avatar
1 vote
1 answer

Passing JAX tracers to Huggingface CLIP transformer for calculating loss

I'm working on a vision task using JAX, and I'm facing an issue with passing intermediate JAX tracer objects as images to the CLIP model for calculating the loss. The CLIP model expects NumPy arrays ...
Kian's user avatar
  • 15
1 vote
1 answer

The exact meaning of n_jitted_steps=5

I have tried to run the code. Here, there is a command called n_jitted_steps=5, which according to the authors, can accumulate several steps. Since the code is rather complicated, it might be ...
RanWang's user avatar
  • 320
0 votes
0 answers

Data download issue with official Flax Image Net example

I am still trying to understand this official Flax Example. For the convenience of the experiment, I have created my own copy. In the section on running locally, it seems that there is no download ...
RanWang's user avatar
  • 320
0 votes
1 answer

How to unroll the training loop so that Jax can train multiple steps in GPU/TPU

When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible. with strategy.scope(): model = create_model() ...
RanWang's user avatar
  • 320
1 vote
1 answer

No module named 'jax.experimental.global_device_array' when running the official Flax Example on Colab with V100

I have been trying to understand this official flax example, based on a Coalb pro+ account with V100. When I execute the command python --workdir=./imagenet --config=configs/ , the ...
RanWang's user avatar
  • 320
1 vote
1 answer

Fail to understand the usage of partial argument in Flax Resnet Official Example

I have been trying to understand this official example. However, I am very confused about the use of partial in two places. For example, in line 94, we have the following: conv = partial(self.conv, ...
RanWang's user avatar
  • 320
1 vote
0 answers

Why does not produces model weight file etc?

I was trying to reproduce this Hugging Face tutorial on T5-like span masked-language-modeling. I have the following code import datasets from t5_tokenizer_model import ...
littleworth's user avatar
  • 5,169
1 vote
1 answer

access a submodule of a flax class/module without calling model.apply()

I have module/class of this kind: class autoencoder(nn.Module): hidden_dim: int z_dim: int output_dim: int def setup(self): self.encoder = encoder(self....
Jabby's user avatar
  • 63
1 vote
0 answers

pip install gives ResolutionImpossible: the user requested flax 0.6.8 but t5x depends on flax 0.6.8

I am trying to a requirements file that depends on versions of the packages flax and t5x at specific commits. The problem can be reproduced with the following command: pip install "flax @ git+...
BioGeek's user avatar
  • 22.8k
0 votes
1 answer

Getting incorrect output from the flax model's init call

I am trying to create a simple neural network using flax, as shown below. However, the params frozen dict I receive as the output to of model.init is empty instead of having the parameters of the ...
Bunny Rabbit's user avatar
  • 8,401
2 votes
0 answers

How to convert flax saved checkpoints to model? I am trying to use the pretrained model present in the git, which is basically a Flax Checkpoint. I want to convert it back to ...
Verma Sushant's user avatar
0 votes
0 answers

Flax implementation of padding_idx from torch.nn.embedding

I have been rewriting some of my pytorch models in jax/flax and came across the issue of converting torch.nn.Embedding to flax.linen.Embed. There does not appear to be a direct translation for pytorch'...
d_a_science's user avatar
0 votes
0 answers

Vanishing parameters in MAML JAX (Meta Learning)

I am working on an implementation of MAML (see in Jax. When training on a distribution of simple linear regression tasks it seems to perform fine (takes a while ...
Sefton de Pledge's user avatar
6 votes
1 answer

AttributeError: module 'flax' has no attribute 'nn'

I'm trying to run RegNeRF, which requires flax. On installing the latest version of flax==0.6.0, I got an error stating flax has no attribute optim. This answer suggested to downgrade flax to 0.5.1. ...
Nagabhushan S N's user avatar
-1 votes
1 answer

Why JAX throws an unfiltered stack trace?

I need to jit the train step but when I do I get this error import jax_resnet import jax import jax.numpy as jnp from flax import linen as nn import tensorflow_datasets as tfds from ...
Christopher Rae's user avatar
2 votes
2 answers

Jax - vmap over batch of dataclasses

In JAX, I am looking to vmap a function over a fixed length list of dataclasses, for example: import jax, chex from flax import struct @struct.dataclass class EnvParams: max_steps: int = 500 ...
EmptyJackson's user avatar
0 votes
1 answer

How can I initialize the hidden state (carry) of a (flax linen) GRUCell as a learnable parameter (e.g. using model.init)

I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows: import jax.numpy as np from jax import random import flax.linen as nn from jax.nn import ...
Jabby's user avatar
  • 63
4 votes
1 answer

AttributeError: module 'flax' has no attribute 'optim'

My code is as follows: !pip install flax init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params'] print(f'Model parameters: {n_params(init_params):,}') optim = flax.optim....
Md Tawsif Mostafiz 170021031's user avatar
1 vote
1 answer

I am trying to assign a JAX Tracer object to a NumPy array that requires concrete values - work around needed please

I am new to Jax. I am implementing a variational autoencoder (VAE) using Jax and Flax. During training, I sample a latent code (from the distribution inferred by the encoder, which I implement using ...
Jabby's user avatar
  • 63