I'm having trouble using flax.jax_utils.prefetch_to_device for the simple function below. I'm loading the SIFT 1M dataset, and converting the array to jnp array.

I then want to prefetch the iterator of 128-dim arrays.

import tensorflow_datasets as tfds
import tensorflow as tf
import jax
import jax.numpy as jnp
import itertools
import jax.dlpack
import jax.tools.colab_tpu
import flax

def _sift1m_iter():
    def prepare_tf_data(xs):
        def _prepare(x):
            dl_arr = tf.experimental.dlpack.to_dlpack(x)
            jax_arr = jax.dlpack.from_dlpack(dl_arr)
            return jax_arr

        return jax.tree_util.tree_map(_prepare, xs['embedding'])

    ds = tfds.load('sift1m', split='database')
    it = map(prepare_tf_data, ds)
    #it = flax.jax_utils.prefetch_to_device(it, 2)  => this causes an error
    return it

However, when I run this code, I get an error:

ValueError: len(shards) = 128 must equal len(devices) = 1.

I'm running this on a CPU-only device, but from the error it seems like the shape of the data I'm passing into prefetch_to_device is wrong.

1 Answer 1


The output in the _prepare(x) function should have the shape [num_devices, batch_size].

In your case, assuming that you have a single GPU, its shape should be [1, 128].

Take a look on how it can be done here.

  • While this link may answer the question, it is better to include the essential parts of the answer here and provide the link for reference. Link-only answers can become invalid if the linked page changes. - From Review
    – Ben A.
    Commented Sep 7, 2023 at 17:26

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.