I'm having trouble using flax.jax_utils.prefetch_to_device
for the simple function below. I'm loading the SIFT 1M dataset, and converting the array to jnp array.
I then want to prefetch the iterator of 128-dim arrays.
import tensorflow_datasets as tfds
import tensorflow as tf
import jax
import jax.numpy as jnp
import itertools
import jax.dlpack
import jax.tools.colab_tpu
import flax
def _sift1m_iter():
def prepare_tf_data(xs):
def _prepare(x):
dl_arr = tf.experimental.dlpack.to_dlpack(x)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
return jax_arr
return jax.tree_util.tree_map(_prepare, xs['embedding'])
ds = tfds.load('sift1m', split='database')
it = map(prepare_tf_data, ds)
#it = flax.jax_utils.prefetch_to_device(it, 2) => this causes an error
return it
However, when I run this code, I get an error:
ValueError: len(shards) = 128 must equal len(devices) = 1.
I'm running this on a CPU-only device, but from the error it seems like the shape of the data I'm passing into prefetch_to_device
is wrong.