1

I have a function which simulates a stochastic differential equation. Currently, without stochastic noise, my invokation of simulating the process up to time t looks like this (and, yeah, I need to use jax):

def evolve(u, t):
    # return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise

def simulate(x, t):
    k = jax.numpy.floor(t / dt).astype(int)
    u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt), u)

Now, the pain comes with the noise. I'm a C++-guy who only occasionally needs to use Python for research/scientific work. And I really don't understand how I need (or should) implement PRNG splitting here. I guess I would change evolve to

def evolve(u, t, key):
    noise = jax.random.multivariate_normal(key, jax.numpy.zeros(d), covariance_matrix, shape = (n,))
    # return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise

But that will not work properly I guess. If I got it right, I need to use jax.random.split to split the key. Cause if I don't, I end up with correlated samples. But how and where do I need to split?

Also: I guess I would need to modify simulate to def simulate(x, t, key). But then, should simulate also return the modified key?

And to make it even more complicated: I actually wrap simulate into a batch_simulate function which uses jax.vmap to process a whole batch of x's and t's. How do I pass the PRNG to that batch_simulate function, how do I pass it (and broadcast it) to jax.vmap and what should batch_forward return? At first glance, it seems to me that it would take a single PRNG and split it into many (due to the vmap). But what does the caller of batch_forward do then ...

Completely lost on this. Any help is highly appreciated!

1 Answer 1

1

If I understand your setup correctly, you should make both evolve and simulate accept a key, and within simulate, use fold_in to generate unique keys for the loop:

def evolve(u, t, key):
    ...

def simulate(x, t, key):
    k = jax.numpy.floor(t / dt).astype(int)
    u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt, jax.random.fold_in(key, i)), u)

Then if you want to vmap over simulate, you can split the key and map over it:

x_batch = ...  # your batched x inputs
t_batch = ...  # your batched t inputs
key_batch = jax.random.split(key, x_batch.shape[0])

batch_result = jax.vmap(simulate)(x_batch, t_batch, key_batch)
4
  • Thank you for your answer! So I don't need to update the key "inside" the loop? This is done (somehow) by fori_loop and fold_in automatically?
    – 0xbadf00d
    Commented Nov 30 at 9:14
  • I'm not sure I understand your question, but fold_in is used inside the fori_loop in order to create a new independent key at each iteration. Does that make sense?
    – jakevdp
    Commented Nov 30 at 13:26
  • What I mean is: I pass a key to simulate. Then we use jax.random.fold_in(key, i). Now what does the caller of simulate do? What happens if he uses jax.random.fold_in(key, i) again in another method? Does he need to split the key before? Also: What do we do with the return value of jax.random.fold_in? Don't we need to return that to the caller of simulate. Everything what you also wrote here is pretty clear to me, but I don't understand how the story goes on from there. How we should proceed with the keys etc.
    – 0xbadf00d
    Commented Dec 1 at 14:49
  • The caller of simulate is responsible for only using the key once. The return value of jax.random.fold_in is already used within the function, so there's nothing more that needs to be done with it.
    – jakevdp
    Commented Dec 1 at 17:05

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.