1

I've ran into an issue with Jax that will make me rewrite an entire 20000-line application if I don't solve it.

I have a non-ML application which relies on pytrees to store data, and the pytrees are deep - about 6-7 layers of data storage (class1 stores class2, and that stores an array of class3 etc.)

I've used python lists to store pytrees and hoped to vmap over them, but turns out jax can't vmap over lists.

(So one solution is to rewrite literally every single dataclass to be a structured array and work from there, possibly putting all 6-7 layers of data into one mega-array)

Is there a way to avoid the rewrite? Is there a way to store pytree classes in a vmappable state so that everything works as before?

I have my classes marked with flax.struct.dataclass if that helps.

1 Answer 1

2

jax.vmap is designed to work with a struct-of-arrays pattern, and it sounds like you have an array-of-structs pattern. From your description, it sounds like you have a sequence of nested structs that look something like this:

import jax
import jax.numpy as jnp
from flax.struct import dataclass

@dataclass
class Params:
  x: jax.Array
  y: jax.Array


@dataclass
class AllParams:
  p: list[Params]


params_list = [AllParams([Params(4, 2), Params(4, 3)]),
               AllParams([Params(3, 5), Params(2, 4)]),
               AllParams([Params(3, 2), Params(6, 3)])]

Then you have a function that you want to apply to each element of the list; something like this:

def some_func(params):
  a, b = params.p
  return a.x * b.y - b.x * a.y

[some_func(params) for params in params_list]
[4, 2, -3]

But as you found, if you try to do this with vmap, you get an error:

jax.vmap(some_func)(params_list)
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

The issue is that vmap operates separately over each entry of the list or pytree you pass to it, not over the elements of the list.

To address this, you can often transform your data structure from an array-of-structs into a struct-of-arrays, and then apply vmap over this. For example:

params_array = jax.tree.map(lambda *vals: jnp.array(vals), *params_list)
print(params_array)
AllParams(p=[
  Params(x=Array([4, 3, 3], dtype=int32), y=Array([2, 5, 2], dtype=int32)),
  Params(x=Array([4, 2, 6], dtype=int32), y=Array([3, 4, 3], dtype=int32))
])

Notice that rather than a list of structures, this is now a single structure with the batching pushed all the way down to the leaves. This is the "struct-of-arrays" pattern that vmap is designed to work with, and so vmap will work correctly:

jax.vmap(some_func)(params_array)
Array([ 4,  2, -3], dtype=int32)

Now, this assumes that every dataclass in your list has identical structure: if not, then vmap will not be applicable, because by design it must map over computations with identical structure.

2
  • Right, that's what I've done too. Now, I've read about all the padding and even shape_poly, but I have the following issue: I have jnp.ndarrays of shape like this: [(50, 6, 6), (100, 6, 6), (1000, 6, 6), (1,6,6).] Array shape will be known at compile time and will remain consistently the same, but each individual array will have different ndarray. How to work with this? lax.scan(func_to_compute, list_of_arrays, length=len(list_of_array)? Is it impossible to make it paralel? I'm positive this is an answered issue, but still. (func_to_compute is pure)
    – MRiabov
    Commented Oct 27 at 12:41
  • 1
    You cannot currently use scan or vmap in such cases. They are designed for single-program-multi-data with static shapes, and your data do not meet that requirement. Dynamic shape computations are still in development and not fully supported. Your best bet would be to use a Python list comprehension – asynchronous dispatch will lead to execution in parallel where possible.
    – jakevdp
    Commented Oct 27 at 13:03

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.