From 32f6442f82927e342b9f9e117dfee0a3b03a80ff Mon Sep 17 00:00:00 2001 From: niujunhao Date: Tue, 2 Dec 2025 19:48:56 +0800 Subject: [PATCH] fix bs>1 in hf dataloader tnd. --- mindformers/dataset/causal_language_model_dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mindformers/dataset/causal_language_model_dataset.py b/mindformers/dataset/causal_language_model_dataset.py index e5452511b..25eb9d746 100644 --- a/mindformers/dataset/causal_language_model_dataset.py +++ b/mindformers/dataset/causal_language_model_dataset.py @@ -38,9 +38,15 @@ CAST_TO_INT_COLUMNS = ["input_ids", "labels"] def _use_compressed_eod_mask(data_loader): + """ + Determine whether the given data loader should use a compressed EOD (End-Of-Document) mask. + """ if (hasattr(data_loader, 'config') and data_loader.config and data_loader.config.create_compressed_eod_mask): # megatron dataset return True + if (hasattr(data_loader, 'create_compressed_eod_mask') and + data_loader.create_compressed_eod_mask): + return True if (hasattr(data_loader, 'adaptor_config') and data_loader.adaptor_config and data_loader.adaptor_config.compress_mask): # common dataloader return True -- Gitee