If you’re using
tf.keras.utils.Sequence is the way to go. The next method I’ll describe is more universal and works with
tf.keras, Tensorflow and JAX.
If you’re using Tensorflow without the keras training loop, or even if you are using keras: Start by converting your data into a RaggedTensor first. You can then turn the RaggedTensor into a
tf.data.Dataset using the
from_tensor_slices method. And then, Bob’s your uncle!
If you’re using JAX, use the
as_numpy_iterator method on the
One thing to bear in mind is that the JAX jit caches its traces based on the shapes of the tensors passed in. So, if you’re using eras as batches with the jit, your first epoch is going to be super slow because it has to jit your training loop for every era (because they all have different number of rows).
This is the case with Tensorflow AutoGraph as well, but the
experimental_relax_shapes argument to the
tf.function decorator helps alleviate this issue.
Using JAX without
jit and Tensorflow in eager mode are perfectly feasible options. Empirically, I’ve observed that both are a bit faster than PyTorch. Although, I haven’t tried using PyTorch’s jit.
In any case, you can efficiently mix and match Tensorflow, JAX (and even PyTorch, if you really want to) using dlpack tensors. All 3 frameworks support it. Heck, even XGBoost supports it.