Skip to content
Snippets Groups Projects
Commit cb574631 authored by Ludwig Forsberg's avatar Ludwig Forsberg
Browse files

customTrainer fix

parent 900f98d5
Branches
Tags
No related merge requests found
......@@ -72,12 +72,6 @@ from transformers.modelcard import TrainingSummary
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.optimization import Adafactor, get_scheduler
from transformers.pytorch_utils import (
ALL_LAYERNORM_LAYERS,
is_torch_greater_or_equal_than_1_6,
is_torch_greater_or_equal_than_1_10,
is_torch_less_than_1_11,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
CallbackHandler,
......@@ -168,14 +162,6 @@ if is_in_notebook():
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
if is_torch_greater_or_equal_than_1_6:
_is_torch_generator_available = True
_is_native_cuda_amp_available = True
if is_torch_greater_or_equal_than_1_10:
_is_native_cpu_amp_available = True
if is_datasets_available():
import datasets
......@@ -371,15 +357,7 @@ class CustomTrainer(Seq2SeqTrainer):
is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
train_dataloader.sampler, RandomSampler
)
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however...
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
_ = list(train_dataloader.sampler)
_ = list(train_dataloader.sampler)
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment