698 questions
1
vote
1
answer
24
views
jax and flax not playing nicely with each other
I want to implement a neural network with multiple LSTM gates stacked one after the other.I set the hidden states to 0, as suggested here. When I try to run the code, I get
JaxTransformError: Jax ...
1
vote
1
answer
38
views
Efficiently custom array creation routines in JAX
I'm still getting a handle of best practices in jax. My broad question is the following:
What are best practices for the implementation of custom array creation routines in jax?
For instance, I want ...
1
vote
1
answer
46
views
How to handle PRNG splitting in a jax.vmap context?
I have a function which simulates a stochastic differential equation. Currently, without stochastic noise, my invokation of simulating the process up to time t looks like this (and, yeah, I need to ...
0
votes
0
answers
38
views
Jax DNN library initialization failed
Main issue
I am trying to run a simple python test file that uses Jax with the following code:
import jax
key = jax.random.PRNGKey(0)
print(jax.device_count())
When I try to run this I get the error
...
1
vote
1
answer
34
views
How to avoid "trace buffers dropped" in TensorBoard trace analysis?
I'm profiling JAX code on TPU using jax.profiler.start_trace. I tried to reduce as much as possible the duration of the trace (running on a single v5e Chip, ...) but I still have more than 3/4 of the ...
0
votes
1
answer
46
views
I have a problem when i install jax, jaxlib and optax
I have an issue when I install JAX and jaxlib version 0.4.23 and then install Optax. I've tried different versions, such as 0.1.5, but the GPU support no longer works, and I get this message when ...
0
votes
0
answers
37
views
Very slow Quantum Neural Networks
the following code is very slow...
How can I speed it up?
Below you will see all the code I have written and how I used JAX and PennyLane. You will see the circuit I created and the model I built, ...
0
votes
0
answers
19
views
XLA PJRT plugin on Mac reveals only 4 CPUs
On my MacOS M3, I have compiled the pjrt_c_api_cpu_plugin.so and I'm using it with JAX.
My Macbook has 12 CPUs but with a simple python script "jax.devices()" the pjrt plugin reveals just 4 ...
0
votes
0
answers
26
views
VAE Loss Decreases But Reconstruction Doesn't Improve
I am having trouble with the reconstruction images not working at all. Here's how I'm performing a single update step. I understand that a simpler implementation is possible for a regular VAE, but due ...
0
votes
0
answers
30
views
How to set non-trainable weights?
The method keras.Model.set_weights seems to only take trainable weights. Non-trainable weights such as those from normalization layers cannot be imported this way. This is problematic, since in Keras ...
1
vote
2
answers
142
views
Hello World for jaxtyping?
I can't find any instructions or tutorials for getting started with jaxtyping. I tried the simplest possible program and it fails to parse. I'm on Python 3.11. I don't see anything on GitHub jaxtyping ...
1
vote
1
answer
45
views
Why is JAX's jit compilation slower on the second run in my example?
I am new to using JAX, and I’m still getting familiar with how it works. From what I understand, when using Just-In-Time (JIT) compilation (jax.jit), the first execution of a function might be slower ...
-3
votes
1
answer
58
views
Installing jax and jaxlib for CUDA V9.1.85 [closed]
I am using Python 3.8.0 and CUDA 9.1.85 (CUDA compilation tools, release 9.1) (or not? see below). Unfortunately, I do not have administrative privileges, so I am unable to upgrade either of them.
I ...
1
vote
0
answers
24
views
one bug from kfac_jax
I am coding one neural network Quantum Monte Carlo software. I already finished all modules. However, when I try to use kfac_jax to optimize the parameters, I meet one bug,
jaxlib.xla_extension....
1
vote
1
answer
66
views
Spooky behaviour of JAX
This is a follow-up to my previous question. I am implementing a Parameterized Quantum Circuit as a Quantum Neural Network, where the optimization loop is jitted. Although there's no error, everything ...
1
vote
1
answer
95
views
Restoring flax model checkpoints using orbax throws ValueError
The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory.
from flax.training import orbax_utils
import orbax....
1
vote
2
answers
123
views
Tracking test/val loss when training a model with JAX
JAX when being used for training a machine learning model, we only try to minimize the training loss.
Whereas in my requirement, in order to assess the number of epochs or to avoid over-training, I ...
1
vote
1
answer
92
views
Storing and jax.vmap() over Pytrees
I've ran into an issue with Jax that will make me rewrite an entire 20000-line application if I don't solve it.
I have a non-ML application which relies on pytrees to store data, and the pytrees are ...
1
vote
1
answer
110
views
JIT: partial or with static argnums? Non hashable input, but hashable partial
I am a bit lost on what exactly going on and what option to choose. Let's go trough an example:
import jax
from functools import partial
from typing import List
def dummy(a: int, b: List[str]):
...
2
votes
1
answer
77
views
precision of JAX
I have a question regarding the precision of float in JAX. For the following code,
import numpy as np
import jax.numpy as jnp
print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10))
print('np.arctan(10) ...
0
votes
0
answers
147
views
How does one install an ARM build of Python on macos?
Background: I develop and run physics simulations in python using JAX, which supports GPUs, including on apple M-series machines (see jax-metal).
Following the instructions there results in an error ...
1
vote
1
answer
68
views
jax register_pytree_node_class and register_dataclass returns non consistent datatype: list and tuple accordingly
I am writing custom class, which is basically a wrapper around list, with custom setitem method. I would like this class participate in jax.jit code, so during that I found a following problem: during ...
1
vote
0
answers
58
views
inter core interconnect Checking in simple Slices of TPU
ICI (inter core interconnects) offers a very fast connectivity with TPUs (that is connected with different hosts) and thus also increase its total available memory for TPU calculations (I guess!). ...
1
vote
1
answer
172
views
Batched matrix multiplication with JAX on GPU faster with larger matrices
I'm trying to perform batched matrix multiplication with JAX on GPU, and noticed that it is ~3x faster to multiply shapes (1000, 1000, 3, 35) @ (1000, 1000, 35, 1) than it is to multiply (1000, 1000, ...
0
votes
0
answers
162
views
solving large amount (1 million) of individual small nonlinear systems of equations using JAX
I have some technical inquiries regarding the capabilities of JAX in addressing a substantial number (1 million) of individual small nonlinear systems of equations. Currently, my approach involves ...
0
votes
0
answers
61
views
How does the polynomial kernel in tinygp work?
I am trying to learn to use the tinygp (v 0.3.0) package (Python version 3.11.10 on macOS Sonoma 14.5) but encountering a problem with the linear kernel. I am following one of their tutorials and this ...
2
votes
0
answers
51
views
Function initialization issue using jax for flexible variables
The first run of the function using jax requires significantly more time that subsequent calls. It appears because during the first time this function is initialized. The tricky thing is if we change ...
1
vote
1
answer
106
views
Jax tracer leaks
I have a JAX tracer leak problem when I start using vmap. I've got two functions grad_loss and its batched equivalent grad_loss_batch.
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
...
1
vote
1
answer
37
views
Compiled JAX functions slowing down for no reason
I am using Jax for scientific computing, specifically, calculating a pairwise interaction force across all the particles in my system in parallel. This is for a simulation of the dynamics of a ...
1
vote
1
answer
85
views
Computing gradient using JAX of a function that outputs a list of arrays
I have a function which returns a list of arrays, and I need to find its derivative with respect to a single parameter. For instance, let's say we have
def fun(x):
...
return [a,b,c]
where a,b,c and ...
1
vote
2
answers
85
views
How to use jax.custom_vjp with functions that take non-JAX types (e.g., SymPy expressions) as inputs?
I'm trying to use JAX's custom_vjp to define custom gradient computations for a function that takes a SymPy expression as an input. However, I'm encountering errors because JAX doesn't support non-JAX ...
0
votes
0
answers
239
views
Kohya_ss Colab fine tune Jax error / torch compatibility
The finetune notebook worked previously with !pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Now this no longer works. I have ...
1
vote
1
answer
39
views
TypeError: 'Add' object is not iterable when iterating over SymPy expressions in scqubits with JAX custom_vjp
I'm trying to modify the scqubits Python package to use JAX's custom_vjp for differentiable programming. While doing so, I encountered the following error:
TypeError: 'Add' object is not iterable
Here'...
0
votes
1
answer
51
views
How to install and use bayesnf package in windows with CPU
I want to use a python package association with this GitHub page. I installed it on my system (As suggested in the Github page using Python 3.10 version and on a new environment) which is Windows 10 ...
2
votes
1
answer
75
views
Modifying multiple dimensions of Jax array simultaneously
When using the jax_array.at[idx] function, I wish to be able to set values at both a set of specified rows and columns within the jax_array to another jax_array containing values in the same shape. ...
2
votes
1
answer
75
views
Mapping Over Arrays of Functions in JAX
What is the most performant, idiomatic way of mapping over arrays of functions in JAX?
Context: This GitHub issue shows a way to apply vmap to several functions using lax.switch. The example is ...
1
vote
1
answer
68
views
How to handle errors caused by non-JAX objects (such as scipy.sparse.csr_matrix) in JAX custom_vjp?
I am using JAX to implement custom backpropagation (custom_vjp), but in my function, one of the input parameters is a sparse matrix of type scipy.sparse.csr_matrix. Since JAX expects all parameters to ...
1
vote
1
answer
100
views
JAX TypeError: 'Device' object is not callable
I found a piece of JAX codes from few years ago.
import jax
import jax.random as rand
device_cpu = None
def do_on_cpu(f):
global device_cpu
if device_cpu is None:
device_cpu = jax....
1
vote
0
answers
65
views
Forward-over-reverse mode Hessian-vector product in Jax: how smart is jax.jvp at re-using computations?
How smart is Jax at re-using intermediate computations when computing Hessian-vector products via forward-over-reverse mode automatic differentiation via jax.jvp over jax.grad?
For example, something ...
0
votes
0
answers
75
views
Have trouble running JAX on Metal
I tried intalling JAX for metal on a venv as mentioned on https://developer.apple.com/metal/jax/, and when I try to verify it, I get a really large error.
Here's what I used to install it:
python3 -m ...
1
vote
1
answer
497
views
Using Jax Jit on a method as decorator versus applying jit function directly
I guess most people familiar with jax have seen this example in the documentation and know that it does not work:
import jax.numpy as jnp
from jax import jit
class CustomClass:
def __init__(self, x:...
0
votes
0
answers
60
views
JAX 3d convolution kernel speedup
I am trying to solve a diffusion kernel with JAX and this is my JAX port of existing GPU CUDA code. JAX gives the correct answer, but it is about 5x slower than CUDA. How can I speed this up further? ...
0
votes
0
answers
63
views
Orbax checkpointing using jit
I am running a Reinforcement Learning agent implented in Jax, using jit. Following this tutorial, I have tried to implemented checkpoints of the agent's training per few training steps.
However, I ...
1
vote
1
answer
596
views
jax library error jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed
I am working on a code that uses the Jax library, and I run into this error over and over again no matter how I tried to configure my environment:
2024-08-20 16:26:58.037892: E external/xla/xla/...
0
votes
0
answers
261
views
Advice on how to speed up jax compilation times?
I want to implement the Legendre approximation of Dahlke&Pacheco 2020 to compute the entropy of a Gaussian Mixture Model.
I want to incorporate this into some deep learning I have in Jax, so I ...
1
vote
1
answer
92
views
transform DeviceArray into Array for jax
I have a .pkl file I downloaded from a public GitHub repository and when I read it using pickle.load, i.e. using
with open('filename.pkl'), 'rb') as f:
file_content = pickle.load(f)
I get ...
0
votes
0
answers
59
views
Restoring an optimizer state across multiple devices in JAX
I'm training a model using jax and optax on four GPUs and need to save and restore the optimizer state, but I'm running into a problem loading it. The optimizer state is initialized --
optimizer = ...
0
votes
0
answers
38
views
How to sample variable number of iterations of a for loop in numpyro?
I'm using numpyro to try to sample several variables in a model, one of which is the number of iterations of a for loop. I have showed an analogous toy model here.
def model():
mu = 0.
sigma = ...
1
vote
1
answer
86
views
Zero length error of non-zero length array
I'm writing environment for rl agent training.
My env.step method takes as action array with length 3
def scan(self, f, init, xs, length=None):
if xs is None:
xs = [None] * ...
0
votes
0
answers
37
views
Flax.linen.conv unexpected behavior
I'm experiencing an unexpected output when using flax.linen.Conv. My output from conv layer has very odd stats. The mean is around 100-110 and sometimes is nan . I tested the same against TensorFlow ...