Skip to main content

All Questions

Tagged with
Filter by
Sorted by
Tagged with
1 vote
1 answer
87 views

Parallelizing over huge for loop involving a generator

Is there a way to compute in parallel the maximum value of an array so big I need to use a generator in parallel? I am computing the maximum value of a function over a very large discrete set of ...
Alex Albors's user avatar
1 vote
1 answer
270 views

Slow JAX Optimization with ScipyBoundedMinimize and Optax - Seeking Speedup Strategies

I'm working on optimizing a model in jax that involves fitting a large observational dataset (4800 data points) with a complex model containing interpolation. The current optimization process using ...
eng's user avatar
  • 97
0 votes
0 answers
398 views

Using JAXopt on constrained optimization(non-negative)

I am trying to do an optimization(using JAX) on my loss function, which comes from a basic physics model, which is two layers growing at different rates. Since the growing rates are always positive. I ...
Heng Yuan's user avatar
2 votes
0 answers
509 views

How to implement np.lib.stride_tricks.sliding_window_view efficiently in JAX?

I have implemented an algorithm in which I calculate the Pearson correlation coefficient between a vector in one image and every vector in another image within a given window around the equivalent ...
Nin17's user avatar
  • 3,442
1 vote
1 answer
412 views

vectorized minimization and root finding in jax

I have a family of functions parameterized by args f(x, args) and want to determine the minimum of f over x for N = 1000 values of args. I have access to both the function and its derivative. My ...
Dan Leonte's user avatar
2 votes
1 answer
920 views

Parallelize with JAX over all GPU cores

In order to minimize the function x^2+y^2, I tried to implement the Adam optimizer from scratch with JAX: @jax.jit def fit(X, batches, params=[0.001, 0.9, 0.99, 1e-8]): global fun # batches ...
ilikenoodles's user avatar
0 votes
0 answers
341 views

Optimisation with JAX using vmap or foriloop

I'm still encountering a challenge with optimizing my code using the JAX library. The purpose of this code is to apply a blur to an intensity array by creating a normalized Gaussian for each pixel and ...
PortorogasDS's user avatar
2 votes
2 answers
245 views

Optimization and speed enhancement of function using JAX

The code performs Gaussian blurring on the image intensityRefracted2DF using a Gaussian kernel centered at each pixel. The Gaussian kernel is determined by the values in the darkField array, where ...
PortorogasDS's user avatar
0 votes
1 answer
621 views

How to write an exponential learning rate decay algorithm while using JAX's version of Adam Optimizer

I am writing a Physics Informed Neural Network (PINN) code and have been using the JAX library for the same. I have been using the 'Adam' optimizer from JAX's 'example_libraries.optimizers' Module. I ...
Sumanta Roy's user avatar
1 vote
1 answer
617 views

Error using JAX, Array slice indices must have static start/stop/step

I'll be happy to help you with your code. If I understand correctly, you want to create a 2D Gaussian patch for each value in the darkField array. The size of the patch should ideally be calculated as ...
PortorogasDS's user avatar
2 votes
1 answer
132 views

Issues while using JAX to minimize the Lennard-Jones potential for two points and the force (gradient of the potential)--result doesn't match

I am trying to use the minimization function in JAX to find the distance of two points satisfying Lennard-Jones potential E = 2(1/r^4-1/r^2) and I can succssfully get the result: [-0.20710678 1....
Heng Yuan's user avatar
1 vote
1 answer
239 views

JAX code for minimizing Lennard-Jones potential for 2 points in Python gives unexpected results

I am trying to practice using JAX fo optimization problem and I am trying to do a simple problem, which is to minimize Lennard-Jones potential for just 2 points and I set both epsilon and sigma in ...
Heng Yuan's user avatar
1 vote
1 answer
289 views

JAX: Canceled future for execute_request message before replies were done

I have an optimization problem which I am trying to solve with Newton's method. To calculate jacobian matrix, I use jax.jacobian. My objective function is called calc_ms. I use Newton's method to find ...
Alina Ozhegova's user avatar
1 vote
1 answer
450 views

Python function minimization not changing optimization variables

I need to minimize a simple function that divides two values. The optimization paramter x is a (n,m) numpy array from which I calculate a float. # An initial value normX0 = calculate_normX(x_start) ...
agentsmith's user avatar
  • 1,306
4 votes
2 answers
559 views

How to improve Julia's performance using just in time compilation (JIT)

I have been playing with JAX (automatic differentiation library in Python) and Zygote (the automatic differentiation library in Julia) to implement Gauss-Newton minimisation method. I came upon the @...
MOON's user avatar
  • 2,771
1 vote
1 answer
615 views

Mapping a vector to a matrix in JAX

I want to optimize with JAX an elements of a vector with a loss function that is a function of a matrix built by the elements of said vector. Specifically, the element of the matrix n,m correspond to ...
RealFatShady's user avatar
1 vote
1 answer
1k views

vmap ops.index_update in Jax

I have the following code below and it's using a simple for loop. I was just wondering if there was a way to vmap it? Here is the original code: import numpy as np import jax.numpy as jnp import jax....
DumbCoder21's user avatar
5 votes
2 answers
8k views

Why is this function slower in JAX vs numpy?

I have the following numpy function as seen below that I'm trying to optimize by using JAX but for whatever reason, it's slower. Could someone point out what I can do to improve the performance here? ...
DumbCoder21's user avatar