Skip to main content

All Questions

Filter by
Sorted by
Tagged with
0 votes
0 answers
162 views

solving large amount (1 million) of individual small nonlinear systems of equations using JAX

I have some technical inquiries regarding the capabilities of JAX in addressing a substantial number (1 million) of individual small nonlinear systems of equations. Currently, my approach involves ...
funpy's user avatar
  • 45
1 vote
1 answer
37 views

Compiled JAX functions slowing down for no reason

I am using Jax for scientific computing, specifically, calculating a pairwise interaction force across all the particles in my system in parallel. This is for a simulation of the dynamics of a ...
Yigithan Gediz's user avatar
0 votes
0 answers
97 views

The low utilization of multi-GPUs in single-program multiple-data (SPMD) by implementing the auto parallelization in JAX

I try to use 4 GPUs to solve a Physic-informed neural networks (PINN) problem. And I find that when I use one GPU, the utilization of GPU can reach 100% and the training speed is high (200 [it/s]), ...
WANG Jacques's user avatar
1 vote
1 answer
87 views

Parallelizing over huge for loop involving a generator

Is there a way to compute in parallel the maximum value of an array so big I need to use a generator in parallel? I am computing the maximum value of a function over a very large discrete set of ...
Alex Albors's user avatar
1 vote
0 answers
320 views

parallelization in jax when the data are more than the number of devices/cores

I am trying parallel computing in jax. However, in my case, I have 100 data but only 8 cores/devices. It looks like jax.pmap() does not support this case. I just want to know what is the easiest way ...
Saber Artoria's user avatar
1 vote
1 answer
462 views

Why matrix multiplication results with JAX are different if the data is sharded differently on the GPU

I am running a tutorial on muatrix multiplication with JAX with data sharded in different ways across multiple GPUs. I found not only the computation time is different for different way of sharding, ...
Dmitry J's user avatar
  • 143
2 votes
1 answer
920 views

Parallelize with JAX over all GPU cores

In order to minimize the function x^2+y^2, I tried to implement the Adam optimizer from scratch with JAX: @jax.jit def fit(X, batches, params=[0.001, 0.9, 0.99, 1e-8]): global fun # batches ...
ilikenoodles's user avatar
1 vote
0 answers
201 views

Parallel RNG with JAX sharding

What is the correct approach for generating pseudo random numbers in parallel using sharding in jax? The following doesn't work (due to sampling the same chain) sharding = jax.sharding....
DavidJ's user avatar
  • 408
1 vote
0 answers
339 views

Parallelism inside sequential loop with JAX

How is the following data-location parallelism translated to a per-device implementation with collective communication using jax? import os os.environ["XLA_FLAGS"] = ( f'--...
DavidJ's user avatar
  • 408
8 votes
1 answer
4k views

JAX vmap vs pmap vs Python multiprocessing

I am rewriting some code from pure Python to JAX. I have gotten to the point where in my old code, I was using Python's multiprocessing module to parallelize the evaluation of a function over all of ...
Jim Raynor's user avatar
1 vote
1 answer
988 views

What is the correct way to define a vectorized (jax.vmap) function in a class?

I want to add a function, which is vectorized by jax.vmap, as a class method. However, I am not sure where to define this function within the class. My main goal is to avoid, that the function is ...
yuki's user avatar
  • 775
0 votes
2 answers
630 views

Nested vmap in pmap - JAX

I currently can run simulations in parallel on one GPU using vmap. To speed things up, I want to batch the simulations over multiple GPU devices using pmap. However, when pmapping the vmapped function ...
Anton B's user avatar
  • 43