Skip to main content
Filter by
Sorted by
Tagged with
1 vote
1 answer
24 views

jax and flax not playing nicely with each other

I want to implement a neural network with multiple LSTM gates stacked one after the other.I set the hidden states to 0, as suggested here. When I try to run the code, I get JaxTransformError: Jax ...
Dan Leonte's user avatar
1 vote
1 answer
38 views

Efficiently custom array creation routines in JAX

I'm still getting a handle of best practices in jax. My broad question is the following: What are best practices for the implementation of custom array creation routines in jax? For instance, I want ...
Ben's user avatar
  • 293
1 vote
1 answer
46 views

How to handle PRNG splitting in a jax.vmap context?

I have a function which simulates a stochastic differential equation. Currently, without stochastic noise, my invokation of simulating the process up to time t looks like this (and, yeah, I need to ...
0xbadf00d's user avatar
  • 18.2k
0 votes
0 answers
38 views

Jax DNN library initialization failed

Main issue I am trying to run a simple python test file that uses Jax with the following code: import jax key = jax.random.PRNGKey(0) print(jax.device_count()) When I try to run this I get the error ...
user28547199's user avatar
1 vote
1 answer
34 views

How to avoid "trace buffers dropped" in TensorBoard trace analysis?

I'm profiling JAX code on TPU using jax.profiler.start_trace. I tried to reduce as much as possible the duration of the trace (running on a single v5e Chip, ...) but I still have more than 3/4 of the ...
Damg's user avatar
  • 11
0 votes
1 answer
46 views

I have a problem when i install jax, jaxlib and optax

I have an issue when I install JAX and jaxlib version 0.4.23 and then install Optax. I've tried different versions, such as 0.1.5, but the GPU support no longer works, and I get this message when ...
Alessandro Castelli's user avatar
0 votes
0 answers
37 views

Very slow Quantum Neural Networks

the following code is very slow... How can I speed it up? Below you will see all the code I have written and how I used JAX and PennyLane. You will see the circuit I created and the model I built, ...
Alessandro Castelli's user avatar
0 votes
0 answers
19 views

XLA PJRT plugin on Mac reveals only 4 CPUs

On my MacOS M3, I have compiled the pjrt_c_api_cpu_plugin.so and I'm using it with JAX. My Macbook has 12 CPUs but with a simple python script "jax.devices()" the pjrt plugin reveals just 4 ...
lordav's user avatar
  • 187
0 votes
0 answers
26 views

VAE Loss Decreases But Reconstruction Doesn't Improve

I am having trouble with the reconstruction images not working at all. Here's how I'm performing a single update step. I understand that a simpler implementation is possible for a regular VAE, but due ...
m4thphobia's user avatar
0 votes
0 answers
30 views

How to set non-trainable weights?

The method keras.Model.set_weights seems to only take trainable weights. Non-trainable weights such as those from normalization layers cannot be imported this way. This is problematic, since in Keras ...
Value_Investor's user avatar
1 vote
2 answers
142 views

Hello World for jaxtyping?

I can't find any instructions or tutorials for getting started with jaxtyping. I tried the simplest possible program and it fails to parse. I'm on Python 3.11. I don't see anything on GitHub jaxtyping ...
dspyz's user avatar
  • 5,460
1 vote
1 answer
45 views

Why is JAX's jit compilation slower on the second run in my example?

I am new to using JAX, and I’m still getting familiar with how it works. From what I understand, when using Just-In-Time (JIT) compilation (jax.jit), the first execution of a function might be slower ...
kernel123's user avatar
-3 votes
1 answer
58 views

Installing jax and jaxlib for CUDA V9.1.85 [closed]

I am using Python 3.8.0 and CUDA 9.1.85 (CUDA compilation tools, release 9.1) (or not? see below). Unfortunately, I do not have administrative privileges, so I am unable to upgrade either of them. I ...
Alessandro Castelli's user avatar
1 vote
0 answers
24 views

one bug from kfac_jax

I am coding one neural network Quantum Monte Carlo software. I already finished all modules. However, when I try to use kfac_jax to optimize the parameters, I meet one bug, jaxlib.xla_extension....
Yongda Huang's user avatar
1 vote
1 answer
66 views

Spooky behaviour of JAX

This is a follow-up to my previous question. I am implementing a Parameterized Quantum Circuit as a Quantum Neural Network, where the optimization loop is jitted. Although there's no error, everything ...
Sup's user avatar
  • 331
1 vote
1 answer
95 views

Restoring flax model checkpoints using orbax throws ValueError

The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory. from flax.training import orbax_utils import orbax....
yash's user avatar
  • 380
1 vote
2 answers
123 views

Tracking test/val loss when training a model with JAX

JAX when being used for training a machine learning model, we only try to minimize the training loss. Whereas in my requirement, in order to assess the number of epochs or to avoid over-training, I ...
Sup's user avatar
  • 331
1 vote
1 answer
92 views

Storing and jax.vmap() over Pytrees

I've ran into an issue with Jax that will make me rewrite an entire 20000-line application if I don't solve it. I have a non-ML application which relies on pytrees to store data, and the pytrees are ...
MRiabov's user avatar
  • 31
1 vote
1 answer
110 views

JIT: partial or with static argnums? Non hashable input, but hashable partial

I am a bit lost on what exactly going on and what option to choose. Let's go trough an example: import jax from functools import partial from typing import List def dummy(a: int, b: List[str]): ...
Evgenii Egorov's user avatar
2 votes
1 answer
77 views

precision of JAX

I have a question regarding the precision of float in JAX. For the following code, import numpy as np import jax.numpy as jnp print('jnp.arctan(10) is:','%.60f' % jnp.arctan(10)) print('np.arctan(10) ...
funpy's user avatar
  • 45
0 votes
0 answers
147 views

How does one install an ARM build of Python on macos?

Background: I develop and run physics simulations in python using JAX, which supports GPUs, including on apple M-series machines (see jax-metal). Following the instructions there results in an error ...
aklmn's user avatar
  • 11
1 vote
1 answer
68 views

jax register_pytree_node_class and register_dataclass returns non consistent datatype: list and tuple accordingly

I am writing custom class, which is basically a wrapper around list, with custom setitem method. I would like this class participate in jax.jit code, so during that I found a following problem: during ...
Evgenii Egorov's user avatar
1 vote
0 answers
58 views

inter core interconnect Checking in simple Slices of TPU

ICI (inter core interconnects) offers a very fast connectivity with TPUs (that is connected with different hosts) and thus also increase its total available memory for TPU calculations (I guess!). ...
Krishna Mohan's user avatar
1 vote
1 answer
172 views

Batched matrix multiplication with JAX on GPU faster with larger matrices

I'm trying to perform batched matrix multiplication with JAX on GPU, and noticed that it is ~3x faster to multiply shapes (1000, 1000, 3, 35) @ (1000, 1000, 35, 1) than it is to multiply (1000, 1000, ...
Nin17's user avatar
  • 3,442
0 votes
0 answers
162 views

solving large amount (1 million) of individual small nonlinear systems of equations using JAX

I have some technical inquiries regarding the capabilities of JAX in addressing a substantial number (1 million) of individual small nonlinear systems of equations. Currently, my approach involves ...
funpy's user avatar
  • 45
0 votes
0 answers
61 views

How does the polynomial kernel in tinygp work?

I am trying to learn to use the tinygp (v 0.3.0) package (Python version 3.11.10 on macOS Sonoma 14.5) but encountering a problem with the linear kernel. I am following one of their tutorials and this ...
Alessandro Ruggieri's user avatar
2 votes
0 answers
51 views

Function initialization issue using jax for flexible variables

The first run of the function using jax requires significantly more time that subsequent calls. It appears because during the first time this function is initialized. The tricky thing is if we change ...
VGEorge's user avatar
  • 21
1 vote
1 answer
106 views

Jax tracer leaks

I have a JAX tracer leak problem when I start using vmap. I've got two functions grad_loss and its batched equivalent grad_loss_batch. @eqx.filter_value_and_grad def grad_loss(model, ti, yi): ...
thmo's user avatar
  • 199
1 vote
1 answer
37 views

Compiled JAX functions slowing down for no reason

I am using Jax for scientific computing, specifically, calculating a pairwise interaction force across all the particles in my system in parallel. This is for a simulation of the dynamics of a ...
Yigithan Gediz's user avatar
1 vote
1 answer
85 views

Computing gradient using JAX of a function that outputs a list of arrays

I have a function which returns a list of arrays, and I need to find its derivative with respect to a single parameter. For instance, let's say we have def fun(x): ... return [a,b,c] where a,b,c and ...
Physics437's user avatar
1 vote
2 answers
85 views

How to use jax.custom_vjp with functions that take non-JAX types (e.g., SymPy expressions) as inputs?

I'm trying to use JAX's custom_vjp to define custom gradient computations for a function that takes a SymPy expression as an input. However, I'm encountering errors because JAX doesn't support non-JAX ...
James Yong's user avatar
0 votes
0 answers
239 views

Kohya_ss Colab fine tune Jax error / torch compatibility

The finetune notebook worked previously with !pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html Now this no longer works. I have ...
Sarah Peterson's user avatar
1 vote
1 answer
39 views

TypeError: 'Add' object is not iterable when iterating over SymPy expressions in scqubits with JAX custom_vjp

I'm trying to modify the scqubits Python package to use JAX's custom_vjp for differentiable programming. While doing so, I encountered the following error: TypeError: 'Add' object is not iterable Here'...
James Yong's user avatar
0 votes
1 answer
51 views

How to install and use bayesnf package in windows with CPU

I want to use a python package association with this GitHub page. I installed it on my system (As suggested in the Github page using Python 3.10 version and on a new environment) which is Windows 10 ...
User's user avatar
  • 31
2 votes
1 answer
75 views

Modifying multiple dimensions of Jax array simultaneously

When using the jax_array.at[idx] function, I wish to be able to set values at both a set of specified rows and columns within the jax_array to another jax_array containing values in the same shape. ...
NotCoding's user avatar
2 votes
1 answer
75 views

Mapping Over Arrays of Functions in JAX

What is the most performant, idiomatic way of mapping over arrays of functions in JAX? Context: This GitHub issue shows a way to apply vmap to several functions using lax.switch. The example is ...
Inquisitive's user avatar
1 vote
1 answer
68 views

How to handle errors caused by non-JAX objects (such as scipy.sparse.csr_matrix) in JAX custom_vjp?

I am using JAX to implement custom backpropagation (custom_vjp), but in my function, one of the input parameters is a sparse matrix of type scipy.sparse.csr_matrix. Since JAX expects all parameters to ...
James Yong's user avatar
1 vote
1 answer
100 views

JAX TypeError: 'Device' object is not callable

I found a piece of JAX codes from few years ago. import jax import jax.random as rand device_cpu = None def do_on_cpu(f): global device_cpu if device_cpu is None: device_cpu = jax....
Raptor's user avatar
  • 54.1k
1 vote
0 answers
65 views

Forward-over-reverse mode Hessian-vector product in Jax: how smart is jax.jvp at re-using computations?

How smart is Jax at re-using intermediate computations when computing Hessian-vector products via forward-over-reverse mode automatic differentiation via jax.jvp over jax.grad? For example, something ...
Nick Alger's user avatar
  • 1,094
0 votes
0 answers
75 views

Have trouble running JAX on Metal

I tried intalling JAX for metal on a venv as mentioned on https://developer.apple.com/metal/jax/, and when I try to verify it, I get a really large error. Here's what I used to install it: python3 -m ...
bendemonium's user avatar
1 vote
1 answer
497 views

Using Jax Jit on a method as decorator versus applying jit function directly

I guess most people familiar with jax have seen this example in the documentation and know that it does not work: import jax.numpy as jnp from jax import jit class CustomClass: def __init__(self, x:...
Stackerexp's user avatar
0 votes
0 answers
60 views

JAX 3d convolution kernel speedup

I am trying to solve a diffusion kernel with JAX and this is my JAX port of existing GPU CUDA code. JAX gives the correct answer, but it is about 5x slower than CUDA. How can I speed this up further? ...
Chiel's user avatar
  • 6,204
0 votes
0 answers
63 views

Orbax checkpointing using jit

I am running a Reinforcement Learning agent implented in Jax, using jit. Following this tutorial, I have tried to implemented checkpoints of the agent's training per few training steps. However, I ...
amavrits's user avatar
1 vote
1 answer
596 views

jax library error jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed

I am working on a code that uses the Jax library, and I run into this error over and over again no matter how I tried to configure my environment: 2024-08-20 16:26:58.037892: E external/xla/xla/...
Mika Bell's user avatar
0 votes
0 answers
261 views

Advice on how to speed up jax compilation times?

I want to implement the Legendre approximation of Dahlke&Pacheco 2020 to compute the entropy of a Gaussian Mixture Model. I want to incorporate this into some deep learning I have in Jax, so I ...
LudvigH's user avatar
  • 4,645
1 vote
1 answer
92 views

transform DeviceArray into Array for jax

I have a .pkl file I downloaded from a public GitHub repository and when I read it using pickle.load, i.e. using with open('filename.pkl'), 'rb') as f: file_content = pickle.load(f) I get ...
johnhenry's user avatar
  • 1,333
0 votes
0 answers
59 views

Restoring an optimizer state across multiple devices in JAX

I'm training a model using jax and optax on four GPUs and need to save and restore the optimizer state, but I'm running into a problem loading it. The optimizer state is initialized -- optimizer = ...
Jamie Mahowald's user avatar
0 votes
0 answers
38 views

How to sample variable number of iterations of a for loop in numpyro?

I'm using numpyro to try to sample several variables in a model, one of which is the number of iterations of a for loop. I have showed an analogous toy model here. def model(): mu = 0. sigma = ...
sangeetpaul's user avatar
1 vote
1 answer
86 views

Zero length error of non-zero length array

I'm writing environment for rl agent training. My env.step method takes as action array with length 3 def scan(self, f, init, xs, length=None): if xs is None: xs = [None] * ...
dkagramanyan's user avatar
0 votes
0 answers
37 views

Flax.linen.conv unexpected behavior

I'm experiencing an unexpected output when using flax.linen.Conv. My output from conv layer has very odd stats. The mean is around 100-110 and sometimes is nan . I tested the same against TensorFlow ...
Nithish M's user avatar

1
2 3 4 5
14