All Questions
Tagged with jax optimization
18 questions
1
vote
1
answer
87
views
Parallelizing over huge for loop involving a generator
Is there a way to compute in parallel the maximum value of an array so big I need to use a generator in parallel? I am computing the maximum value of a function over a very large discrete set of ...
1
vote
1
answer
270
views
Slow JAX Optimization with ScipyBoundedMinimize and Optax - Seeking Speedup Strategies
I'm working on optimizing a model in jax that involves fitting a large observational dataset (4800 data points) with a complex model containing interpolation. The current optimization process using ...
0
votes
0
answers
398
views
Using JAXopt on constrained optimization(non-negative)
I am trying to do an optimization(using JAX) on my loss function, which comes from a basic physics model, which is two layers growing at different rates. Since the growing rates are always positive. I ...
2
votes
0
answers
509
views
How to implement np.lib.stride_tricks.sliding_window_view efficiently in JAX?
I have implemented an algorithm in which I calculate the Pearson correlation coefficient between a vector in one image and every vector in another image within a given window around the equivalent ...
1
vote
1
answer
412
views
vectorized minimization and root finding in jax
I have a family of functions parameterized by args
f(x, args)
and want to determine the minimum of f over x for N = 1000 values of args. I have access to both the function and its derivative. My ...
2
votes
1
answer
920
views
Parallelize with JAX over all GPU cores
In order to minimize the function x^2+y^2, I tried to implement the Adam optimizer from scratch with JAX:
@jax.jit
def fit(X, batches, params=[0.001, 0.9, 0.99, 1e-8]):
global fun
# batches ...
0
votes
0
answers
341
views
Optimisation with JAX using vmap or foriloop
I'm still encountering a challenge with optimizing my code using the JAX library. The purpose of this code is to apply a blur to an intensity array by creating a normalized Gaussian for each pixel and ...
2
votes
2
answers
245
views
Optimization and speed enhancement of function using JAX
The code performs Gaussian blurring on the image intensityRefracted2DF using a Gaussian kernel centered at each pixel. The Gaussian kernel is determined by the values in the darkField array, where ...
0
votes
1
answer
621
views
How to write an exponential learning rate decay algorithm while using JAX's version of Adam Optimizer
I am writing a Physics Informed Neural Network (PINN) code and have been using the JAX library for the same. I have been using the 'Adam' optimizer from JAX's 'example_libraries.optimizers' Module. I ...
1
vote
1
answer
617
views
Error using JAX, Array slice indices must have static start/stop/step
I'll be happy to help you with your code. If I understand correctly, you want to create a 2D Gaussian patch for each value in the darkField array. The size of the patch should ideally be calculated as ...
2
votes
1
answer
132
views
Issues while using JAX to minimize the Lennard-Jones potential for two points and the force (gradient of the potential)--result doesn't match
I am trying to use the minimization function in JAX to find the distance of two points satisfying Lennard-Jones potential E = 2(1/r^4-1/r^2) and I can succssfully get the result: [-0.20710678 1....
1
vote
1
answer
239
views
JAX code for minimizing Lennard-Jones potential for 2 points in Python gives unexpected results
I am trying to practice using JAX fo optimization problem and I am trying to do a simple problem, which is to minimize Lennard-Jones potential for just 2 points and I set both epsilon and sigma in ...
1
vote
1
answer
289
views
JAX: Canceled future for execute_request message before replies were done
I have an optimization problem which I am trying to solve with Newton's method. To calculate jacobian matrix, I use jax.jacobian.
My objective function is called calc_ms. I use Newton's method to find ...
1
vote
1
answer
450
views
Python function minimization not changing optimization variables
I need to minimize a simple function that divides two values. The optimization paramter x is a (n,m) numpy array from which I calculate a float.
# An initial value
normX0 = calculate_normX(x_start)
...
4
votes
2
answers
559
views
How to improve Julia's performance using just in time compilation (JIT)
I have been playing with JAX (automatic differentiation library in Python) and Zygote (the automatic differentiation library in Julia) to implement Gauss-Newton minimisation method.
I came upon the @...
1
vote
1
answer
615
views
Mapping a vector to a matrix in JAX
I want to optimize with JAX an elements of a vector with a loss function that is a function of a matrix built by the elements of said vector. Specifically, the element of the matrix n,m correspond to ...
1
vote
1
answer
1k
views
vmap ops.index_update in Jax
I have the following code below and it's using a simple for loop. I was just wondering if there was a way to vmap it? Here is the original code:
import numpy as np
import jax.numpy as jnp
import jax....
5
votes
2
answers
8k
views
Why is this function slower in JAX vs numpy?
I have the following numpy function as seen below that I'm trying to optimize by using JAX but for whatever reason, it's slower.
Could someone point out what I can do to improve the performance here? ...