All Questions
17 questions
1
vote
1
answer
90
views
Can jax.vmap() do a hstack()?
As the title says, I currently manually hstack() the first axis of a 3D array returned by jax.vmap(). In my code, the copy operation in hstack() is a currently a speed bottleneck. Can I avoid this by ...
1
vote
1
answer
368
views
Unexpected behavior of JAX `vmap` for multiple arguments
I have found that vmap in JAX does not behave as expected when applied to multiple arguments. For example, consider the function below:
def f1(x, y, z):
f = x[:, None, None] * z[None, None, :] + y[...
2
votes
1
answer
124
views
Why do I get different values from jnp.round and np.round?
I'm writing tests for some jax code and using np.testing.assert_array...-type functions and came across this difference in values that I didn't expect:
import jax.numpy as jnp
import numpy as np
from ...
4
votes
2
answers
6k
views
How to get value of jaxlib.xla_extension.ArrayImpl
Using type(z1[0]) I get jaxlib.xla_extension.ArrayImpl. Printing z1[0] I get Array(0.71530414, dtype=float32). How can I get the actual number 0.71530414?
I tried z1[0][0] because z1[0] is a kind of ...
1
vote
1
answer
580
views
index `jax` array with variable dimension
I am trying to write a general utility to update indices in a jax array that may have a different number of dimensions depending on the instance.
I know that I have to use the .at[].set() methods, and ...
0
votes
1
answer
301
views
Irregular/Inhomogeneous Arrays with JAX
What is the recommended approach to implement array behaviour/methods on irregular/inhomogeneous data (possesses some inherient dimensionality) within JAX?
Two principle options come to mind:
make ...
3
votes
1
answer
375
views
jax segment_sum along array dimension
I am fairly new to jax and have the following problem:
I need to compute functions (sum/min/max maybe more complex stuff later) across an array given an index. To solve this problem I found the jnp....
2
votes
1
answer
235
views
How to map the kronecker product along array dimensions?
Given two tensors A and B with the same dimension (d>=2) and shapes [A_{1},...,A_{d-2},A_{d-1},A_{d}] and [A_{1},...,A_{d-2},B_{d-1},B_{d}] (shapes of the first d-2 dimensions are identical).
Is ...
3
votes
1
answer
513
views
Modify an array from indexes contained in another array
I have an array of the shape (2,10) such as:
arr = jnp.ones(shape=(2,10)) * 2
or
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
and another array, for example [2,4].
I want the ...
1
vote
1
answer
142
views
Check if 2D sub-array is ordered - Pyhthon JAX
Let us suppose that we have an array ordered. We want to check if the sub-arrays t and t_inv are following the same order as the imposed order inorder array.
Reading from left to right: the first ...
4
votes
1
answer
1k
views
vmap in Jax to loop over arguments
Lets suppose I have some function which returns a sum of inputs.
@jit
def some_func(a,r1,r2):
return a + r1 + r2
Now I would like to loop over different values of r1 and r2, save the result and ...
3
votes
2
answers
1k
views
Vectorizing a function that takes multidimensional input over a multidimensional array in JAX
I have been trying to vectorize a function that takes two 2D arrays and return a
2D array of the same shape, so that I can apply it element wise to two 4D arrays.
Here is an example:
import jax.numpy ...
3
votes
1
answer
3k
views
Handle varying shapes in jax numpy arrays (jit compatible)
Important note: I need everything to be jit compatible here, otherwise my problem is trivial :)
I have a jax numpy array such as:
a = jnp.array([1,5,3,4,5,6,7,2,9])
First I filter it considering a ...
1
vote
1
answer
305
views
swapaxes and how it is implemented?
I'm wondering if someone can explain this code to me?
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c....
4
votes
1
answer
342
views
Fastest way to multiply and sum 4D array with 2D array in python?
Here's my problem. I have two matrices A and B, with complex entries, of dimensions (n,n,m,m) and (n,n) respectively.
Below is the operation I perform to get a matrix C -
C = np.sum(B[:,:,None,None]*A,...
0
votes
1
answer
2k
views
Is there a way to speed up indexing a vector with JAX?
I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays. For example, consider making a basic array in JAX numpy and ordinary ...
1
vote
1
answer
2k
views
indexing into numpy array with jax array: faulty error messages
The following numpy code is perfectly fine:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
It also works after migrating ...