I am unsure what is the best way to vectorize objects in Python Jax. In particular, I want to write a code that handles both calling a method from a single instantiation of a class and from multiple (vectorized) instantiations of the class. In the following, I write a simple example of what I would like to achieve.
import jax
import jax.numpy as jnp
import jax.random as random
class Dummy:
def __init__(self, x, key):
self.x = x
self.key = key
def to_pytree(self):
return (self.x, self.key), None
def get_noisy_x(self):
self.key, subkey = random.split(self.key)
return self.x + random.normal(subkey, self.x.shape)
@staticmethod
def from_pytree(auxiliary, pytree):
return Dummy(*pytree)
jax.tree_util.register_pytree_node(Dummy,
Dummy.to_pytree,
Dummy.from_pytree)
The class Dummy
contains some information, x
and keys
, and has a method, get_noisy_x
. The following code works as expected:
key = random.PRNGKey(0)
dummy = Dummy(jnp.array([1., 2., 3.]), key)
dummy.get_noisy_x()
I would like get_noisy_x
to work also on a vectorized version of the object Dummy
.
key = random.PRNGKey(0)
key, subkey = random.split(key)
key_batch = random.split(subkey, 100)
dummy_vmap = jax.vmap(lambda x: Dummy(jnp.array([1., 2., 3.]), x))(key_batch)
I would expect dummy_vmap
to be an array of Dummy
objects; however, instead, dummy_vmap
results to be only one Dummy
with vectorized x
and key
. This is not ideal for me because that modifies the behavior of the code. For example, if I call dummy_vmap.get_noisy_x()
, I get returned an error saying that self.key, subkey = random.split(self.key)
does not work because self.key is not a single key. While this error could be solved in several ways - and actually, in this example, vectorization is not really needed, my goal is to understand how to write code in a object-oriented way, that both handles correctly
dummy = Dummy(jnp.array([1., 2., 3.]), key)
dummy.get_noisy_x()
and
vectorized_dummy = .... ?
vectorized_dummy.get_noisy_x()
Notice that the example that I have made could work in several ways without involving vectorization. What I look for, however, is a more generic way to deal with vectorization in much more complicated scenarios.
Update
I have found out that I need to vectorize get_noisy_x as well.
dummy_vmap = jax.vmap(lambda x: Dummy(jnp.array([1., 2., 3.]), x))(key_batch)
jax.vmap(lambda self: Dummy.get_noisy_x(self))(dummy_vmap) # this function call works exactly as expected.
However, this solution seems a bit counter-intuitive, and not really scalable, as in a larger project I would need to vectorize all functions of interest.