Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More problems with train_on_responses_only #1017

Open
LostRuins opened this issue Sep 12, 2024 · 3 comments
Open

More problems with train_on_responses_only #1017

LostRuins opened this issue Sep 12, 2024 · 3 comments

Comments

@LostRuins
Copy link

I'm trying to finetune Mistral-Nemo-Base-2407 with a text dataset of long inputs. Usually, the SFTrainer will truncate it to fit the specified context size.

However, I get an error when using train_on_responses_only.

Running the same dataset without train_on_responses_only works fine and trains normally.

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 2,068 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 258
 "-____-"     Number of trainable parameters = 57,016,320
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:762, in BatchEncoding.convert_to_tensors(self, tensor_type, prepend_batch_axis)
    761 if not is_tensor(value):
--> 762     tensor = as_tensor(value)
    764     # Removing this for now in favor of controlling the shape with `prepend_batch_axis`
    765     # # at-least2d
    766     # if tensor.ndim > 2:
    767     #     tensor = tensor.squeeze(0)
    768     # elif tensor.ndim < 2:
    769     #     tensor = tensor[None, :]

File /usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:724, in BatchEncoding.convert_to_tensors.<locals>.as_tensor(value, dtype)
    723     return torch.tensor(np.array(value))
--> 724 return torch.tensor(value)

ValueError: expected sequence of length 4096 at dim 1 (got 327)

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[1], line 268
    258 trainer = train_on_responses_only(
    259     trainer,
    260     instruction_part = "### Instruction:\n",
    261     response_part = "### Response:\n",
    262 )
    264 # #sanity check
    265 # space = tokenizer(" ", add_special_tokens = False).input_ids[0]
    266 # tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[0]["labels"]])
--> 268 trainer_stats = trainer.train()
    270 used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    271 used_memory_for_lora = round(used_memory - start_gpu_memory, 3)

File <string>:145, in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)

File <string>:320, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File /usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py:550, in DataLoaderShard.__iter__(self)
    548 # We iterate one batch ahead to check when we are at the end
    549 try:
--> 550     current_batch = next(dataloader_iter)
    551 except StopIteration:
    552     yield

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
    672 def _next_data(self):
    673     index = self._next_index()  # may raise StopIteration
--> 674     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    675     if self._pin_memory:
    676         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py:54, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     52 else:
     53     data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)

File /usr/local/lib/python3.10/dist-packages/transformers/data/data_collator.py:45, in DataCollatorMixin.__call__(self, features, return_tensors)
     43     return self.tf_call(features)
     44 elif return_tensors == "pt":
---> 45     return self.torch_call(features)
     46 elif return_tensors == "np":
     47     return self.numpy_call(features)

File /usr/local/lib/python3.10/dist-packages/transformers/data/data_collator.py:806, in DataCollatorForLanguageModeling.torch_call(self, examples)
    803 def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
    804     # Handle dict or lists with proper padding and conversion to tensor.
    805     if isinstance(examples[0], Mapping):
--> 806         batch = pad_without_fast_tokenizer_warning(
    807             self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
    808         )
    809     else:
    810         batch = {
    811             "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
    812         }

File /usr/local/lib/python3.10/dist-packages/transformers/data/data_collator.py:66, in pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs)
     63 tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
     65 try:
---> 66     padded = tokenizer.pad(*pad_args, **pad_kwargs)
     67 finally:
     68     # Restore the state of the warning.
     69     tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state

File /usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:3560, in PreTrainedTokenizerBase.pad(self, encoded_inputs, padding, max_length, pad_to_multiple_of, return_attention_mask, return_tensors, verbose)
   3557             batch_outputs[key] = []
   3558         batch_outputs[key].append(value)
-> 3560 return BatchEncoding(batch_outputs, tensor_type=return_tensors)

File /usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:227, in BatchEncoding.__init__(self, data, encoding, tensor_type, prepend_batch_axis, n_sequences)
    223     n_sequences = encoding[0].n_sequences
    225 self._n_sequences = n_sequences
--> 227 self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)

File /usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:778, in BatchEncoding.convert_to_tensors(self, tensor_type, prepend_batch_axis)
    773         if key == "overflowing_tokens":
    774             raise ValueError(
    775                 "Unable to create tensor returning overflowing tokens of different lengths. "
    776                 "Please see if a fast version of this tokenizer is available to have this feature available."
    777             ) from e
--> 778         raise ValueError(
    779             "Unable to create tensor, you should probably activate truncation and/or padding with"
    780             " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your"
    781             f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is"
    782             " expected)."
    783         ) from e
    785 return self

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`labels` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

Any help would be appreciated.

@danielhanchen
Copy link
Contributor

@LostRuins Apologies on the delay - it seems like it's saying the labels are nested? Would it be possible to print out maybe the first few rows of trainer.train_dataset? Thanks! Also our Discord server can be more helpful for async help if that works!

@LostRuins
Copy link
Author

Hi @danielhanchen , there are many rows, i've trimmed it to show the format

trainer.train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1835
})

trainer.train_dataset[0]

{'input_ids': [1,
  1595,
  83779,
  1877,
  18746,  
  ...],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  ...],
  'labels': [-100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  1877,
  82236,
  1321,
  14969,
  5978,
   ...]}

What other commands should I run?

@danielhanchen
Copy link
Contributor

Ok I'll check on my end and get back to you asap!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants