All Questions
13 questions
1
vote
1
answer
462
views
JAX jax.grad on simple function that takes an array: `ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected`
I'm trying to implement this function and use JAX to automatically build the gradient function:
$f(x) = \sum\limits_{k=1}^{n-1} [100 (x_{k+1} - x_k^2)^2 + (1 - x_k)^2]$
(sorry, I don't know how to ...
1
vote
1
answer
155
views
Getting derivatives of NNs according to its inputs by batches in JAX
There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of ...
1
vote
2
answers
534
views
PyTorch autograd: Efficient computation of Jacobian and Jacobian-Vector-product of scalar function over range of inputs
I have function which takes 5 values as arguments and returns a scalar. That is a mapping of the form f:R^5 -> R.
Hench, its Jacobian J is a matrix with dimension (1x5) and might as well have been ...
3
votes
2
answers
1k
views
Purpose of stop gradient in `jax.nn.softmax`?
jax.nn.softmax is defined as:
def softmax(x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None)...
0
votes
1
answer
241
views
Cannot compute simple gradient of lambda function in JAX
I'm trying to compute the gradient of a lambda function that involves other gradients of functions, but the computation is hanging and I do not understand why.
In particular, the code below ...
1
vote
1
answer
492
views
JAX return types after transformations
Why is the return type of jax.grad different to other jax transformations in the following scenario?
Consider a function to be transformed by JAX which takes a custom container as an argument
import ...
1
vote
1
answer
137
views
jax empty element in the parameters of neural network
I am working on the implementation of a very small neural network. My network is as follows:
init_random_params, predict = stax.serial(
Dense(1024), Relu,
Dense(1024), Relu,
Dense(10), LogSoftmax)
I ...
3
votes
1
answer
258
views
Automatic Differentiation with respect to rank-based computations
I'm new to automatic differentiation programming, so this maybe a naive question. Below is a simplified version of what I'm trying to solve.
I have two input arrays - a vector A of size N and a matrix ...
1
vote
1
answer
709
views
How to use grad convolution in google-jax?
Thanks for reading my question!
I was just learning about custom grad functions in Jax, and I found the approach JAX took with defining custom functions is quite elegant.
One thing troubles me though.
...
2
votes
0
answers
208
views
Can OpenMDAO co-operate with autograd or jax?
Could the autograd or jax packages be used to generate the equivalent of analytic derivatives for OpenMDAO explicit components? i.e. something more accurate than finite differences (or perhaps more ...
0
votes
0
answers
666
views
Training neural network on gradient of input with pytorch
I am currently trying to train a neural network with pytorch, where I try to match the input on the input derivative. I want to do this because this is ensuring a conservative vector field. (Done in ...
3
votes
2
answers
5k
views
Conditional update in JAX?
In autograd/numpy I could do:
q[q<0] = 0.0
How can I do the same thing in JAX?
I tried import numpy as onp and using that to create arrays, but that doesn't seem to work.
4
votes
1
answer
2k
views
Efficient way to compute Jacobian x Jacobian.T
Assume J is the Jacobian of some function f with respect to some parameters. Are there efficient ways (in PyTorch or perhaps Jax) to have a function that takes two inputs (x1 and x2) and computes J(x1)...