0

I use Flax to solve a neural differential equation i.e. part of my PDE is represented by a NN. Doesn't really matter, just for context. Assume we have a neural network like this

import flax.linen as nn

# External force function neural network
class ForceMLP(nn.Module):
    def setup(self):
        self.input = nn.Dense(256)
        self.dense1 = nn.Dense(256)
        self.dense2 = nn.Dense(256)
        self.output = nn.Dense(1)

    def __call__(self, t):
        x = self.input(t)
        x = nn.tanh(x)
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        x = nn.relu(x)
        F = self.output(x)

        return F

Now I would like to record the value of F so I can check how my predicted force evolves and spot potential issues but I just don't know how because JAX turns everything intro tracers.

Maybe I misread the docs but they basically just tell me "you can't"? Note that the model is implicitly used by a diffrax PDE solver:

# ODE solver using diffrax
solver = dfx.Tsit5()  # Tsitouras 5th order method
stepsize_controller = dfx.PIDController(rtol=1e-6, atol=1e-6)
term = dfx.ODETerm(equations)
saveAt = dfx.SaveAt(ts=jnp.linspace(t_start, t_end, num_points))

whereas equatiosn is a Python function implementing all the maths for the PDE and that's where I call the model.

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Browse other questions tagged or ask your own question.