I have a function which simulates a stochastic differential equation. Currently, without stochastic noise, my invokation of simulating the process up to time t
looks like this (and, yeah, I need to use jax):
def evolve(u, t):
# return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise
def simulate(x, t):
k = jax.numpy.floor(t / dt).astype(int)
u = jax.lax.fori_loop(0, k, lambda i, u : evolve(u, i * dt), u)
Now, the pain comes with the noise. I'm a C++-guy who only occasionally needs to use Python for research/scientific work. And I really don't understand how I need (or should) implement PRNG splitting here. I guess I would change evolve
to
def evolve(u, t, key):
noise = jax.random.multivariate_normal(key, jax.numpy.zeros(d), covariance_matrix, shape = (n,))
# return u + dt * b(t, u) + sigma(t, u) * sqrt_dt * noise
But that will not work properly I guess. If I got it right, I need to use jax.random.split
to split the key
. Cause if I don't, I end up with correlated samples. But how and where do I need to split?
Also: I guess I would need to modify simulate
to def simulate(x, t, key)
. But then, should simulate
also return the modified key
?
And to make it even more complicated: I actually wrap simulate
into a batch_simulate
function which uses jax.vmap
to process a whole batch of x
's and t
's. How do I pass the PRNG to that batch_simulate
function, how do I pass it (and broadcast it) to jax.vmap
and what should batch_forward
return? At first glance, it seems to me that it would take a single PRNG and split it into many (due to the vmap
). But what does the caller of batch_forward
do then ...
Completely lost on this. Any help is highly appreciated!