I use Flax to solve a neural differential equation i.e. part of my PDE is represented by a NN. Doesn't really matter, just for context. Assume we have a neural network like this
import flax.linen as nn
# External force function neural network
class ForceMLP(nn.Module):
def setup(self):
self.input = nn.Dense(256)
self.dense1 = nn.Dense(256)
self.dense2 = nn.Dense(256)
self.output = nn.Dense(1)
def __call__(self, t):
x = self.input(t)
x = nn.tanh(x)
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
x = nn.relu(x)
F = self.output(x)
return F
Now I would like to record the value of F
so I can check how my predicted force evolves and spot potential issues but I just don't know how because JAX turns everything intro tracers.
Maybe I misread the docs but they basically just tell me "you can't"? Note that the model is implicitly used by a diffrax PDE solver:
# ODE solver using diffrax
solver = dfx.Tsit5() # Tsitouras 5th order method
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
term = dfx.ODETerm(equations)
saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))
whereas equatiosn
is a Python function implementing all the maths for the PDE and that's where I call the model.