Skip to main content

All Questions

Tagged with
Filter by
Sorted by
Tagged with
1 vote
1 answer
462 views

JAX jax.grad on simple function that takes an array: `ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected`

I'm trying to implement this function and use JAX to automatically build the gradient function: $f(x) = \sum\limits_{k=1}^{n-1} [100 (x_{k+1} - x_k^2)^2 + (1 - x_k)^2]$ (sorry, I don't know how to ...
clay's user avatar
  • 20.3k
1 vote
1 answer
155 views

Getting derivatives of NNs according to its inputs by batches in JAX

There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of ...
Michaela's user avatar
1 vote
2 answers
534 views

PyTorch autograd: Efficient computation of Jacobian and Jacobian-Vector-product of scalar function over range of inputs

I have function which takes 5 values as arguments and returns a scalar. That is a mapping of the form f:R^5 -> R. Hench, its Jacobian J is a matrix with dimension (1x5) and might as well have been ...
Landscape's user avatar
  • 257
3 votes
2 answers
1k views

Purpose of stop gradient in `jax.nn.softmax`?

jax.nn.softmax is defined as: def softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None)...
Jay Mody's user avatar
  • 4,013
0 votes
1 answer
241 views

Cannot compute simple gradient of lambda function in JAX

I'm trying to compute the gradient of a lambda function that involves other gradients of functions, but the computation is hanging and I do not understand why. In particular, the code below ...
Marco's user avatar
  • 25
1 vote
1 answer
492 views

JAX return types after transformations

Why is the return type of jax.grad different to other jax transformations in the following scenario? Consider a function to be transformed by JAX which takes a custom container as an argument import ...
DavidJ's user avatar
  • 408
1 vote
1 answer
137 views

jax empty element in the parameters of neural network

I am working on the implementation of a very small neural network. My network is as follows: init_random_params, predict = stax.serial( Dense(1024), Relu, Dense(1024), Relu, Dense(10), LogSoftmax) I ...
MMH's user avatar
  • 111
3 votes
1 answer
258 views

Automatic Differentiation with respect to rank-based computations

I'm new to automatic differentiation programming, so this maybe a naive question. Below is a simplified version of what I'm trying to solve. I have two input arrays - a vector A of size N and a matrix ...
P JMU's user avatar
  • 69
1 vote
1 answer
709 views

How to use grad convolution in google-jax?

Thanks for reading my question! I was just learning about custom grad functions in Jax, and I found the approach JAX took with defining custom functions is quite elegant. One thing troubles me though. ...
TIM's user avatar
  • 125
2 votes
0 answers
208 views

Can OpenMDAO co-operate with autograd or jax?

Could the autograd or jax packages be used to generate the equivalent of analytic derivatives for OpenMDAO explicit components? i.e. something more accurate than finite differences (or perhaps more ...
Jacob Schwartz's user avatar
0 votes
0 answers
666 views

Training neural network on gradient of input with pytorch

I am currently trying to train a neural network with pytorch, where I try to match the input on the input derivative. I want to do this because this is ensuring a conservative vector field. (Done in ...
not_converging's user avatar
3 votes
2 answers
5k views

Conditional update in JAX?

In autograd/numpy I could do: q[q<0] = 0.0 How can I do the same thing in JAX? I tried import numpy as onp and using that to create arrays, but that doesn't seem to work.
Andriy Drozdyuk's user avatar
4 votes
1 answer
2k views

Efficient way to compute Jacobian x Jacobian.T

Assume J is the Jacobian of some function f with respect to some parameters. Are there efficient ways (in PyTorch or perhaps Jax) to have a function that takes two inputs (x1 and x2) and computes J(x1)...
Milad's user avatar
  • 5,450