|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def trainer(): |
|
|
return MagicMock() |
|
|
|
|
|
|
|
|
@patch('nemo.collections.llm.gpt.data.core.GPTSFTDataset.__init__', return_value=None) |
|
|
def test_finetuning_module(mock_gpt_sft_dataset, trainer) -> None: |
|
|
from nemo.collections.llm.gpt.data import FineTuningDataModule |
|
|
|
|
|
dataset_root = 'random_root' |
|
|
datamodule = FineTuningDataModule( |
|
|
dataset_root, |
|
|
seq_length=2048, |
|
|
micro_batch_size=4, |
|
|
global_batch_size=8, |
|
|
seed=1234, |
|
|
) |
|
|
datamodule.trainer = trainer |
|
|
datamodule.setup(stage='train') |
|
|
|
|
|
datamodule.train_dataloader() |
|
|
mock_gpt_sft_dataset.assert_called_once() |
|
|
|
|
|
|
|
|
@patch('nemo.collections.llm.gpt.data.core.GPTSFTDataset.__init__', return_value=None) |
|
|
def test_dolly_module(mock_gpt_sft_dataset, trainer) -> None: |
|
|
from nemo.collections.llm.gpt.data import DollyDataModule |
|
|
|
|
|
datamodule = DollyDataModule( |
|
|
seq_length=2048, |
|
|
micro_batch_size=4, |
|
|
global_batch_size=8, |
|
|
seed=1234, |
|
|
) |
|
|
datamodule.trainer = trainer |
|
|
datamodule.setup(stage='train') |
|
|
|
|
|
datamodule.train_dataloader() |
|
|
mock_gpt_sft_dataset.assert_called_once() |
|
|
|
|
|
|
|
|
@patch('nemo.collections.llm.gpt.data.core.GPTSFTDataset.__init__', return_value=None) |
|
|
def test_squad_module(mock_gpt_sft_dataset, trainer) -> None: |
|
|
from nemo.collections.llm.gpt.data import SquadDataModule |
|
|
|
|
|
datamodule = SquadDataModule( |
|
|
seq_length=2048, |
|
|
micro_batch_size=4, |
|
|
global_batch_size=8, |
|
|
seed=1234, |
|
|
) |
|
|
datamodule.trainer = trainer |
|
|
datamodule.setup(stage='train') |
|
|
|
|
|
datamodule.train_dataloader() |
|
|
mock_gpt_sft_dataset.assert_called_once() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|