What happens if texts from the dataset don't have equal lengths
What would happen if the texts in the dataset didn't have equal lengths and batch size was > 1. E.g. first text would have 2 segments and second 4. Would the loss be nan for segments with padding tokens only? I think your training script assumes that a segment is always of sequence length of 2048 non-padded.