All Questions
Tagged with jax parallel-processing
12 questions
0
votes
0
answers
162
views
solving large amount (1 million) of individual small nonlinear systems of equations using JAX
I have some technical inquiries regarding the capabilities of JAX in addressing a substantial number (1 million) of individual small nonlinear systems of equations. Currently, my approach involves ...
1
vote
1
answer
37
views
Compiled JAX functions slowing down for no reason
I am using Jax for scientific computing, specifically, calculating a pairwise interaction force across all the particles in my system in parallel. This is for a simulation of the dynamics of a ...
0
votes
0
answers
97
views
The low utilization of multi-GPUs in single-program multiple-data (SPMD) by implementing the auto parallelization in JAX
I try to use 4 GPUs to solve a Physic-informed neural networks (PINN) problem. And I find that when I use one GPU, the utilization of GPU can reach 100% and the training speed is high (200 [it/s]), ...
1
vote
1
answer
87
views
Parallelizing over huge for loop involving a generator
Is there a way to compute in parallel the maximum value of an array so big I need to use a generator in parallel? I am computing the maximum value of a function over a very large discrete set of ...
1
vote
0
answers
320
views
parallelization in jax when the data are more than the number of devices/cores
I am trying parallel computing in jax. However, in my case, I have 100 data but only 8 cores/devices. It looks like jax.pmap() does not support this case.
I just want to know what is the easiest way ...
1
vote
1
answer
462
views
Why matrix multiplication results with JAX are different if the data is sharded differently on the GPU
I am running a tutorial on muatrix multiplication with JAX with data sharded in different ways across multiple GPUs. I found not only the computation time is different for different way of sharding, ...
2
votes
1
answer
920
views
Parallelize with JAX over all GPU cores
In order to minimize the function x^2+y^2, I tried to implement the Adam optimizer from scratch with JAX:
@jax.jit
def fit(X, batches, params=[0.001, 0.9, 0.99, 1e-8]):
global fun
# batches ...
1
vote
0
answers
201
views
Parallel RNG with JAX sharding
What is the correct approach for generating pseudo random numbers in parallel using sharding in jax?
The following doesn't work (due to sampling the same chain)
sharding = jax.sharding....
1
vote
0
answers
339
views
Parallelism inside sequential loop with JAX
How is the following data-location parallelism translated to a per-device implementation with collective communication using jax?
import os
os.environ["XLA_FLAGS"] = (
f'--...
8
votes
1
answer
4k
views
JAX vmap vs pmap vs Python multiprocessing
I am rewriting some code from pure Python to JAX. I have gotten to the point where in my old code, I was using Python's multiprocessing module to parallelize the evaluation of a function over all of ...
1
vote
1
answer
988
views
What is the correct way to define a vectorized (jax.vmap) function in a class?
I want to add a function, which is vectorized by jax.vmap, as a class method. However, I am not sure where to define this function within the class. My main goal is to avoid, that the function is ...
0
votes
2
answers
630
views
Nested vmap in pmap - JAX
I currently can run simulations in parallel on one GPU using vmap. To speed things up, I want to batch the simulations over multiple GPU devices using pmap. However, when pmapping the vmapped function ...