jax.vmap
is designed to work with a struct-of-arrays pattern, and it sounds like you have an array-of-structs pattern. From your description, it sounds like you have a sequence of nested structs that look something like this:
import jax
import jax.numpy as jnp
from flax.struct import dataclass
@dataclass
class Params:
x: jax.Array
y: jax.Array
@dataclass
class AllParams:
p: list[Params]
params_list = [AllParams([Params(4, 2), Params(4, 3)]),
AllParams([Params(3, 5), Params(2, 4)]),
AllParams([Params(3, 2), Params(6, 3)])]
Then you have a function that you want to apply to each element of the list; something like this:
def some_func(params):
a, b = params.p
return a.x * b.y - b.x * a.y
[some_func(params) for params in params_list]
[4, 2, -3]
But as you found, if you try to do this with vmap
, you get an error:
jax.vmap(some_func)(params_list)
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
The issue is that vmap
operates separately over each entry of the list or pytree you pass to it, not over the elements of the list.
To address this, you can often transform your data structure from an array-of-structs into a struct-of-arrays, and then apply vmap
over this. For example:
params_array = jax.tree.map(lambda *vals: jnp.array(vals), *params_list)
print(params_array)
AllParams(p=[
Params(x=Array([4, 3, 3], dtype=int32), y=Array([2, 5, 2], dtype=int32)),
Params(x=Array([4, 2, 6], dtype=int32), y=Array([3, 4, 3], dtype=int32))
])
Notice that rather than a list of structures, this is now a single structure with the batching pushed all the way down to the leaves. This is the "struct-of-arrays" pattern that vmap
is designed to work with, and so vmap
will work correctly:
jax.vmap(some_func)(params_array)
Array([ 4, 2, -3], dtype=int32)
Now, this assumes that every dataclass in your list has identical structure: if not, then vmap
will not be applicable, because by design it must map over computations with identical structure.