All Questions
2 questions
1
vote
1
answer
155
views
Getting derivatives of NNs according to its inputs by batches in JAX
There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of ...
3
votes
2
answers
1k
views
Purpose of stop gradient in `jax.nn.softmax`?
jax.nn.softmax is defined as:
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None)...