Skip to main content

All Questions

Tagged with
Filter by
Sorted by
Tagged with
1 vote
1 answer

JAX jax.grad on simple function that takes an array: `ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected`

I'm trying to implement this function and use JAX to automatically build the gradient function: $f(x) = \sum\limits_{k=1}^{n-1} [100 (x_{k+1} - x_k^2)^2 + (1 - x_k)^2]$ (sorry, I don't know how to ...
clay's user avatar
  • 20.3k
1 vote
1 answer

Getting derivatives of NNs according to its inputs by batches in JAX

There is a neural network that takes as an input a two variables: net(x, t), where x is usually d-dim, and t is a scalar. The NN outputs a vector of length d. x and t might be batches, so x is of ...
Michaela's user avatar
1 vote
2 answers

PyTorch autograd: Efficient computation of Jacobian and Jacobian-Vector-product of scalar function over range of inputs

I have function which takes 5 values as arguments and returns a scalar. That is a mapping of the form f:R^5 -> R. Hench, its Jacobian J is a matrix with dimension (1x5) and might as well have been ...
Landscape's user avatar
  • 257
3 votes
2 answers

Purpose of stop gradient in `jax.nn.softmax`?

jax.nn.softmax is defined as: def softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None)...
Jay Mody's user avatar
  • 4,013
0 votes
1 answer

Cannot compute simple gradient of lambda function in JAX

I'm trying to compute the gradient of a lambda function that involves other gradients of functions, but the computation is hanging and I do not understand why. In particular, the code below ...
Marco's user avatar
  • 25
1 vote
1 answer

JAX return types after transformations

Why is the return type of jax.grad different to other jax transformations in the following scenario? Consider a function to be transformed by JAX which takes a custom container as an argument import ...
DavidJ's user avatar
  • 408
1 vote
1 answer

jax empty element in the parameters of neural network

I am working on the implementation of a very small neural network. My network is as follows: init_random_params, predict = stax.serial( Dense(1024), Relu, Dense(1024), Relu, Dense(10), LogSoftmax) I ...
MMH's user avatar
  • 111
3 votes
1 answer

Automatic Differentiation with respect to rank-based computations

I'm new to automatic differentiation programming, so this maybe a naive question. Below is a simplified version of what I'm trying to solve. I have two input arrays - a vector A of size N and a matrix ...
P JMU's user avatar
  • 69
1 vote
1 answer

How to use grad convolution in google-jax?

Thanks for reading my question! I was just learning about custom grad functions in Jax, and I found the approach JAX took with defining custom functions is quite elegant. One thing troubles me though. ...
TIM's user avatar
  • 125
2 votes
0 answers

Can OpenMDAO co-operate with autograd or jax?

Could the autograd or jax packages be used to generate the equivalent of analytic derivatives for OpenMDAO explicit components? i.e. something more accurate than finite differences (or perhaps more ...
Jacob Schwartz's user avatar
0 votes
0 answers

Training neural network on gradient of input with pytorch

I am currently trying to train a neural network with pytorch, where I try to match the input on the input derivative. I want to do this because this is ensuring a conservative vector field. (Done in ...
not_converging's user avatar
3 votes
2 answers

Conditional update in JAX?

In autograd/numpy I could do: q[q<0] = 0.0 How can I do the same thing in JAX? I tried import numpy as onp and using that to create arrays, but that doesn't seem to work.
Andriy Drozdyuk's user avatar
4 votes
1 answer

Efficient way to compute Jacobian x Jacobian.T

Assume J is the Jacobian of some function f with respect to some parameters. Are there efficient ways (in PyTorch or perhaps Jax) to have a function that takes two inputs (x1 and x2) and computes J(x1)...
Milad's user avatar
  • 5,450