Skip to main content

All Questions

Tagged with
Filter by
Sorted by
Tagged with
1 vote
1 answer
90 views

Can jax.vmap() do a hstack()?

As the title says, I currently manually hstack() the first axis of a 3D array returned by jax.vmap(). In my code, the copy operation in hstack() is a currently a speed bottleneck. Can I avoid this by ...
marnix's user avatar
  • 1,172
1 vote
1 answer
368 views

Unexpected behavior of JAX `vmap` for multiple arguments

I have found that vmap in JAX does not behave as expected when applied to multiple arguments. For example, consider the function below: def f1(x, y, z): f = x[:, None, None] * z[None, None, :] + y[...
Jingyang Wang's user avatar
2 votes
1 answer
124 views

Why do I get different values from jnp.round and np.round?

I'm writing tests for some jax code and using np.testing.assert_array...-type functions and came across this difference in values that I didn't expect: import jax.numpy as jnp import numpy as np from ...
Bill's user avatar
  • 11.6k
4 votes
2 answers
6k views

How to get value of jaxlib.xla_extension.ArrayImpl

Using type(z1[0]) I get jaxlib.xla_extension.ArrayImpl. Printing z1[0] I get Array(0.71530414, dtype=float32). How can I get the actual number 0.71530414? I tried z1[0][0] because z1[0] is a kind of ...
fabianod's user avatar
  • 693
1 vote
1 answer
580 views

index `jax` array with variable dimension

I am trying to write a general utility to update indices in a jax array that may have a different number of dimensions depending on the instance. I know that I have to use the .at[].set() methods, and ...
zephyrus's user avatar
  • 1,266
0 votes
1 answer
301 views

Irregular/Inhomogeneous Arrays with JAX

What is the recommended approach to implement array behaviour/methods on irregular/inhomogeneous data (possesses some inherient dimensionality) within JAX? Two principle options come to mind: make ...
DavidJ's user avatar
  • 408
3 votes
1 answer
375 views

jax segment_sum along array dimension

I am fairly new to jax and have the following problem: I need to compute functions (sum/min/max maybe more complex stuff later) across an array given an index. To solve this problem I found the jnp....
Simon P.'s user avatar
  • 125
2 votes
1 answer
235 views

How to map the kronecker product along array dimensions?

Given two tensors A and B with the same dimension (d>=2) and shapes [A_{1},...,A_{d-2},A_{d-1},A_{d}] and [A_{1},...,A_{d-2},B_{d-1},B_{d}] (shapes of the first d-2 dimensions are identical). Is ...
bayes2021's user avatar
  • 202
3 votes
1 answer
513 views

Modify an array from indexes contained in another array

I have an array of the shape (2,10) such as: arr = jnp.ones(shape=(2,10)) * 2 or [[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.] [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]] and another array, for example [2,4]. I want the ...
Valentin Macé's user avatar
1 vote
1 answer
142 views

Check if 2D sub-array is ordered - Pyhthon JAX

Let us suppose that we have an array ordered. We want to check if the sub-arrays t and t_inv are following the same order as the imposed order inorder array. Reading from left to right: the first ...
relaxon's user avatar
  • 141
4 votes
1 answer
1k views

vmap in Jax to loop over arguments

Lets suppose I have some function which returns a sum of inputs. @jit def some_func(a,r1,r2): return a + r1 + r2 Now I would like to loop over different values of r1 and r2, save the result and ...
Zohim's user avatar
  • 51
3 votes
2 answers
1k views

Vectorizing a function that takes multidimensional input over a multidimensional array in JAX

I have been trying to vectorize a function that takes two 2D arrays and return a 2D array of the same shape, so that I can apply it element wise to two 4D arrays. Here is an example: import jax.numpy ...
VerwirrterStudent's user avatar
3 votes
1 answer
3k views

Handle varying shapes in jax numpy arrays (jit compatible)

Important note: I need everything to be jit compatible here, otherwise my problem is trivial :) I have a jax numpy array such as: a = jnp.array([1,5,3,4,5,6,7,2,9]) First I filter it considering a ...
Valentin Macé's user avatar
1 vote
1 answer
305 views

swapaxes and how it is implemented?

I'm wondering if someone can explain this code to me? c = self.config assert len(pair_act.shape) == 3 assert len(pair_mask.shape) == 2 assert c.orientation in ['per_row', 'per_column'] if c....
user10713428's user avatar
4 votes
1 answer
342 views

Fastest way to multiply and sum 4D array with 2D array in python?

Here's my problem. I have two matrices A and B, with complex entries, of dimensions (n,n,m,m) and (n,n) respectively. Below is the operation I perform to get a matrix C - C = np.sum(B[:,:,None,None]*A,...
Prasad Mani's user avatar
0 votes
1 answer
2k views

Is there a way to speed up indexing a vector with JAX?

I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays. For example, consider making a basic array in JAX numpy and ordinary ...
Danny Williams's user avatar
1 vote
1 answer
2k views

indexing into numpy array with jax array: faulty error messages

The following numpy code is perfectly fine: arr = np.arange(50) print(arr.shape) # (50,) indices = np.zeros((30,), dtype=int) print(indices.shape) # (30,) arr[indices] It also works after migrating ...
lhk's user avatar
  • 29.8k