Skip to main content

All Questions

Filter by
Sorted by
Tagged with
1 vote
1 answer
155 views

Getting derivatives of NNs according to its inputs by batches in JAX

There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of ...
Michaela's user avatar
3 votes
2 answers
1k views

Purpose of stop gradient in `jax.nn.softmax`?

jax.nn.softmax is defined as: def softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None)...
Jay Mody's user avatar
  • 4,013