zaydzuhri commited on
Commit
5b64e7c
·
verified ·
1 Parent(s): 613202f

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc +0 -0
  2. fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc +0 -0
  3. logs/none_1_grtqk5/attempt_0/0/stderr.log +0 -0
  4. logs/none_1_grtqk5/attempt_0/1/stderr.log +0 -0
  5. logs/none_1_grtqk5/attempt_0/6/stderr.log +0 -0
  6. setup.py +51 -0
  7. torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  8. torchtitan/components/optimizer.py +303 -0
  9. torchtitan/datasets/hf_datasets.py +173 -0
  10. torchtitan/datasets/tokenizer/tiktoken.py +190 -0
  11. torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc +0 -0
  12. torchtitan/distributed/__pycache__/utils.cpython-312.pyc +0 -0
  13. torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
  14. torchtitan/experiments/deepseek_v3/README.md +40 -0
  15. torchtitan/experiments/deepseek_v3/generate.py +308 -0
  16. torchtitan/experiments/deepseek_v3/indices.py +195 -0
  17. torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
  18. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
  19. torchtitan/experiments/flux/README.md +23 -0
  20. torchtitan/experiments/flux/__init__.py +122 -0
  21. torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
  22. torchtitan/experiments/flux/model/autoencoder.py +388 -0
  23. torchtitan/experiments/flux/model/hf_embedder.py +40 -0
  24. torchtitan/experiments/flux/model/math.py +38 -0
  25. torchtitan/experiments/flux/model/model.py +177 -0
  26. torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
  27. torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
  28. torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
  29. torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py +885 -0
  30. torchtitan/experiments/llama4/README.md +29 -0
  31. torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc +0 -0
  32. torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
  33. torchtitan/experiments/llama4/model/args.py +109 -0
  34. torchtitan/experiments/llama4/model/moe.py +228 -0
  35. torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py +536 -0
  36. torchtitan/experiments/multimodal/__init__.py +37 -0
  37. torchtitan/experiments/multimodal/mm_collator.py +227 -0
  38. torchtitan/experiments/multimodal/requirements.txt +1 -0
  39. torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc +0 -0
  40. torchtitan/experiments/simple_fsdp/model.py +18 -0
  41. torchtitan/experiments/simple_fsdp/simple_fsdp.py +194 -0
  42. torchtitan/experiments/simple_fsdp/tests/__init__.py +5 -0
  43. torchtitan/models/__init__.py +10 -0
  44. torchtitan/models/__pycache__/__init__.cpython-312.pyc +0 -0
  45. torchtitan/models/__pycache__/norms.cpython-312.pyc +0 -0
  46. torchtitan/models/llama3/__init__.py +76 -0
  47. torchtitan/models/llama3/parallelize_llama.py +398 -0
  48. torchtitan/models/llama3/pipeline_llama.py +161 -0
  49. torchtitan/models/llama3/train_configs/llama3_405b.toml +63 -0
  50. torchtitan/tools/profiling.py +131 -0
fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc ADDED
Binary file (7.06 kB). View file
 
fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc ADDED
Binary file (20.6 kB). View file
 
logs/none_1_grtqk5/attempt_0/0/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_1_grtqk5/attempt_0/1/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_1_grtqk5/attempt_0/6/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import ast
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from setuptools import find_packages, setup
9
+
10
+ with open('README.md') as f:
11
+ long_description = f.read()
12
+
13
+
14
+ def get_package_version():
15
+ with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
16
+ version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
17
+ return ast.literal_eval(version_match.group(1))
18
+
19
+
20
+ setup(
21
+ name='flame',
22
+ version=get_package_version(),
23
+ description='A minimal training framework for scaling FLA models',
24
+ long_description=long_description,
25
+ long_description_content_type='text/markdown',
26
+ author='Songlin Yang, Yu Zhang',
27
+ author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
28
+ url='https://github.com/fla-org/flame',
29
+ packages=find_packages(),
30
+ license='MIT',
31
+ classifiers=[
32
+ 'Programming Language :: Python :: 3',
33
+ 'License :: OSI Approved :: MIT License',
34
+ 'Operating System :: OS Independent',
35
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence'
36
+ ],
37
+ python_requires='>=3.10',
38
+ install_requires=[
39
+ 'torch==2.6',
40
+ 'torchdata',
41
+ 'transformers==4.51.3',
42
+ 'triton>=3.0',
43
+ 'datasets>=3.3.0',
44
+ 'einops',
45
+ 'ninja',
46
+ 'wandb',
47
+ 'tiktoken',
48
+ 'tensorboard',
49
+ 'python-dotenv'
50
+ ],
51
+ )
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (7.71 kB). View file
 
torchtitan/components/optimizer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import functools
8
+ from typing import Any, Generic, Iterator, TypeVar
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.distributed.checkpoint.state_dict import (
13
+ get_optimizer_state_dict,
14
+ set_optimizer_state_dict,
15
+ StateDictOptions,
16
+ )
17
+ from torch.distributed.checkpoint.stateful import Stateful
18
+ from torch.optim import Optimizer
19
+
20
+ from torchtitan.components.ft import FTManager, has_torchft
21
+ from torchtitan.config_manager import JobConfig
22
+
23
+ __all__ = [
24
+ "OptimizersContainer",
25
+ "build_optimizers",
26
+ ]
27
+
28
+
29
+ if has_torchft:
30
+ import torchft as ft
31
+
32
+
33
+ T = TypeVar("T", bound=Optimizer)
34
+
35
+
36
+ class OptimizersContainer(Optimizer, Stateful, Generic[T]):
37
+ """A container for multiple optimizers.
38
+
39
+ This class is used to wrap multiple optimizers into a single object that can be
40
+ used to reduce the complexity of the training loop. This mimics the behavior of
41
+ ``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``.
42
+
43
+ **Note**
44
+ Users who want to customize the optimizer behavior can inherit from this class and
45
+ extend the functionality as needed. The following methods must follow the same signature
46
+ as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``,
47
+ ``load_state_dict()``.
48
+
49
+ **Limitations**
50
+ This class assumes that all the optimizers are the same type and have the same
51
+ configurations. With this assumption, TorchTitan can support lr scheduler resharding
52
+ (e.g., loading a checkpoint with a different number of GPUs and/or different
53
+ parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the
54
+ resharding for the optimizer state but not for the lr scheduler state, hence the limitation.
55
+
56
+ Args:
57
+ model_parts (List[nn.Module]): List of model parts to be optimized.
58
+ optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers.
59
+ name (str): Name of the optimizers.
60
+ """
61
+
62
+ optimizers: list[T]
63
+ model_parts: list[nn.Module]
64
+
65
+ def __init__(
66
+ self,
67
+ model_parts: list[nn.Module],
68
+ optimizer_cls: type[T],
69
+ optimizer_kwargs: dict[str, Any],
70
+ ) -> None:
71
+ all_params = []
72
+ self.optimizers = []
73
+ self.model_parts = model_parts
74
+ for model in self.model_parts:
75
+ params = [p for p in model.parameters() if p.requires_grad]
76
+ self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
77
+ all_params.extend(params)
78
+ self._validate_length(len(self.model_parts))
79
+ self._post_init(all_params, optimizer_kwargs)
80
+
81
+ def __iter__(self) -> Iterator[T]:
82
+ return iter(self.optimizers)
83
+
84
+ def __len__(self) -> int:
85
+ return len(self.optimizers)
86
+
87
+ def step(self, *args, **kwargs) -> None:
88
+ for optimizer in self.optimizers:
89
+ optimizer.step(*args, **kwargs)
90
+
91
+ def zero_grad(self, *args, **kwargs) -> None:
92
+ for optimizer in self.optimizers:
93
+ optimizer.zero_grad(*args, **kwargs)
94
+
95
+ def state_dict(self) -> dict[str, Any]:
96
+ func = functools.partial(
97
+ get_optimizer_state_dict,
98
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
99
+ )
100
+ return {
101
+ k: v
102
+ for sd in map(func, self.model_parts, self.optimizers)
103
+ for k, v in sd.items()
104
+ }
105
+
106
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
107
+ func = functools.partial(
108
+ set_optimizer_state_dict,
109
+ optim_state_dict=state_dict,
110
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
111
+ )
112
+ list(map(func, self.model_parts, self.optimizers))
113
+
114
+ def _validate_length(self, expected_length: int) -> None:
115
+ assert expected_length == len(self.optimizers), (
116
+ "Must pass one optimizer per model part or per param if "
117
+ "using OptimizersInBackwardContainer."
118
+ )
119
+
120
+ def _post_init(
121
+ self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
122
+ ) -> None:
123
+ # We need to call Optimizer.__init__() to initialize some necessary optimizer
124
+ # functionality such as hooks.
125
+ Optimizer.__init__(self, all_params, optimizer_kwargs)
126
+
127
+
128
+ class OptimizersInBackwardContainer(OptimizersContainer):
129
+ """OptimizersContainer for executing ``optim.step()`` in backward pass.
130
+
131
+ This class extend ``OptimizersContainer`` to support optimizer step in
132
+ backward pass. ``step()`` and ``zero_grad()`` are no-op in this class.
133
+ Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to
134
+ execute these methods when the gradient is accumulated.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ model_parts: list[nn.Module],
140
+ optimizer_cls: type[T],
141
+ optimizer_kwargs: dict[str, Any],
142
+ ) -> None:
143
+ all_params = []
144
+ self.model_parts = model_parts
145
+
146
+ optim_dict = {}
147
+ for model in self.model_parts:
148
+ for p in model.parameters():
149
+ if p.requires_grad:
150
+ optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
151
+ all_params.append(p)
152
+
153
+ def optim_hook(param) -> None:
154
+ optim_dict[param].step()
155
+ optim_dict[param].zero_grad()
156
+
157
+ for model in self.model_parts:
158
+ for param in model.parameters():
159
+ if param.requires_grad:
160
+ param.register_post_accumulate_grad_hook(optim_hook)
161
+
162
+ self.optimizers = list(optim_dict.values())
163
+
164
+ self._validate_length(
165
+ sum(len(list(model.parameters())) for model in self.model_parts)
166
+ )
167
+ self._post_init(all_params, optimizer_kwargs)
168
+
169
+ def step(self) -> None:
170
+ pass
171
+
172
+ def zero_grad(self) -> None:
173
+ pass
174
+
175
+
176
+ class FTOptimizersContainer(OptimizersContainer):
177
+ def __init__(
178
+ self,
179
+ model_parts: list[nn.Module],
180
+ optimizer_cls: type[T],
181
+ optimizer_kwargs: dict[str, Any],
182
+ ft_manager: "ft.Manager",
183
+ ) -> None:
184
+ super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
185
+
186
+ # Force to initialize the optimizer state so that `optim.step()`
187
+ # won't be called by state_dict() and load_state_dict().
188
+ _ = {
189
+ k: v
190
+ for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
191
+ for k, v in sd.items()
192
+ }
193
+ self.cache_state_dict: dict[str, Any] = {}
194
+ self._ft_optimizer = ft.Optimizer(ft_manager, self)
195
+ self._call_from_ft: bool = False
196
+
197
+ def init_cache_state_dict(self) -> None:
198
+ self.cache_state_dict = super().state_dict()
199
+
200
+ def state_dict(self) -> dict[str, Any]:
201
+ return self.cache_state_dict
202
+
203
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
204
+ # We have to invalidate the `cache_state_dict` because optimizer uses
205
+ # assign instead of copy when doing `load_state_dict()`. Without
206
+ # invalidating the `cache_state_dict`, there will be memory leakage.
207
+ self.cache_state_dict = {}
208
+ super().load_state_dict(state_dict)
209
+ self.init_cache_state_dict()
210
+
211
+ def step(self, *args, **kwargs) -> None:
212
+ """Calling the correct step() depending on the caller.
213
+
214
+ TorchFT's OptimizerWrapper.step() is designed to be callled only once
215
+ per train step per ft.Manager regardless how many optimizers are used.
216
+ Hence we will need to appropriately dispatch the call.
217
+ """
218
+ if self._call_from_ft:
219
+ super().step(*args, **kwargs)
220
+ else:
221
+ self._call_from_ft = True
222
+ self._ft_optimizer.step(*args, **kwargs)
223
+ self._call_from_ft = False
224
+
225
+ def zero_grad(self, *args, **kwargs) -> None:
226
+ """Calling the correct zero_grad() depending on the caller.
227
+
228
+ Check the comment in ``step()``.
229
+ """
230
+ if self._call_from_ft:
231
+ super().zero_grad(*args, **kwargs)
232
+ else:
233
+ self._call_from_ft = True
234
+ self._ft_optimizer.zero_grad(*args, **kwargs)
235
+ self._call_from_ft = False
236
+
237
+
238
+ def build_optimizers(
239
+ model_parts: list[nn.Module],
240
+ job_config: JobConfig,
241
+ ft_manager: FTManager,
242
+ ) -> OptimizersContainer:
243
+ """Create a OptimizersContainer for the given model parts and job config.
244
+
245
+ This function creates a ``OptimizersContainer`` for the given model parts.
246
+ ``job_config`` should define the correct optimizer name and parameters.
247
+ This function currently supports creating ``OptimizersContainer`` and
248
+ ``OptimizersInBackwardContainer``.
249
+
250
+ **Note**
251
+ Users who want to customize the optimizer behavior can create their own
252
+ ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
253
+ customized ``build_optimizers`` to ``TrainSpec`` will create the customized
254
+ ``OptimizersContainer``.
255
+
256
+ Args:
257
+ model_parts (List[nn.Module]): List of model parts to be optimized.
258
+ job_config (JobConfig): Job config containing the optimizer name and parameters.
259
+ """
260
+ optim_in_bwd = job_config.optimizer.early_step_in_backward
261
+ if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
262
+ raise NotImplementedError(
263
+ "Optimizers in backward is not supported with pipeline parallelism."
264
+ )
265
+ name = job_config.optimizer.name
266
+ lr = job_config.optimizer.lr
267
+ eps = job_config.optimizer.eps
268
+
269
+ optim_implementation = job_config.optimizer.implementation
270
+ assert optim_implementation in ["fused", "foreach", "for-loop"]
271
+
272
+ fused = optim_implementation == "fused"
273
+ foreach = optim_implementation == "foreach"
274
+
275
+ optimizer_kwargs = {
276
+ "lr": lr,
277
+ "eps": eps,
278
+ "betas": (0.9, 0.95),
279
+ "weight_decay": 0.1,
280
+ "fused": fused,
281
+ "foreach": foreach,
282
+ }
283
+
284
+ optimizer_classes = {
285
+ "Adam": torch.optim.Adam,
286
+ "AdamW": torch.optim.AdamW,
287
+ }
288
+ if name not in optimizer_classes:
289
+ raise NotImplementedError(f"Optimizer {name} not added.")
290
+ optimizer_cls = optimizer_classes[name]
291
+
292
+ if optim_in_bwd and ft_manager.enabled:
293
+ raise ValueError("TorchFT is not supported with optimizers in backward.")
294
+ elif optim_in_bwd:
295
+ return OptimizersInBackwardContainer(
296
+ model_parts, optimizer_cls, optimizer_kwargs
297
+ )
298
+ elif ft_manager.enabled:
299
+ return FTOptimizersContainer(
300
+ model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager
301
+ )
302
+ else:
303
+ return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
torchtitan/datasets/hf_datasets.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable
9
+
10
+ import torch
11
+
12
+ from datasets import Dataset, load_dataset
13
+ from datasets.distributed import split_dataset_by_node
14
+ from torch.distributed.checkpoint.stateful import Stateful
15
+ from torch.utils.data import IterableDataset
16
+
17
+ from torchtitan.components.dataloader import ParallelAwareDataloader
18
+ from torchtitan.components.tokenizer import Tokenizer
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ def _load_c4_dataset(dataset_path: str):
24
+ """Load C4 dataset with default configuration."""
25
+ return load_dataset(dataset_path, name="en", split="train", streaming=True)
26
+
27
+
28
+ def _process_c4_text(sample: dict[str, Any]) -> str:
29
+ """Process C4 dataset sample text."""
30
+ return sample["text"]
31
+
32
+
33
+ @dataclass
34
+ class DatasetConfig:
35
+ path: str
36
+ loader: Callable
37
+ text_processor: Callable
38
+
39
+
40
+ # Add your dataset here here - more information at docs/datasets.md
41
+ DATASETS = {
42
+ "c4": DatasetConfig(
43
+ path="allenai/c4",
44
+ loader=_load_c4_dataset,
45
+ text_processor=_process_c4_text,
46
+ ),
47
+ "c4_test": DatasetConfig(
48
+ path="tests/assets/c4_test",
49
+ loader=lambda path: load_dataset(path, split="train"),
50
+ text_processor=_process_c4_text,
51
+ ),
52
+ }
53
+
54
+
55
+ def _validate_dataset(
56
+ dataset_name: str, dataset_path: str | None = None
57
+ ) -> tuple[str, Callable, Callable]:
58
+ """Validate dataset name and path."""
59
+ if dataset_name not in DATASETS:
60
+ raise ValueError(
61
+ f"Dataset {dataset_name} is not supported. "
62
+ f"Supported datasets are: {list(DATASETS.keys())}"
63
+ )
64
+
65
+ config = DATASETS[dataset_name]
66
+ path = dataset_path or config.path
67
+ logger.info(f"Preparing {dataset_name} dataset from {path}")
68
+ return path, config.loader, config.text_processor
69
+
70
+
71
+ class HuggingFaceDataset(IterableDataset, Stateful):
72
+ def __init__(
73
+ self,
74
+ dataset_name: str,
75
+ dataset_path: str | None,
76
+ tokenizer: Tokenizer,
77
+ seq_len: int = 2048,
78
+ dp_rank: int = 0,
79
+ dp_world_size: int = 1,
80
+ infinite: bool = False,
81
+ ) -> None:
82
+ # Force lowercase for consistent comparison
83
+ dataset_name = dataset_name.lower()
84
+
85
+ path, dataset_loader, text_processor = _validate_dataset(
86
+ dataset_name, dataset_path
87
+ )
88
+ ds = dataset_loader(path)
89
+
90
+ self.dataset_name = dataset_name
91
+ self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
92
+ self._tokenizer = tokenizer
93
+ self.seq_len = seq_len
94
+ self.infinite = infinite
95
+ self._text_processor = text_processor
96
+
97
+ # Variables for checkpointing
98
+ self._sample_idx = 0
99
+ self._all_tokens: list[int] = []
100
+
101
+ def _get_data_iter(self):
102
+ if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
103
+ return iter([])
104
+
105
+ it = iter(self._data)
106
+ for _ in range(self._sample_idx):
107
+ next(it)
108
+ return it
109
+
110
+ def __iter__(self):
111
+ max_buffer_token_len = 1 + self.seq_len
112
+
113
+ while True:
114
+ for sample in self._get_data_iter():
115
+ # Use the dataset-specific text processor
116
+ sample_text = self._text_processor(sample)
117
+ sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
118
+ self._all_tokens.extend(sample_tokens)
119
+ self._sample_idx += 1
120
+
121
+ while len(self._all_tokens) >= max_buffer_token_len:
122
+ x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
123
+ # update tokens to the remaining tokens
124
+ self._all_tokens = self._all_tokens[max_buffer_token_len:]
125
+ input = x[:-1]
126
+ label = x[1:]
127
+ yield {"input": input}, label
128
+
129
+ if not self.infinite:
130
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
131
+ break
132
+ else:
133
+ # Reset offset for the next iteration
134
+ self._sample_idx = 0
135
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
136
+
137
+ def load_state_dict(self, state_dict):
138
+ self._sample_idx = state_dict["sample_idx"]
139
+ self._all_tokens = state_dict["token_buffer"]
140
+
141
+ def state_dict(self):
142
+ return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
143
+
144
+
145
+ def build_hf_dataloader(
146
+ dp_world_size: int,
147
+ dp_rank: int,
148
+ tokenizer: Tokenizer,
149
+ job_config: JobConfig,
150
+ infinite: bool = True,
151
+ ) -> ParallelAwareDataloader:
152
+ """Build a data loader for HuggingFace datasets."""
153
+ dataset_name = job_config.training.dataset
154
+ dataset_path = job_config.training.dataset_path
155
+ batch_size = job_config.training.batch_size
156
+ seq_len = job_config.training.seq_len
157
+
158
+ hf_ds = HuggingFaceDataset(
159
+ dataset_name=dataset_name,
160
+ dataset_path=dataset_path,
161
+ tokenizer=tokenizer,
162
+ seq_len=seq_len,
163
+ dp_rank=dp_rank,
164
+ dp_world_size=dp_world_size,
165
+ infinite=infinite,
166
+ )
167
+
168
+ return ParallelAwareDataloader(
169
+ dataset=hf_ds,
170
+ dp_rank=dp_rank,
171
+ dp_world_size=dp_world_size,
172
+ batch_size=batch_size,
173
+ )
torchtitan/datasets/tokenizer/tiktoken.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+ import os
11
+ from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet
12
+ from pathlib import Path
13
+ from typing import cast, Literal
14
+
15
+ import tiktoken
16
+ from tiktoken.load import load_tiktoken_bpe
17
+
18
+ from torchtitan.components.tokenizer import Tokenizer
19
+ from torchtitan.config_manager import JobConfig
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class TikTokenizer(Tokenizer):
24
+ """
25
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
26
+
27
+ Args:
28
+ model_path (str): The path to the Tiktoken model file.
29
+ """
30
+
31
+ special_tokens: dict[str, int]
32
+
33
+ num_reserved_special_tokens = 256
34
+
35
+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
36
+
37
+ def __init__(self, model_path: str):
38
+ super().__init__()
39
+ assert os.path.exists(
40
+ model_path
41
+ ), f"The tokenizer path does not exist: {model_path}"
42
+ assert os.path.isfile(model_path), model_path
43
+
44
+ mergeable_ranks = load_tiktoken_bpe(model_path)
45
+ num_base_tokens = len(mergeable_ranks)
46
+ special_tokens = [
47
+ "<|begin_of_text|>",
48
+ "<|end_of_text|>",
49
+ "<|reserved_special_token_0|>",
50
+ "<|reserved_special_token_1|>",
51
+ "<|reserved_special_token_2|>",
52
+ "<|reserved_special_token_3|>",
53
+ "<|start_header_id|>",
54
+ "<|end_header_id|>",
55
+ "<|reserved_special_token_4|>",
56
+ "<|eot_id|>", # end of turn
57
+ ] + [
58
+ f"<|reserved_special_token_{i}|>"
59
+ for i in range(5, self.num_reserved_special_tokens - 5)
60
+ ]
61
+ self.special_tokens = {
62
+ token: num_base_tokens + i for i, token in enumerate(special_tokens)
63
+ }
64
+ self.model = tiktoken.Encoding(
65
+ name=Path(model_path).name,
66
+ pat_str=self.pat_str,
67
+ mergeable_ranks=mergeable_ranks,
68
+ special_tokens=self.special_tokens,
69
+ )
70
+
71
+ self._n_words: int = self.model.n_vocab
72
+ # BOS / EOS token IDs
73
+ self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
74
+ self.eos_id: int = self.special_tokens["<|end_of_text|>"]
75
+ self.pad_id: int = -1
76
+ self.stop_tokens = {
77
+ self.special_tokens["<|end_of_text|>"],
78
+ self.special_tokens["<|eot_id|>"],
79
+ }
80
+ logger.info(
81
+ f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}"
82
+ )
83
+
84
+ def encode(
85
+ self,
86
+ s: str,
87
+ *,
88
+ bos: bool,
89
+ eos: bool,
90
+ allowed_special: Literal["all"] | AbstractSet[str] | None = None,
91
+ disallowed_special: Literal["all"] | Collection[str] | None = None,
92
+ ) -> list[int]:
93
+ """
94
+ Encodes a string into a list of token IDs.
95
+
96
+ Args:
97
+ s (str): The input string to be encoded.
98
+ bos (bool): Whether to prepend the beginning-of-sequence token.
99
+ eos (bool): Whether to append the end-of-sequence token.
100
+ allowed_tokens ("all"|set[str]): allowed special tokens in string
101
+ disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
102
+
103
+ Returns:
104
+ list[int]: A list of token IDs.
105
+
106
+ By default, setting disallowed_special=() encodes a string by ignoring
107
+ special tokens. Specifically:
108
+ - Setting `disallowed_special` to () will cause all text corresponding
109
+ to special tokens to be encoded as natural text (insteading of raising
110
+ an error).
111
+ - Setting `allowed_special` to "all" will treat all text corresponding
112
+ to special tokens to be encoded as special tokens.
113
+ """
114
+ assert type(s) is str
115
+ allowed_special = allowed_special or set()
116
+ disallowed_special = disallowed_special or ()
117
+
118
+ # The tiktoken tokenizer can handle <=400k chars without
119
+ # pyo3_runtime.PanicException.
120
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
121
+
122
+ # https://github.com/openai/tiktoken/issues/195
123
+ # Here we iterate over subsequences and split if we exceed the limit
124
+ # of max consecutive non-whitespace or whitespace characters.
125
+ MAX_NO_WHITESPACES_CHARS = 25_000
126
+
127
+ substrs = (
128
+ substr
129
+ for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
130
+ for substr in self._split_whitespaces_or_nonwhitespaces(
131
+ s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
132
+ )
133
+ )
134
+ t: list[int] = []
135
+ for substr in substrs:
136
+ t.extend(
137
+ self.model.encode(
138
+ substr,
139
+ allowed_special=allowed_special,
140
+ disallowed_special=disallowed_special,
141
+ )
142
+ )
143
+ if bos:
144
+ t.insert(0, self.bos_id)
145
+ if eos:
146
+ t.append(self.eos_id)
147
+ return t
148
+
149
+ def decode(self, t: Sequence[int]) -> str:
150
+ """
151
+ Decodes a list of token IDs into a string.
152
+
153
+ Args:
154
+ t (List[int]): The list of token IDs to be decoded.
155
+
156
+ Returns:
157
+ str: The decoded string.
158
+ """
159
+ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
160
+ return self.model.decode(cast(list[int], t))
161
+
162
+ @staticmethod
163
+ def _split_whitespaces_or_nonwhitespaces(
164
+ s: str, max_consecutive_slice_len: int
165
+ ) -> Iterator[str]:
166
+ """
167
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
168
+ consecutive whitespaces or consecutive non-whitespaces.
169
+ """
170
+ current_slice_len = 0
171
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
172
+ slice_start = 0
173
+
174
+ for i in range(len(s)):
175
+ is_now_space = s[i].isspace()
176
+
177
+ if current_slice_is_space ^ is_now_space:
178
+ current_slice_len = 1
179
+ current_slice_is_space = is_now_space
180
+ else:
181
+ current_slice_len += 1
182
+ if current_slice_len > max_consecutive_slice_len:
183
+ yield s[slice_start:i]
184
+ slice_start = i
185
+ current_slice_len = 1
186
+ yield s[slice_start:]
187
+
188
+
189
+ def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
190
+ return TikTokenizer(job_config.model.tokenizer_path)
torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (7.82 kB). View file
 
torchtitan/distributed/__pycache__/utils.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
torchtitan/experiments/deepseek_v3/LICENSE-CODE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 DeepSeek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
torchtitan/experiments/deepseek_v3/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running DeepSeek in Titan (experimental)
2
+
3
+ This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
4
+ and scripts needed to run it.
5
+
6
+ ## Inference
7
+
8
+ ### Prerequisites:
9
+
10
+ You will need to download a DeepSeek model's weights if you want to run a
11
+ pre-trained checkpoint. We provided a script to download the weights from
12
+ HuggingFace Model Hub:
13
+ ```bash
14
+ python download.py [vX]
15
+ ```
16
+ where `vX` can be v2 or v3, both are supported. You may be required to create a
17
+ HuggingFace account and log in first.
18
+
19
+ ### Running inference:
20
+
21
+ The inference script is in `generate.py`. You can run it with the following
22
+ command:
23
+ ```bash
24
+ torchrun --standalone --nproc-per-node 4 generate.py
25
+ ```
26
+ This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
27
+ default.
28
+
29
+ Alternatively, you can run inference by using `bash inference.sh`, optionally
30
+ followed by your prompt.
31
+
32
+ ## Training
33
+
34
+ The training script is in `train.py`. You can run it by the following command:
35
+ ```bash
36
+ torchrun --standalone --nproc-per-node 8 train.py
37
+ ```
38
+
39
+ This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
40
+ default, with pipeline parallel, expert parallel, and data parallel enabled.
torchtitan/experiments/deepseek_v3/generate.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # torchrun --standalone --nproc-per-node 4 generate.py
8
+
9
+ # use inference.sh "Your Question Here?" to run inference with a single prompt.
10
+
11
+ import sys
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+
17
+ from checkpoint import load_weights_from_hf
18
+ from model import DeepseekForCausalLM
19
+ from model_config import deepseek_config_registry
20
+ from torch.distributed.device_mesh import DeviceMesh
21
+ from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
22
+ from torchtitan.tools.utils import Color
23
+ from transformers import AutoTokenizer
24
+
25
+ # Uncomment the model you want to run.
26
+ model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
27
+ # model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
28
+
29
+
30
+ def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
31
+ """Parse and colorize chat output with optional colors for each role."""
32
+ lines = text.split("\n")
33
+ result = []
34
+
35
+ current_role = None
36
+ current_content = []
37
+
38
+ def _process_current_content():
39
+ if not current_role or not current_content:
40
+ return None
41
+
42
+ content = "\n".join(current_content)
43
+ if current_role == "output":
44
+ return (
45
+ f"Output: {output_color}{content}{color.reset}"
46
+ if output_color
47
+ else f"Output: {content}"
48
+ )
49
+ else:
50
+ try:
51
+ prefix, rest = current_content[0].split(":", 1)
52
+ role_color = user_color if current_role == "user" else assistant_color
53
+ if role_color:
54
+ formatted = f"{prefix}:{role_color}{rest}{color.reset}"
55
+ if len(current_content) > 1:
56
+ formatted += (
57
+ f"{role_color}\n"
58
+ + "\n".join(current_content[1:])
59
+ + f"{color.reset}"
60
+ )
61
+ return formatted
62
+ except ValueError:
63
+ pass
64
+ return content
65
+
66
+ for line in lines:
67
+ if line.startswith("Output:"):
68
+ if processed := _process_current_content():
69
+ result.append(processed)
70
+ current_role = "output"
71
+ content = line[len("Output:") :].strip()
72
+ if output_color:
73
+ content = f"Output: {output_color}{content}{color.reset}"
74
+ else:
75
+ content = f"Output: {content}"
76
+ result.append(content)
77
+ current_content = []
78
+
79
+ elif line.startswith("User:"):
80
+ if processed := _process_current_content():
81
+ result.append(processed)
82
+ current_role = "user"
83
+ current_content = [line]
84
+
85
+ elif line.startswith("Assistant:"):
86
+ if processed := _process_current_content():
87
+ result.append(processed)
88
+ current_role = "assistant"
89
+ current_content = [line]
90
+
91
+ else:
92
+ if current_content:
93
+ current_content.append(line)
94
+ elif line.strip() and current_role is None:
95
+ # Handle system message at the beginning
96
+ current_role = "output"
97
+ if output_color:
98
+ result.append(f"Output: {output_color}{line.strip()}{color.reset}")
99
+ else:
100
+ result.append(f"Output: {line.strip()}")
101
+
102
+ # Process the last segment
103
+ if processed := _process_current_content():
104
+ result.append(processed)
105
+
106
+ return "\n".join(result)
107
+
108
+
109
+ color = Color()
110
+
111
+
112
+ @dataclass
113
+ class DistConfig:
114
+ mesh: DeviceMesh
115
+ pp_mesh: DeviceMesh
116
+ ep_mesh: DeviceMesh
117
+ pp_size: int
118
+ ep_size: int
119
+ ep_rank: int
120
+ pp_rank: int
121
+ device: torch.device
122
+
123
+
124
+ def create_model(dist_config: DistConfig):
125
+ model_args = deepseek_config_registry[model_id]
126
+ model_args.ep_size = dist_config.ep_size
127
+ model_args.num_stages = dist_config.pp_size
128
+ model_args.stage_idx = dist_config.pp_rank
129
+ model_args.max_seq_len = 16384
130
+
131
+ with dist_config.device, dist_config.mesh:
132
+ model = DeepseekForCausalLM(model_args)
133
+ load_weights_from_hf(model, model_id, dist_config.device)
134
+ model.eval()
135
+ model.setup_symm_mem(torch.bfloat16, dist_config.device)
136
+
137
+ stage = PipelineStage(
138
+ model,
139
+ dist_config.pp_rank,
140
+ dist_config.pp_size,
141
+ dist_config.device,
142
+ group=dist_config.pp_mesh.get_group(),
143
+ )
144
+ pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
145
+ return model, pp_schedule
146
+
147
+
148
+ def create_dist_config(mesh: DeviceMesh):
149
+ rank = dist.get_rank()
150
+ device_count = torch.cuda.device_count()
151
+ device = torch.device("cuda", rank % device_count)
152
+
153
+ dist_config = DistConfig(
154
+ mesh=mesh,
155
+ pp_mesh=mesh["pp"],
156
+ ep_mesh=mesh["ep"],
157
+ pp_rank=mesh["pp"].get_local_rank(),
158
+ pp_size=mesh["pp"].size(),
159
+ ep_size=mesh["ep"].size(),
160
+ ep_rank=mesh["ep"].get_local_rank(),
161
+ device=device,
162
+ )
163
+ return dist_config
164
+
165
+
166
+ def decode(tokenizer, x):
167
+ output = tokenizer.decode(x[0])
168
+ # Clean up the output by removing special tokens
169
+ bos = tokenizer.bos_token
170
+ output = output.replace(bos, "")
171
+ # Truncate at end of sentence token
172
+ eos_token = tokenizer.eos_token
173
+ if eos_token and eos_token in output:
174
+ output = output.split(eos_token)[0]
175
+ colored_output = colorize_chat(
176
+ output,
177
+ user_color=color.green,
178
+ assistant_color=color.cyan,
179
+ output_color=color.blue,
180
+ )
181
+ return colored_output
182
+
183
+
184
+ @torch.inference_mode()
185
+ def generate(
186
+ model,
187
+ pp_schedule,
188
+ tokenizer,
189
+ dist_config,
190
+ messages: list[dict],
191
+ n_tokens: int = 50,
192
+ ):
193
+ rank = dist.get_rank()
194
+ device = dist_config.device
195
+ x = tokenizer.apply_chat_template(
196
+ [messages] * dist_config.pp_size,
197
+ add_generation_prompt=True,
198
+ return_tensors="pt",
199
+ )
200
+ next_idx = x.shape[-1]
201
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
202
+ x = x.to(device)
203
+
204
+ for _ in range(n_tokens):
205
+ if dist_config.pp_size > 1:
206
+ if dist_config.pp_rank == 0:
207
+ pp_schedule.step(x)
208
+ torch.distributed.broadcast(
209
+ x,
210
+ group=dist_config.pp_mesh.get_group(),
211
+ group_src=dist_config.pp_size - 1,
212
+ )
213
+ elif dist_config.pp_rank == dist_config.pp_size - 1:
214
+ preds = pp_schedule.step()
215
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
216
+ x[:, next_idx] = next_token
217
+ torch.distributed.broadcast(
218
+ x,
219
+ group=dist_config.pp_mesh.get_group(),
220
+ group_src=dist_config.pp_size - 1,
221
+ )
222
+ else:
223
+ pp_schedule.step()
224
+ torch.distributed.broadcast(
225
+ x,
226
+ group=dist_config.pp_mesh.get_group(),
227
+ group_src=dist_config.pp_size - 1,
228
+ )
229
+
230
+ next_idx += 1
231
+ else:
232
+ preds = model(x)
233
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
234
+ x[:, next_idx] = next_token
235
+ next_idx += 1
236
+
237
+ if rank == 0:
238
+ colored_output = decode(tokenizer, x)
239
+ print(f"Without CUDA Graph:\n{colored_output}")
240
+
241
+
242
+ @torch.inference_mode()
243
+ def generate_with_cuda_graph(
244
+ model,
245
+ tokenizer,
246
+ dist_config,
247
+ messages: list[dict],
248
+ n_tokens: int = 10,
249
+ ):
250
+ rank = dist.get_rank()
251
+ device = dist_config.device
252
+ x = tokenizer.apply_chat_template(
253
+ [messages] * dist_config.pp_size,
254
+ add_generation_prompt=True,
255
+ return_tensors="pt",
256
+ )
257
+ next_idx = x.shape[-1]
258
+ x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
259
+ x = x.to(device)
260
+
261
+ torch.cuda.synchronize()
262
+
263
+ # Create CUDA graph
264
+ g = torch.cuda.CUDAGraph()
265
+ with torch.cuda.graph(g):
266
+ preds = model(x)
267
+
268
+ # Run CUDA graph
269
+ for _ in range(n_tokens):
270
+ g.replay()
271
+ next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
272
+ x[:, next_idx] = next_token
273
+ next_idx += 1
274
+
275
+ if rank == 0:
276
+ colored_output = decode(tokenizer, x)
277
+ print(f"With CUDA Graph:\n{colored_output}")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # Get user prompt from command line arguments
282
+ user_prompt = "What is 2+2?" # Default prompt
283
+ if len(sys.argv) > 1:
284
+ user_prompt = sys.argv[1]
285
+
286
+ mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
287
+ rank = dist.get_rank()
288
+ if rank == 0:
289
+ print(
290
+ f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
291
+ )
292
+
293
+ dist_config = create_dist_config(mesh)
294
+ model, pp_schedule = create_model(dist_config)
295
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
296
+
297
+ messages = [
298
+ {"role": "system", "content": "You are a helpful assistant."},
299
+ {"role": "user", "content": user_prompt},
300
+ ]
301
+
302
+ generate(model, pp_schedule, tokenizer, dist_config, messages)
303
+ generate_with_cuda_graph(model, tokenizer, dist_config, messages)
304
+
305
+ if rank == 0:
306
+ print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
307
+
308
+ dist.destroy_process_group()
torchtitan/experiments/deepseek_v3/indices.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+
12
+ __all__ = ["generate_permute_indices"]
13
+
14
+
15
+ @triton.jit
16
+ def fill_indices_kernel(
17
+ tokens_per_expert_group_ptr, # *Pointer* to first input vector.
18
+ start_index_values_ptr, # *Pointer* to second input vector.
19
+ write_offsets_ptr, # *Pointer* to third input vector.
20
+ output_ptr, # *Pointer* to output vector.
21
+ experts_per_rank, # Number of experts per rank.
22
+ num_ranks, # Number of expert ranks.
23
+ ):
24
+ # There are multiple 'programs' processing different data. We identify which program
25
+ # we are here:
26
+ pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
27
+ # The total number of programs in the launch grid.
28
+ num_programs = tl.num_programs(axis=0)
29
+ # We map the programs (blocks) to the experts.
30
+ for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
31
+ # Read this expert's write offset.
32
+ write_offset = tl.load(write_offsets_ptr + expert_id)
33
+ # Loop over the ranks.
34
+ for r in tl.range(num_ranks):
35
+ # Slot in the tokens_per_expert_group array.
36
+ i = r * experts_per_rank + expert_id
37
+ start_index = tl.load(start_index_values_ptr + i)
38
+ length = tl.load(tokens_per_expert_group_ptr + i)
39
+ # Write the indices.
40
+ for l in tl.range(length):
41
+ val = start_index + l
42
+ tl.store(output_ptr + write_offset + l, val)
43
+ write_offset += length
44
+
45
+
46
+ def fill_indices(
47
+ tokens_per_expert_group: torch.Tensor,
48
+ start_index_values: torch.Tensor,
49
+ write_offsets: torch.Tensor,
50
+ experts_per_rank: int,
51
+ num_ranks: int,
52
+ max_len: int,
53
+ ):
54
+ # We need to preallocate the output.
55
+ permuted_indices = torch.full(
56
+ (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
57
+ )
58
+ # Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
59
+ # In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
60
+ grid = lambda meta: (1,)
61
+ # Each torch.tensor object is implicitly converted into a pointer to its first element.
62
+ fill_indices_kernel[grid](
63
+ tokens_per_expert_group,
64
+ start_index_values,
65
+ write_offsets,
66
+ permuted_indices,
67
+ experts_per_rank,
68
+ num_ranks,
69
+ )
70
+ return permuted_indices
71
+
72
+
73
+ def fill_indices_cpu(
74
+ tokens_per_expert_group: torch.Tensor,
75
+ start_index_values: torch.Tensor,
76
+ write_offsets: torch.Tensor,
77
+ experts_per_rank: int,
78
+ num_ranks: int,
79
+ max_len: int,
80
+ ):
81
+ # We need to preallocate the output.
82
+ permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
83
+ # Fill the permuted indices
84
+ # For each local expert
85
+ for e in range(experts_per_rank):
86
+ write_start = write_offsets[e]
87
+ # For each remote rank
88
+ for r in range(num_ranks):
89
+ i = r * experts_per_rank + e
90
+ start_index = start_index_values[i]
91
+ length = tokens_per_expert_group[i]
92
+ # Fill in the indices
93
+ permuted_indices[write_start : write_start + length] = torch.arange(
94
+ start_index, start_index + length
95
+ )
96
+ write_start += length
97
+ return permuted_indices
98
+
99
+
100
+ def generate_permute_indices(
101
+ tokens_per_expert_group: torch.Tensor,
102
+ experts_per_rank: int,
103
+ num_ranks: int,
104
+ max_len: int,
105
+ alignment: int,
106
+ use_cpu: bool = False,
107
+ ):
108
+ # Prepare permutation indices and the number of tokens for each expert. The
109
+ # permutation indices are the indices of the tokens for each expert. The
110
+ # number of tokens for each expert is the sum of the number of tokens for
111
+ # such experts from all ranks. This number is aligned to the provided
112
+ # alignment requirement (usually comes from group gemm).
113
+
114
+ # Args:
115
+ # tokens_per_expert_group: number of tokens for each expert from all ranks.
116
+ # experts_per_rank: number of experts per rank.
117
+ # num_ranks: number of ranks.
118
+ # max_len: maximum length of the output index vector. If greater than
119
+ # total number of tokens, the remaining indices are set to -1.
120
+ # alignment: alignment for each returned element in `m_sizes`.
121
+ # use_cpu: whether to use cpu or gpu.
122
+ # Returns:
123
+ # permuted_indices: permutation indices.
124
+ # m_sizes: number of tokens for each expert.
125
+
126
+ # `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
127
+ # From: | rank 0 | rank 1 |
128
+ # To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
129
+ # | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
130
+
131
+ # Prefix sum to get the start index value of each expert
132
+ start_index_values = (
133
+ torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
134
+ )
135
+ # Chunk sizes for each expert
136
+ chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
137
+ # Align the chunk sizes to the given alignment
138
+ m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
139
+ torch.int32
140
+ )
141
+ # Perform another prefix sum to get the write offset of each expert in `permuted_indices`
142
+ write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
143
+ # Select the method to fill the permuted indices
144
+ fill_fn = fill_indices_cpu if use_cpu else fill_indices
145
+ # Fill the permuted indices
146
+ permuted_indices = fill_fn(
147
+ tokens_per_expert_group,
148
+ start_index_values,
149
+ write_offsets,
150
+ experts_per_rank,
151
+ num_ranks,
152
+ max_len,
153
+ )
154
+ return permuted_indices, m_sizes
155
+
156
+
157
+ # Below is for testing only
158
+
159
+
160
+ def test():
161
+ device = torch.device("cuda", 0)
162
+ experts_per_rank = 4
163
+ num_ranks = 4
164
+ tokens_per_expert_group = torch.full(
165
+ (num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
166
+ )
167
+ max_len = 128
168
+ alignment = 32
169
+ # Use the GPU kernel
170
+ permuted_indices_gpu, m_sizes = generate_permute_indices(
171
+ tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
172
+ )
173
+ # Use the CPU method
174
+ permuted_indices_cpu, _ = generate_permute_indices(
175
+ tokens_per_expert_group,
176
+ experts_per_rank,
177
+ num_ranks,
178
+ max_len,
179
+ alignment,
180
+ use_cpu=True,
181
+ )
182
+ # Check that the results are the same
183
+ assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
184
+ assert torch.equal(
185
+ torch.remainder(m_sizes, alignment),
186
+ torch.zeros(experts_per_rank, device=device),
187
+ )
188
+ # Print the results
189
+ print(permuted_indices_gpu)
190
+ print(m_sizes)
191
+ print("Success")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ test()
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
8
+
9
+ __all__ = [
10
+ "OnDeviceAllToAllV",
11
+ ]
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.distributed._symmetric_memory as symm_mem
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from .triton_barrier import blockwise_barrier
14
+ from .triton_utils import sync_threads
15
+
16
+
17
+ @triton.jit
18
+ def _exchange_row_offsets(
19
+ split_sizes_ptrs,
20
+ rank: tl.constexpr,
21
+ world_size: tl.constexpr,
22
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
23
+ ):
24
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
25
+
26
+ # split_sizes_ptr for all ranks
27
+ # All these vector stacks into split_sizes_matrix
28
+ split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
29
+
30
+ # split_sizes_matrix[remote_rank, :]
31
+ input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
32
+ tl.pointer_type(tl.int64)
33
+ )
34
+
35
+ offsets_ = tl.arange(0, world_size)
36
+ input_split_sizes = tl.load(
37
+ input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
38
+ )
39
+
40
+ num_rows = tl.load(input_split_sizes_ptr + rank)
41
+ input_row_offset = tl.sum(input_split_sizes) - num_rows
42
+
43
+ # split_sizes_matrix[:, rank]
44
+ output_split_sizes_ptrs = (
45
+ tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
46
+ )
47
+ output_split_sizes = tl.load(
48
+ output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
49
+ )
50
+ output_row_offset = tl.sum(output_split_sizes) - num_rows
51
+
52
+ return input_row_offset, output_row_offset, num_rows
53
+
54
+
55
+ @triton.jit
56
+ def on_device_all_to_all_v_kernel(
57
+ output_ptr,
58
+ output_splits_ptr,
59
+ input_ptrs,
60
+ input_splits_ptr,
61
+ signal_pad_ptrs,
62
+ dim: tl.constexpr, # Separate dim for easier vectorization
63
+ rank: tl.constexpr,
64
+ world_size: tl.constexpr,
65
+ BLOCKS_PER_REMOTE_RANK: tl.constexpr,
66
+ UNROLL_FACTOR: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ ):
69
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
70
+ sync_threads()
71
+
72
+ remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
73
+ block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
74
+
75
+ input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
76
+ input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
77
+ )
78
+
79
+ output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
80
+ if block_offset == 0:
81
+ # Update output_splits
82
+ tl.store(output_splits_ptr + remote_rank, num_rows)
83
+
84
+ input_ptr = (
85
+ tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
86
+ tl.pointer_type(tl.bfloat16)
87
+ )
88
+ + input_row_offset * dim
89
+ )
90
+ output_ptr = output_ptr + output_row_offset * dim
91
+
92
+ outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
93
+ outer_loop_iters_per_rank = tl.cdiv(
94
+ tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
95
+ )
96
+ numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
97
+ offset = numel_per_rank * block_offset
98
+ end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
99
+
100
+ unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
101
+ for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
102
+ datas = []
103
+ for j in tl.range(
104
+ i,
105
+ i + outer_loop_step,
106
+ BLOCK_SIZE,
107
+ loop_unroll_factor=UNROLL_FACTOR,
108
+ ):
109
+ offsets = j + tl.arange(0, BLOCK_SIZE)
110
+ data = tl.load(input_ptr + offsets)
111
+ tl.store(output_ptr + offsets, data)
112
+
113
+ offset += unroll_region_size
114
+ while offset < end:
115
+ offsets = offset + tl.arange(0, BLOCK_SIZE)
116
+ mask = offsets < num_rows * dim
117
+ data = tl.load(input_ptr + offsets, mask=mask)
118
+ tl.store(output_ptr + offsets, data, mask=mask)
119
+ offset += BLOCK_SIZE
120
+
121
+ sync_threads()
122
+ blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
123
+ return
124
+
125
+
126
+ def _on_device_all_to_all_v(
127
+ output: torch.Tensor,
128
+ output_splits: torch.Tensor,
129
+ input: torch.Tensor,
130
+ input_splits: torch.Tensor,
131
+ group: dist.ProcessGroup = dist.group.WORLD,
132
+ BLOCKS_PER_REMOTE_RANK=8,
133
+ UNROLL_FACTOR: int = 8,
134
+ BLOCK_SIZE: int = 16384,
135
+ ):
136
+ assert output.dim() == 2, f"{output.shape}"
137
+ assert input.dim() == 2, f"{input.shape}"
138
+ assert output.shape[1] == input.shape[1]
139
+
140
+ dim = output.shape[1]
141
+ input_hdl = symm_mem.rendezvous(input, group=group)
142
+ input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
143
+
144
+ num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
145
+ kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
146
+ output,
147
+ output_splits,
148
+ input_hdl.buffer_ptrs_dev,
149
+ input_splits_hdl.buffer_ptrs_dev,
150
+ input_hdl.signal_pad_ptrs_dev,
151
+ dim=dim,
152
+ rank=input_hdl.rank,
153
+ world_size=input_hdl.world_size,
154
+ BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
155
+ UNROLL_FACTOR=UNROLL_FACTOR,
156
+ BLOCK_SIZE=BLOCK_SIZE,
157
+ num_warps=16,
158
+ )
159
+ # log_triton_kernel(kernel)
160
+ return output
161
+
162
+
163
+ class OnDeviceAllToAllV(torch.autograd.Function):
164
+ # A symmetric memory holding the grad_output during backward
165
+ grad_output_buf = None
166
+ # A symmetric memory for exchanges split sizes during both forward and backward
167
+ splits_buf = None
168
+ # Maximum output length (need to be set before use of OnDeviceAllToAllV)
169
+ max_output_len = None
170
+
171
+ @staticmethod
172
+ def forward(
173
+ ctx,
174
+ input: torch.Tensor,
175
+ input_splits: torch.Tensor,
176
+ group: dist.ProcessGroup = dist.group.WORLD,
177
+ ):
178
+ """
179
+ Args:
180
+ input: input tensor with data for all ranks concatenated.
181
+ input_splits: input splits of shape (group.world_size,)
182
+ group: process group to scope the collective.
183
+ """
184
+ # Initialize input splits buffer (one time only)
185
+ if OnDeviceAllToAllV.splits_buf is None:
186
+ OnDeviceAllToAllV.splits_buf = symm_mem.empty(
187
+ *input_splits.shape,
188
+ dtype=input_splits.dtype,
189
+ device=input_splits.device,
190
+ )
191
+
192
+ if OnDeviceAllToAllV.max_output_len is None:
193
+ raise RuntimeError(
194
+ "Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
195
+ )
196
+
197
+ # Allocate output buffer
198
+ output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
199
+ # Allocate output splits tensor
200
+ output_splits = torch.empty_like(input_splits)
201
+ # Copy input splits to the buffer
202
+ OnDeviceAllToAllV.splits_buf.copy_(input_splits)
203
+
204
+ # Shuffle input to output
205
+ _on_device_all_to_all_v(
206
+ output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
207
+ )
208
+
209
+ # Output splits in forward is the input splits in backward
210
+ ctx.save_for_backward(output_splits)
211
+ ctx.group = group
212
+ ctx.input_shape = input.shape
213
+ return output, output_splits
214
+
215
+ @staticmethod
216
+ def backward(ctx, grad_output, grad_splits):
217
+ """
218
+ Backward is implemented as a shuffle of the output's gradients to the input.
219
+ Args:
220
+ `grad_output`: output's gradients passed from the downstream.
221
+ `grad_splits`: unused.
222
+ """
223
+
224
+ # Initialize grad_output buffer (one time only)
225
+ if OnDeviceAllToAllV.grad_output_buf is None:
226
+ assert (
227
+ OnDeviceAllToAllV.max_output_len is not None
228
+ ), "`max_output_len` not set"
229
+ OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
230
+ OnDeviceAllToAllV.max_output_len,
231
+ *grad_output.shape[1:],
232
+ dtype=grad_output.dtype,
233
+ device=grad_output.device,
234
+ )
235
+
236
+ # TODO: is there a way to tell autograd to feed grad_output directly to
237
+ # our symm_mem buffer?
238
+ OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
239
+ grad_output
240
+ )
241
+
242
+ # Size info
243
+ (grad_output_splits,) = ctx.saved_tensors
244
+ OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
245
+ grad_input_splits = torch.empty_like(grad_output_splits) # unused
246
+ grad_input = grad_output.new_empty(*ctx.input_shape)
247
+
248
+ # Shuffle gradients back to the input
249
+ _on_device_all_to_all_v(
250
+ grad_input,
251
+ grad_input_splits,
252
+ OnDeviceAllToAllV.grad_output_buf,
253
+ OnDeviceAllToAllV.splits_buf,
254
+ group=ctx.group,
255
+ )
256
+ return grad_input, None, None
257
+
258
+
259
+ # Alias
260
+ on_device_all_to_all_v = OnDeviceAllToAllV.apply
torchtitan/experiments/flux/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FLUX model in torchtitan
2
+
3
+ ## Overview
4
+
5
+ ## Usage
6
+ First, download the autoencoder model from HuggingFace with your own access token:
7
+ ```bash
8
+ python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
9
+ ```
10
+ This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
11
+
12
+ Run the following command to train the model on a single GPU:
13
+ ```bash
14
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
15
+ ```
16
+
17
+ ## TODO
18
+ - [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
19
+ - [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
20
+ - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
21
+ - [ ] Support for distributed checkpointing and loading
22
+ - [ ] Implement init_weights() function to initialize the model weights
23
+ - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
torchtitan/experiments/flux/__init__.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
10
+ from torchtitan.components.optimizer import build_optimizers
11
+ from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
12
+ from torchtitan.experiments.flux.loss import build_mse_loss
13
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
14
+ from torchtitan.experiments.flux.parallelize_flux import parallelize_flux
15
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16
+
17
+ from .model.model import FluxModel, FluxModelArgs
18
+
19
+ __all__ = [
20
+ "FluxModelArgs",
21
+ "FluxModel",
22
+ "flux_configs",
23
+ "parallelize_flux",
24
+ ]
25
+
26
+
27
+ flux_configs = {
28
+ "flux-dev": FluxModelArgs(
29
+ in_channels=64,
30
+ out_channels=64,
31
+ vec_in_dim=768,
32
+ context_in_dim=512,
33
+ hidden_size=3072,
34
+ mlp_ratio=4.0,
35
+ num_heads=24,
36
+ depth=19,
37
+ depth_single_blocks=38,
38
+ axes_dim=(16, 56, 56),
39
+ theta=10_000,
40
+ qkv_bias=True,
41
+ guidance_embed=True,
42
+ autoencoder_params=AutoEncoderParams(
43
+ resolution=256,
44
+ in_channels=3,
45
+ ch=128,
46
+ out_ch=3,
47
+ ch_mult=(1, 2, 4, 4),
48
+ num_res_blocks=2,
49
+ z_channels=16,
50
+ scale_factor=0.3611,
51
+ shift_factor=0.1159,
52
+ ),
53
+ ),
54
+ "flux-schnell": FluxModelArgs(
55
+ in_channels=64,
56
+ out_channels=64,
57
+ vec_in_dim=768,
58
+ context_in_dim=4096,
59
+ hidden_size=3072,
60
+ mlp_ratio=4.0,
61
+ num_heads=24,
62
+ depth=19,
63
+ depth_single_blocks=38,
64
+ axes_dim=(16, 56, 56),
65
+ theta=10_000,
66
+ qkv_bias=True,
67
+ guidance_embed=False,
68
+ autoencoder_params=AutoEncoderParams(
69
+ resolution=256,
70
+ in_channels=3,
71
+ ch=128,
72
+ out_ch=3,
73
+ ch_mult=(1, 2, 4, 4),
74
+ num_res_blocks=2,
75
+ z_channels=16,
76
+ scale_factor=0.3611,
77
+ shift_factor=0.1159,
78
+ ),
79
+ ),
80
+ "flux-debug": FluxModelArgs(
81
+ in_channels=64,
82
+ out_channels=64,
83
+ vec_in_dim=768,
84
+ context_in_dim=512,
85
+ hidden_size=512,
86
+ mlp_ratio=4.0,
87
+ num_heads=4,
88
+ depth=2,
89
+ depth_single_blocks=2,
90
+ axes_dim=(16, 56, 56),
91
+ theta=10_000,
92
+ qkv_bias=True,
93
+ guidance_embed=True,
94
+ autoencoder_params=AutoEncoderParams(
95
+ resolution=256,
96
+ in_channels=3,
97
+ ch=128,
98
+ out_ch=3,
99
+ ch_mult=(1, 2, 4, 4),
100
+ num_res_blocks=2,
101
+ z_channels=16,
102
+ scale_factor=0.3611,
103
+ shift_factor=0.1159,
104
+ ),
105
+ ),
106
+ }
107
+
108
+
109
+ register_train_spec(
110
+ TrainSpec(
111
+ name="flux",
112
+ cls=FluxModel,
113
+ config=flux_configs,
114
+ parallelize_fn=parallelize_flux,
115
+ pipelining_fn=None,
116
+ build_optimizers_fn=build_optimizers,
117
+ build_lr_schedulers_fn=build_lr_schedulers,
118
+ build_dataloader_fn=build_flux_dataloader,
119
+ build_tokenizer_fn=None,
120
+ build_loss_fn=build_mse_loss,
121
+ )
122
+ )
torchtitan/experiments/flux/dataset/tokenizer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
9
+
10
+
11
+ from typing import List
12
+
13
+ from torchtitan.components.tokenizer import Tokenizer
14
+ from transformers import CLIPTokenizer, T5Tokenizer
15
+
16
+
17
+ class FluxTokenizer(Tokenizer):
18
+ """
19
+ Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
20
+
21
+ Args:
22
+ model_path (str): Path to the tokenzier from hugging face.
23
+
24
+ """
25
+
26
+ def __init__(self, model_path: str = "t5-small", max_length: int = 77):
27
+ super().__init__()
28
+ self._n_words = 8 # TODO(jianiw): check
29
+ self._max_length = max_length
30
+
31
+ self.is_clip = model_path.startswith("openai")
32
+
33
+ if self.is_clip:
34
+ self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
35
+ model_path, max_length=max_length
36
+ )
37
+ else:
38
+ self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
39
+ model_path, max_length=max_length
40
+ )
41
+
42
+ def encode(
43
+ self,
44
+ s: str,
45
+ ) -> List[int]:
46
+ """
47
+ Encode the prompt text into tokens.
48
+ """
49
+ tokens = self._tokenizer(
50
+ s,
51
+ truncation=True,
52
+ max_length=self._max_length,
53
+ return_length=False,
54
+ return_overflowing_tokens=False,
55
+ padding="max_length",
56
+ return_tensors="pt", # return pytorch tensors, default return List[int]
57
+ )["input_ids"]
58
+ return tokens
59
+
60
+ def decode(self, t: List[int]) -> str:
61
+ """
62
+ Decode function. This function will not be called.
63
+ """
64
+ return self._tokenizer.decode(t)
torchtitan/experiments/flux/model/autoencoder.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from safetensors.torch import load_file as load_sft
13
+ from torch import nn, Tensor
14
+
15
+
16
+ @dataclass
17
+ class AutoEncoderParams:
18
+ resolution: int = 256
19
+ in_channels: int = 3
20
+ ch: int = 128
21
+ out_ch: int = 3
22
+ ch_mult: tuple[int] = (1, 2, 4, 4)
23
+ num_res_blocks: int = 2
24
+ z_channels: int = 16
25
+ scale_factor: float = 0.3611
26
+ shift_factor: float = 0.1159
27
+
28
+
29
+ def swish(x: Tensor) -> Tensor:
30
+ return x * torch.sigmoid(x)
31
+
32
+
33
+ class AttnBlock(nn.Module):
34
+ def __init__(self, in_channels: int):
35
+ super().__init__()
36
+ self.in_channels = in_channels
37
+
38
+ self.norm = nn.GroupNorm(
39
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
40
+ )
41
+
42
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
43
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
44
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
45
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
46
+
47
+ def attention(self, h_: Tensor) -> Tensor:
48
+ h_ = self.norm(h_)
49
+ q = self.q(h_)
50
+ k = self.k(h_)
51
+ v = self.v(h_)
52
+
53
+ b, c, h, w = q.shape
54
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
55
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
56
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
57
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
58
+
59
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ return x + self.proj_out(self.attention(x))
63
+
64
+
65
+ class ResnetBlock(nn.Module):
66
+ def __init__(self, in_channels: int, out_channels: int):
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+
72
+ self.norm1 = nn.GroupNorm(
73
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
74
+ )
75
+ self.conv1 = nn.Conv2d(
76
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
77
+ )
78
+ self.norm2 = nn.GroupNorm(
79
+ num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
80
+ )
81
+ self.conv2 = nn.Conv2d(
82
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
83
+ )
84
+ if self.in_channels != self.out_channels:
85
+ self.nin_shortcut = nn.Conv2d(
86
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
87
+ )
88
+
89
+ def forward(self, x):
90
+ h = x
91
+ h = self.norm1(h)
92
+ h = swish(h)
93
+ h = self.conv1(h)
94
+
95
+ h = self.norm2(h)
96
+ h = swish(h)
97
+ h = self.conv2(h)
98
+
99
+ if self.in_channels != self.out_channels:
100
+ x = self.nin_shortcut(x)
101
+
102
+ return x + h
103
+
104
+
105
+ class Downsample(nn.Module):
106
+ def __init__(self, in_channels: int):
107
+ super().__init__()
108
+ # no asymmetric padding in torch conv, must do it ourselves
109
+ self.conv = nn.Conv2d(
110
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
111
+ )
112
+
113
+ def forward(self, x: Tensor):
114
+ pad = (0, 1, 0, 1)
115
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
116
+ x = self.conv(x)
117
+ return x
118
+
119
+
120
+ class Upsample(nn.Module):
121
+ def __init__(self, in_channels: int):
122
+ super().__init__()
123
+ self.conv = nn.Conv2d(
124
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
125
+ )
126
+
127
+ def forward(self, x: Tensor):
128
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
129
+ x = self.conv(x)
130
+ return x
131
+
132
+
133
+ class Encoder(nn.Module):
134
+ def __init__(
135
+ self,
136
+ resolution: int,
137
+ in_channels: int,
138
+ ch: int,
139
+ ch_mult: list[int],
140
+ num_res_blocks: int,
141
+ z_channels: int,
142
+ ):
143
+ super().__init__()
144
+ self.ch = ch
145
+ self.num_resolutions = len(ch_mult)
146
+ self.num_res_blocks = num_res_blocks
147
+ self.resolution = resolution
148
+ self.in_channels = in_channels
149
+ # downsampling
150
+ self.conv_in = nn.Conv2d(
151
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
152
+ )
153
+
154
+ curr_res = resolution
155
+ in_ch_mult = (1,) + tuple(ch_mult)
156
+ self.in_ch_mult = in_ch_mult
157
+ self.down = nn.ModuleList()
158
+ block_in = self.ch
159
+ for i_level in range(self.num_resolutions):
160
+ block = nn.ModuleList()
161
+ attn = nn.ModuleList()
162
+ block_in = ch * in_ch_mult[i_level]
163
+ block_out = ch * ch_mult[i_level]
164
+ for _ in range(self.num_res_blocks):
165
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
166
+ block_in = block_out
167
+ down = nn.Module()
168
+ down.block = block
169
+ down.attn = attn
170
+ if i_level != self.num_resolutions - 1:
171
+ down.downsample = Downsample(block_in)
172
+ curr_res = curr_res // 2
173
+ self.down.append(down)
174
+
175
+ # middle
176
+ self.mid = nn.Module()
177
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
178
+ self.mid.attn_1 = AttnBlock(block_in)
179
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
180
+
181
+ # end
182
+ self.norm_out = nn.GroupNorm(
183
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
184
+ )
185
+ self.conv_out = nn.Conv2d(
186
+ block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
187
+ )
188
+
189
+ def forward(self, x: Tensor) -> Tensor:
190
+ # downsampling
191
+ hs = [self.conv_in(x)]
192
+ for i_level in range(self.num_resolutions):
193
+ for i_block in range(self.num_res_blocks):
194
+ h = self.down[i_level].block[i_block](hs[-1])
195
+ if len(self.down[i_level].attn) > 0:
196
+ h = self.down[i_level].attn[i_block](h)
197
+ hs.append(h)
198
+ if i_level != self.num_resolutions - 1:
199
+ hs.append(self.down[i_level].downsample(hs[-1]))
200
+
201
+ # middle
202
+ h = hs[-1]
203
+ h = self.mid.block_1(h)
204
+ h = self.mid.attn_1(h)
205
+ h = self.mid.block_2(h)
206
+ # end
207
+ h = self.norm_out(h)
208
+ h = swish(h)
209
+ h = self.conv_out(h)
210
+ return h
211
+
212
+
213
+ class Decoder(nn.Module):
214
+ def __init__(
215
+ self,
216
+ ch: int,
217
+ out_ch: int,
218
+ ch_mult: list[int],
219
+ num_res_blocks: int,
220
+ in_channels: int,
221
+ resolution: int,
222
+ z_channels: int,
223
+ ):
224
+ super().__init__()
225
+ self.ch = ch
226
+ self.num_resolutions = len(ch_mult)
227
+ self.num_res_blocks = num_res_blocks
228
+ self.resolution = resolution
229
+ self.in_channels = in_channels
230
+ self.ffactor = 2 ** (self.num_resolutions - 1)
231
+
232
+ # compute in_ch_mult, block_in and curr_res at lowest res
233
+ block_in = ch * ch_mult[self.num_resolutions - 1]
234
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
235
+ self.z_shape = (1, z_channels, curr_res, curr_res)
236
+
237
+ # z to block_in
238
+ self.conv_in = nn.Conv2d(
239
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
240
+ )
241
+
242
+ # middle
243
+ self.mid = nn.Module()
244
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
245
+ self.mid.attn_1 = AttnBlock(block_in)
246
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
247
+
248
+ # upsampling
249
+ self.up = nn.ModuleList()
250
+ for i_level in reversed(range(self.num_resolutions)):
251
+ block = nn.ModuleList()
252
+ attn = nn.ModuleList()
253
+ block_out = ch * ch_mult[i_level]
254
+ for _ in range(self.num_res_blocks + 1):
255
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
256
+ block_in = block_out
257
+ up = nn.Module()
258
+ up.block = block
259
+ up.attn = attn
260
+ if i_level != 0:
261
+ up.upsample = Upsample(block_in)
262
+ curr_res = curr_res * 2
263
+ self.up.insert(0, up) # prepend to get consistent order
264
+
265
+ # end
266
+ self.norm_out = nn.GroupNorm(
267
+ num_groups=32, num_channels=block_in, eps=1e-6, affine=True
268
+ )
269
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
270
+
271
+ def forward(self, z: Tensor) -> Tensor:
272
+ # get dtype for proper tracing
273
+ upscale_dtype = next(self.up.parameters()).dtype
274
+
275
+ # z to block_in
276
+ h = self.conv_in(z)
277
+
278
+ # middle
279
+ h = self.mid.block_1(h)
280
+ h = self.mid.attn_1(h)
281
+ h = self.mid.block_2(h)
282
+
283
+ # cast to proper dtype
284
+ h = h.to(upscale_dtype)
285
+ # upsampling
286
+ for i_level in reversed(range(self.num_resolutions)):
287
+ for i_block in range(self.num_res_blocks + 1):
288
+ h = self.up[i_level].block[i_block](h)
289
+ if len(self.up[i_level].attn) > 0:
290
+ h = self.up[i_level].attn[i_block](h)
291
+ if i_level != 0:
292
+ h = self.up[i_level].upsample(h)
293
+
294
+ # end
295
+ h = self.norm_out(h)
296
+ h = swish(h)
297
+ h = self.conv_out(h)
298
+ return h
299
+
300
+
301
+ class DiagonalGaussian(nn.Module):
302
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
303
+ super().__init__()
304
+ self.sample = sample
305
+ self.chunk_dim = chunk_dim
306
+
307
+ def forward(self, z: Tensor) -> Tensor:
308
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
309
+ if self.sample:
310
+ std = torch.exp(0.5 * logvar)
311
+ return mean + std * torch.randn_like(mean)
312
+ else:
313
+ return mean
314
+
315
+
316
+ class AutoEncoder(nn.Module):
317
+ def __init__(self, params: AutoEncoderParams):
318
+ super().__init__()
319
+ self.params = params
320
+ self.encoder = Encoder(
321
+ resolution=params.resolution,
322
+ in_channels=params.in_channels,
323
+ ch=params.ch,
324
+ ch_mult=params.ch_mult,
325
+ num_res_blocks=params.num_res_blocks,
326
+ z_channels=params.z_channels,
327
+ )
328
+ self.decoder = Decoder(
329
+ resolution=params.resolution,
330
+ in_channels=params.in_channels,
331
+ ch=params.ch,
332
+ out_ch=params.out_ch,
333
+ ch_mult=params.ch_mult,
334
+ num_res_blocks=params.num_res_blocks,
335
+ z_channels=params.z_channels,
336
+ )
337
+ self.reg = DiagonalGaussian()
338
+
339
+ self.scale_factor = params.scale_factor
340
+ self.shift_factor = params.shift_factor
341
+
342
+ def encode(self, x: Tensor) -> Tensor:
343
+ z = self.reg(self.encoder(x))
344
+ z = self.scale_factor * (z - self.shift_factor)
345
+ return z
346
+
347
+ def decode(self, z: Tensor) -> Tensor:
348
+ z = z / self.scale_factor + self.shift_factor
349
+ return self.decoder(z)
350
+
351
+ def forward(self, x: Tensor) -> Tensor:
352
+ return self.decode(self.encode(x))
353
+
354
+
355
+ def load_ae(
356
+ ckpt_path: str,
357
+ autoencoder_params: AutoEncoderParams,
358
+ device: str | torch.device = "cuda",
359
+ dtype=torch.bfloat16,
360
+ ) -> AutoEncoder:
361
+ """
362
+ Load the autoencoder from the given model name.
363
+ Args:
364
+ name (str): The name of the autoencoder.
365
+ device (str or torch.device): The device to load the autoencoder to.
366
+ Returns:
367
+ AutoEncoder: The loaded autoencoder.
368
+ """
369
+ # Loading the autoencoder
370
+ print("Init AE")
371
+ with torch.device(device):
372
+ ae = AutoEncoder(autoencoder_params)
373
+
374
+ if not os.path.exists(ckpt_path):
375
+ raise ValueError(
376
+ f"Autoencoder path {ckpt_path} does not exist. Please download it first."
377
+ )
378
+
379
+ if ckpt_path is not None:
380
+ sd = load_sft(ckpt_path, device=str(device))
381
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
382
+ if len(missing) > 0:
383
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
384
+ if len(unexpected) > 0:
385
+ print(
386
+ f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
387
+ )
388
+ return ae.to(dtype=dtype)
torchtitan/experiments/flux/model/hf_embedder.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torch import nn, Tensor
8
+ from transformers import CLIPTextModel, T5EncoderModel
9
+
10
+
11
+ class FluxEmbedder(nn.Module):
12
+ def __init__(self, version: str, **hf_kwargs):
13
+ super().__init__()
14
+ self.is_clip = version.startswith("openai")
15
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
16
+
17
+ if self.is_clip:
18
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
19
+ version, **hf_kwargs
20
+ )
21
+ else:
22
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
23
+ version, **hf_kwargs
24
+ )
25
+
26
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
27
+
28
+ def forward(self, batch_tokens: Tensor) -> Tensor:
29
+ """
30
+ batch_tokens: [bsz, embedding_length]
31
+
32
+ For T5 Encoder, embeding_length is 768
33
+ For CLIP, embedding_length is 256
34
+ """
35
+ outputs = self.hf_module(
36
+ input_ids=batch_tokens.to(self.hf_module.device),
37
+ attention_mask=None,
38
+ output_hidden_states=False,
39
+ )
40
+ return outputs[self.output_key]
torchtitan/experiments/flux/model/math.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+
12
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
13
+ q, k = apply_rope(q, k, pe)
14
+
15
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
16
+ x = rearrange(x, "B H L D -> B L (H D)")
17
+
18
+ return x
19
+
20
+
21
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
22
+ assert dim % 2 == 0
23
+ scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
24
+ omega = 1.0 / (theta**scale)
25
+ out = torch.einsum("...n,d->...nd", pos, omega)
26
+ out = torch.stack(
27
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
28
+ )
29
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
30
+ return out.float()
31
+
32
+
33
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
34
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
35
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
36
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
37
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
38
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
torchtitan/experiments/flux/model/model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ import torch
10
+
11
+ from torch import nn, Tensor
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
16
+ from torchtitan.experiments.flux.model.layers import (
17
+ DoubleStreamBlock,
18
+ EmbedND,
19
+ LastLayer,
20
+ MLPEmbedder,
21
+ SingleStreamBlock,
22
+ timestep_embedding,
23
+ )
24
+
25
+ from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
26
+ from torchtitan.tools.logging import logger
27
+
28
+
29
+ @dataclass
30
+ class FluxModelArgs(BaseModelArgs):
31
+ in_channels: int = 64
32
+ out_channels: int = 64
33
+ vec_in_dim: int = 768
34
+ context_in_dim: int = 512
35
+ hidden_size: int = 3072
36
+ mlp_ratio: float = 4.0
37
+ num_heads: int = 24
38
+ depth: int = 19
39
+ depth_single_blocks: int = 38
40
+ axes_dim: tuple = (16, 56, 56)
41
+ theta: int = 10_000
42
+ qkv_bias: bool = True
43
+ guidance_embed: bool = True
44
+ autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
45
+
46
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
47
+ # context_in_dim is the same as the T5 embedding dimension
48
+ self.context_in_dim = job_config.encoder.max_t5_encoding_len
49
+
50
+ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
51
+ # TODO(jianiw): Add the number of flops for the autoencoder
52
+ nparams = sum(p.numel() for p in model.parameters())
53
+ logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
54
+ return nparams, 1
55
+
56
+
57
+ class FluxModel(nn.Module, ModelProtocol):
58
+ """
59
+ Transformer model for flow matching on sequences.
60
+
61
+ Agrs:
62
+ model_args: FluxModelArgs.
63
+
64
+ Attributes:
65
+ model_args (TransformerModelArgs): Model configuration arguments.
66
+ """
67
+
68
+ def __init__(self, model_args: FluxModelArgs):
69
+ super().__init__()
70
+
71
+ self.model_args = model_args
72
+ self.in_channels = model_args.in_channels
73
+ self.out_channels = model_args.out_channels
74
+ if model_args.hidden_size % model_args.num_heads != 0:
75
+ raise ValueError(
76
+ f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
77
+ )
78
+ pe_dim = model_args.hidden_size // model_args.num_heads
79
+ if sum(model_args.axes_dim) != pe_dim:
80
+ raise ValueError(
81
+ f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
82
+ )
83
+ self.hidden_size = model_args.hidden_size
84
+ self.num_heads = model_args.num_heads
85
+ self.pe_embedder = EmbedND(
86
+ dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
87
+ )
88
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
89
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
90
+ self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
91
+ self.guidance_in = (
92
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
93
+ if model_args.guidance_embed
94
+ else nn.Identity()
95
+ )
96
+ self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
97
+
98
+ self.double_blocks = nn.ModuleList(
99
+ [
100
+ DoubleStreamBlock(
101
+ self.hidden_size,
102
+ self.num_heads,
103
+ mlp_ratio=model_args.mlp_ratio,
104
+ qkv_bias=model_args.qkv_bias,
105
+ )
106
+ for _ in range(model_args.depth)
107
+ ]
108
+ )
109
+
110
+ self.single_blocks = nn.ModuleList(
111
+ [
112
+ SingleStreamBlock(
113
+ self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
114
+ )
115
+ for _ in range(model_args.depth_single_blocks)
116
+ ]
117
+ )
118
+
119
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
120
+
121
+ def init_weights(self, buffer_device=None):
122
+ # TODO(jianiw): replace placeholder with real weight init
123
+ for param in self.parameters():
124
+ param.data.uniform_(0, 0.1)
125
+
126
+ def forward(
127
+ self,
128
+ img: Tensor,
129
+ img_ids: Tensor,
130
+ txt: Tensor,
131
+ txt_ids: Tensor,
132
+ timesteps: Tensor,
133
+ y: Tensor,
134
+ guidance: Tensor | None = None,
135
+ ) -> Tensor:
136
+ if img.ndim != 3 or txt.ndim != 3:
137
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
138
+
139
+ # running on sequences img
140
+ img = self.img_in(img)
141
+ vec = self.time_in(timestep_embedding(timesteps, 256))
142
+ if self.model_args.guidance_embed:
143
+ if guidance is None:
144
+ raise ValueError(
145
+ "Didn't get guidance strength for guidance distilled model."
146
+ )
147
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
148
+ vec = vec + self.vector_in(y)
149
+ txt = self.txt_in(txt)
150
+
151
+ ids = torch.cat((txt_ids, img_ids), dim=1)
152
+ pe = self.pe_embedder(ids)
153
+
154
+ for block in self.double_blocks:
155
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
156
+
157
+ img = torch.cat((txt, img), 1)
158
+ for block in self.single_blocks:
159
+ img = block(img, vec=vec, pe=pe)
160
+ img = img[:, txt.shape[1] :, ...]
161
+
162
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
163
+ return img
164
+
165
+ @classmethod
166
+ def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
167
+ """
168
+ Initialize a Flux model from a FluxModelArgs object.
169
+
170
+ Args:
171
+ model_args (FluxModelArgs): Model configuration arguments.
172
+
173
+ Returns:
174
+ FluxModel: FluxModel model.
175
+
176
+ """
177
+ return cls(model_args)
torchtitan/experiments/flux/scripts/download_autoencoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional
8
+
9
+ from requests.exceptions import HTTPError
10
+
11
+
12
+ def hf_download(
13
+ repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
14
+ ) -> None:
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ try:
18
+ hf_hub_download(
19
+ repo_id=repo_id,
20
+ filename=file_path,
21
+ local_dir=local_dir,
22
+ local_dir_use_symlinks=False,
23
+ token=hf_token,
24
+ )
25
+ except HTTPError as e:
26
+ if e.response.status_code == 401:
27
+ print(
28
+ "You need to pass a valid `--hf_token=...` to download private checkpoints."
29
+ )
30
+ else:
31
+ raise e
32
+
33
+
34
+ if __name__ == "__main__":
35
+ import argparse
36
+
37
+ parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
38
+ parser.add_argument(
39
+ "--repo_id",
40
+ type=str,
41
+ default="black-forest-labs/FLUX.1-dev",
42
+ help="Repository ID to download from. default to Flux-dev model",
43
+ )
44
+ parser.add_argument(
45
+ "--ae_path",
46
+ type=str,
47
+ default="ae.safetensors",
48
+ help="the autoencoder path relative to repo_id",
49
+ )
50
+ parser.add_argument(
51
+ "--hf_token", type=str, default=None, help="HuggingFace API token"
52
+ )
53
+ parser.add_argument(
54
+ "--local_dir",
55
+ type=str,
56
+ default="torchtitan/experiments/flux/assets/autoencoder/",
57
+ help="local directory to save the autoencoder",
58
+ )
59
+
60
+ args = parser.parse_args()
61
+ hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
torchtitan/experiments/flux/tests/test_generate_image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from PIL import ExifTags, Image
16
+
17
+ from torch import Tensor
18
+
19
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
20
+
21
+ from torchtitan.experiments.flux.model.autoencoder import (
22
+ AutoEncoder,
23
+ AutoEncoderParams,
24
+ load_ae,
25
+ )
26
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
27
+
28
+ from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
29
+ from torchtitan.experiments.flux.utils import (
30
+ create_position_encoding_for_latents,
31
+ generate_noise_latent,
32
+ pack_latents,
33
+ preprocess_flux_data,
34
+ unpack_latents,
35
+ )
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(
43
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
44
+ ) -> Callable[[float], float]:
45
+ m = (y2 - y1) / (x2 - x1)
46
+ b = y1 - m * x1
47
+ return lambda x: m * x + b
48
+
49
+
50
+ def get_schedule(
51
+ num_steps: int,
52
+ image_seq_len: int,
53
+ base_shift: float = 0.5,
54
+ max_shift: float = 1.15,
55
+ shift: bool = True,
56
+ ) -> list[float]:
57
+ # extra step for zero
58
+ timesteps = torch.linspace(1, 0, num_steps + 1)
59
+
60
+ # shifting the schedule to favor high timesteps for higher signal images
61
+ if shift:
62
+ # estimate mu based on linear estimation between two points
63
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
64
+ timesteps = time_shift(mu, 1.0, timesteps)
65
+
66
+ return timesteps.tolist()
67
+
68
+
69
+ class TestGenerateImage:
70
+ def test_generate_image(self):
71
+ """
72
+ Run a forward pass of flux model to generate an image.
73
+ """
74
+ name = "flux-dev"
75
+ img_width = 512
76
+ img_height = 512
77
+ seed = None
78
+ prompt = (
79
+ "a photo of a forest with mist swirling around the tree trunks. The word "
80
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
81
+ )
82
+ device = "cuda"
83
+ num_steps = None
84
+ loop = False
85
+ guidance = 3.5
86
+ output_dir = "output"
87
+ add_sampling_metadata = True
88
+
89
+ prompt = prompt.split("|")
90
+ if len(prompt) == 1:
91
+ prompt = prompt[0]
92
+ additional_prompts = None
93
+ else:
94
+ additional_prompts = prompt[1:]
95
+ prompt = prompt[0]
96
+
97
+ assert not (
98
+ (additional_prompts is not None) and loop
99
+ ), "Do not provide additional prompts and set loop to True"
100
+
101
+ torch_device = torch.device(device)
102
+ if num_steps is None:
103
+ num_steps = 30
104
+
105
+ # allow for packing and conversion to latent space
106
+ img_height = 16 * (img_height // 16)
107
+ img_width = 16 * (img_width // 16)
108
+
109
+ # init all components
110
+ model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
111
+
112
+ ae = load_ae(
113
+ ckpt_path="assets/autoencoder/ae.safetensors",
114
+ autoencoder_params=AutoEncoderParams(),
115
+ device=torch_device,
116
+ dtype=torch.bfloat16,
117
+ )
118
+ clip_tokenizer = FluxTokenizer(
119
+ model_path="openai/clip-vit-large-patch14", max_length=77
120
+ )
121
+ t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
122
+ clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
123
+ torch_device, dtype=torch.bfloat16
124
+ )
125
+ t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
126
+ torch_device, dtype=torch.bfloat16
127
+ )
128
+
129
+ rng = torch.Generator(device="cpu")
130
+
131
+ if seed is None:
132
+ seed = rng.seed()
133
+ print(f"Generating with seed {seed}:\n{prompt}")
134
+ t0 = time.perf_counter()
135
+ output_name = os.path.join(output_dir, f"img_{seed}.jpg")
136
+
137
+ # Tokenize the prompt, on CPU
138
+ clip_tokens = clip_tokenizer.encode(prompt)
139
+ t5_tokens = t5_tokenizer.encode(prompt)
140
+
141
+ batch = preprocess_flux_data(
142
+ device=torch_device,
143
+ dtype=torch.bfloat16,
144
+ autoencoder=None,
145
+ clip_encoder=clip_encoder,
146
+ t5_encoder=t5_encoder,
147
+ batch={
148
+ "clip_tokens": clip_tokens,
149
+ "t5_tokens": t5_tokens,
150
+ },
151
+ )
152
+
153
+ img = self._generate_images(
154
+ device=torch_device,
155
+ dtype=torch.bfloat16,
156
+ model=model,
157
+ decoder=ae,
158
+ img_width=img_width,
159
+ img_height=img_height,
160
+ denoising_steps=num_steps,
161
+ seed=seed,
162
+ clip_encodings=batch["clip_encodings"],
163
+ t5_encodings=batch["t5_encodings"],
164
+ guidance=guidance,
165
+ )
166
+
167
+ if torch.cuda.is_available():
168
+ torch.cuda.synchronize()
169
+ t1 = time.perf_counter()
170
+
171
+ print(f"Done in {t1 - t0:.1f}s.")
172
+
173
+ self._save_image(name, output_name, img, add_sampling_metadata, prompt)
174
+
175
+ def _generate_images(
176
+ self,
177
+ device: torch.device,
178
+ dtype: torch.dtype,
179
+ model: FluxModel,
180
+ decoder: AutoEncoder,
181
+ # image params:
182
+ img_width: int,
183
+ img_height: int,
184
+ # sampling params:
185
+ denoising_steps: int,
186
+ seed: int,
187
+ clip_encodings: torch.Tensor,
188
+ t5_encodings: torch.Tensor,
189
+ guidance: float = 4.0,
190
+ ):
191
+
192
+ bsz = clip_encodings.shape[0]
193
+ latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
194
+ _, latent_channels, latent_height, latent_width = latents.shape
195
+
196
+ # create denoising schedule
197
+ timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
198
+
199
+ # create positional encodings
200
+ POSITION_DIM = 3 # constant for Flux flow model
201
+ latent_pos_enc = create_position_encoding_for_latents(
202
+ bsz, latent_height, latent_width, POSITION_DIM
203
+ ).to(latents)
204
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
205
+
206
+ # convert img-like latents into sequences of patches
207
+ latents = pack_latents(latents)
208
+
209
+ # this is ignored for schnell
210
+ guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
211
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
212
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
213
+ pred = model(
214
+ img=latents,
215
+ img_ids=latent_pos_enc,
216
+ txt=t5_encodings,
217
+ txt_ids=text_pos_enc,
218
+ y=clip_encodings,
219
+ timesteps=t_vec,
220
+ guidance=guidance_vec,
221
+ )
222
+
223
+ latents = latents + (t_prev - t_curr) * pred
224
+
225
+ # convert sequences of patches into img-like latents
226
+ latents = unpack_latents(latents, latent_height, latent_width)
227
+
228
+ img = decoder.decode(latents)
229
+ return img
230
+
231
+ def _save_image(
232
+ self,
233
+ name: str,
234
+ output_name: str,
235
+ x: torch.Tensor,
236
+ add_sampling_metadata: bool,
237
+ prompt: str,
238
+ ):
239
+ print(f"Saving {output_name}")
240
+ # bring into PIL format and save
241
+ x = x.clamp(-1, 1)
242
+ x = rearrange(x[0], "c h w -> h w c")
243
+
244
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
245
+
246
+ exif_data = Image.Exif()
247
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
248
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
249
+ exif_data[ExifTags.Base.Model] = name
250
+ if add_sampling_metadata:
251
+ exif_data[ExifTags.Base.ImageDescription] = prompt
252
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
torchtitan/experiments/flux/train_configs/debug_model.toml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [job]
3
+ dump_folder = "./outputs"
4
+ description = "Flux debug model"
5
+ print_args = false
6
+ use_for_integration_test = true
7
+
8
+ [profiling]
9
+ enable_profiling = false
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 10
12
+ enable_memory_snapshot = false
13
+ save_memory_snapshot_folder = "memory_snapshot"
14
+
15
+ [metrics]
16
+ log_freq = 1
17
+ disable_color_printing = false
18
+ enable_tensorboard = false
19
+ save_tb_folder = "tb"
20
+ enable_wandb = false
21
+
22
+ [model]
23
+ name = "flux"
24
+ flavor = "flux-debug"
25
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
26
+ # test tokenizer.model, for debug purpose only
27
+ # tokenizer_path = "./tests/assets/test_tiktoken.model"
28
+ # converters = "float8"
29
+
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 8e-4
34
+ eps = 1e-8
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.0
41
+
42
+ [training]
43
+ batch_size = 32
44
+ seq_len = 512
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "cc12m"
49
+ guidance = 3.5
50
+ seed = 0
51
+
52
+ [encoder]
53
+ t5_encoder="google/t5-v1_1-small"
54
+ clip_encoder="openai/clip-vit-large-patch14"
55
+ max_t5_encoding_len=512
56
+ auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
57
+
58
+ [parallelism]
59
+ data_parallel_replicate_degree = 1
60
+ data_parallel_shard_degree = 1
61
+ fsdp_reshard_after_forward = "default" # default / never / always
62
+ tensor_parallel_degree = 1
63
+ enable_async_tensor_parallel = false
64
+ pipeline_parallel_degree = 1
65
+ context_parallel_degree = 1
66
+
67
+ [experimental]
68
+ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import math
10
+ import time
11
+
12
+ from typing import Dict, List, Tuple
13
+
14
+ # import numpy as np
15
+ import torch #
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+
20
+ # from torchao_pr.mg_grouped_gemm import mg_grouped_gemm
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
25
+ )
26
+
27
+ # Try to import the optimized MG GEMM implementation
28
+ try:
29
+ from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward,
30
+ grouped_gemm_forward,
31
+ )
32
+
33
+ has_mg_gemm = True
34
+ except ImportError:
35
+ logging.warning("MG GEMM implementation not found. Will use manual looping only.")
36
+ has_mg_gemm = False
37
+
38
+
39
+ class Router(nn.Module):
40
+ """
41
+ Router module that assigns tokens to experts.
42
+ """
43
+
44
+ def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
45
+ super().__init__()
46
+ self.input_dim = input_dim
47
+ self.num_experts = num_experts
48
+ self.top_k = top_k
49
+
50
+ # Routing layer
51
+ self.router = nn.Linear(input_dim, num_experts)
52
+
53
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
54
+ """
55
+ Route input tokens to experts.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
59
+
60
+ Returns:
61
+ Tuple containing:
62
+ - router_logits: Raw routing probabilities
63
+ - dispatch_tensor: One-hot tensor indicating expert assignment
64
+ - expert_indices: List of indices for each expert's tokens
65
+ """
66
+ batch_size, seq_len, _ = x.shape
67
+
68
+ # Flatten batch and sequence dimensions
69
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
70
+
71
+ # Compute routing probabilities
72
+ router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts)
73
+
74
+ # Apply softmax to get probabilities
75
+ router_probs = F.softmax(router_logits, dim=-1)
76
+
77
+ # Get top-k experts for each token
78
+ top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
79
+
80
+ # Normalize top-k probabilities
81
+ top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
82
+
83
+ # Create dispatch tensor (one-hot representation of assignments)
84
+ dispatch_tensor = torch.zeros_like(router_probs)
85
+ token_indices = (
86
+ torch.arange(router_probs.size(0), device=router_probs.device)
87
+ .unsqueeze(1)
88
+ .expand(-1, self.top_k)
89
+ )
90
+ dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1))
91
+
92
+ # For each expert, get the indices of tokens routed to it
93
+ expert_indices = []
94
+ for expert_idx in range(self.num_experts):
95
+ # Get indices of tokens that have non-zero probability for this expert
96
+ indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[
97
+ 0
98
+ ]
99
+ expert_indices.append(indices)
100
+
101
+ return router_logits, dispatch_tensor, expert_indices
102
+
103
+
104
+ class Expert(nn.Module):
105
+ """
106
+ Individual expert module.
107
+ """
108
+
109
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
110
+ super().__init__()
111
+ self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
112
+ self.activation = nn.GELU()
113
+ self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ x = self.fc1(x)
117
+ x = self.activation(x)
118
+ x = self.fc2(x)
119
+ return x
120
+
121
+
122
+ class MixtureOfExperts(nn.Module):
123
+ """
124
+ Mixture of Experts layer with support for both manual looping and grouped GEMM.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ input_dim: int,
130
+ hidden_dim: int,
131
+ output_dim: int,
132
+ num_experts: int,
133
+ top_k: int = 2,
134
+ use_mg_gemm: bool = False,
135
+ ):
136
+ super().__init__()
137
+ self.input_dim = input_dim
138
+ self.hidden_dim = hidden_dim
139
+ self.output_dim = output_dim
140
+ self.num_experts = num_experts
141
+ self.top_k = top_k
142
+ self.use_mg_gemm = use_mg_gemm and has_mg_gemm
143
+
144
+ # Router
145
+ self.router = Router(input_dim, num_experts, top_k)
146
+
147
+ # Create expert modules
148
+ if self.use_mg_gemm:
149
+ # For MG GEMM, we need a single weight tensor for all experts
150
+ # First layer (input -> hidden)
151
+ self.expert_fc1_weight = nn.Parameter(
152
+ torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim)
153
+ )
154
+ # self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim))
155
+
156
+ # Second layer (hidden -> output)
157
+ self.expert_fc2_weight = nn.Parameter(
158
+ torch.randn(num_experts * output_dim, hidden_dim)
159
+ / math.sqrt(hidden_dim)
160
+ )
161
+ # self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim))
162
+ else:
163
+ # For manual looping, create separate experts
164
+ self.experts = nn.ModuleList(
165
+ [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
166
+ )
167
+
168
+ def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor:
169
+ """
170
+ Forward pass using manual looping over experts.
171
+ """
172
+ batch_size, seq_len, _ = x.shape
173
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
174
+
175
+ # Get routing information
176
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
177
+
178
+ # Initialize output tensor
179
+ final_output = torch.zeros(
180
+ batch_size * seq_len, self.output_dim, device=x.device
181
+ )
182
+
183
+ # Process each expert
184
+ for expert_idx, indices in enumerate(expert_indices):
185
+ if indices.numel() > 0:
186
+ # Get tokens routed to this expert
187
+ expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim)
188
+
189
+ # Process tokens through expert
190
+ expert_outputs = self.experts[expert_idx](
191
+ expert_inputs
192
+ ) # (num_tokens_for_expert, output_dim)
193
+
194
+ # Scale outputs by router probabilities
195
+ scaled_outputs = expert_outputs * dispatch_tensor[
196
+ indices, expert_idx
197
+ ].unsqueeze(1)
198
+
199
+ # Add to final output
200
+ final_output.index_add_(0, indices, scaled_outputs)
201
+
202
+ # Reshape back to original dimensions
203
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
204
+
205
+ return output, router_logits
206
+
207
+ def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor:
208
+ batch_size, seq_len, _ = x.shape
209
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
210
+ total_tokens = batch_size * seq_len
211
+
212
+ # Get routing information
213
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
214
+
215
+ # Get token counts for each expert
216
+ token_counts = [indices.numel() for indices in expert_indices]
217
+ m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device)
218
+
219
+ print(f"Token counts per expert: {token_counts}")
220
+ print(f"m_sizes: {m_sizes}")
221
+
222
+ # Create the combined input tensor
223
+ combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device)
224
+
225
+ start_idx = 0
226
+ for expert_idx, indices in enumerate(expert_indices):
227
+ if indices.numel() > 0:
228
+ end_idx = start_idx + indices.numel()
229
+ combined_input[start_idx:end_idx] = x_flat[indices]
230
+ start_idx = end_idx
231
+
232
+ print(f"combined_input shape: {combined_input.shape}")
233
+
234
+ # First layer: input -> hidden
235
+ fc1_weight_reshaped = self.expert_fc1_weight.reshape(
236
+ self.num_experts, self.hidden_dim, self.input_dim
237
+ )
238
+ fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim)
239
+
240
+ print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}")
241
+
242
+ # Run the grouped GEMM
243
+ hidden_outputs = grouped_gemm_forward(
244
+ combined_input, fc1_weight_combined, m_sizes
245
+ )
246
+
247
+ print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}")
248
+
249
+ # Apply activation
250
+ hidden_outputs = F.gelu(hidden_outputs)
251
+
252
+ print(f"hidden_outputs shape after activation: {hidden_outputs.shape}")
253
+
254
+ # Second layer: hidden -> output
255
+ # Reshape hidden_outputs to match expected dimensions
256
+ reshaped_hidden_outputs = []
257
+ start_idx = 0
258
+
259
+ for expert_idx, count in enumerate(token_counts):
260
+ if count > 0:
261
+ end_idx = start_idx + count
262
+ # Take this expert's outputs and reshape to [count, hidden_dim]
263
+ expert_output = hidden_outputs[
264
+ start_idx:end_idx,
265
+ expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim,
266
+ ]
267
+ reshaped_hidden_outputs.append(expert_output)
268
+ start_idx = end_idx
269
+
270
+ # Concatenate all reshaped outputs
271
+ hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0)
272
+
273
+ # Reshape expert weights for second layer
274
+ fc2_weight_reshaped = self.expert_fc2_weight.reshape(
275
+ self.num_experts, self.output_dim, self.hidden_dim
276
+ )
277
+ fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim)
278
+
279
+ print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}")
280
+
281
+ # Run the second grouped GEMM
282
+ expert_outputs_combined = grouped_gemm_forward(
283
+ hidden_outputs, fc2_weight_combined, m_sizes
284
+ )
285
+
286
+ # Initialize final output tensor with correct shape
287
+ final_output = torch.zeros(total_tokens, self.output_dim, device=x.device)
288
+
289
+ # Distribute the outputs back to the original token positions
290
+ start_idx = 0
291
+ for expert_idx, indices in enumerate(expert_indices):
292
+ if indices.numel() > 0:
293
+ end_idx = start_idx + indices.numel()
294
+ # Get this expert's outputs
295
+ expert_outputs = expert_outputs_combined[start_idx:end_idx]
296
+
297
+ print(
298
+ f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}"
299
+ )
300
+
301
+ # Scale outputs by router probabilities
302
+ scaled_outputs = expert_outputs * dispatch_tensor[
303
+ indices, expert_idx
304
+ ].unsqueeze(1)
305
+
306
+ # Ensure dimensions match before using index_add_
307
+ if scaled_outputs.shape[1] != final_output.shape[1]:
308
+ # print(
309
+ # f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}"
310
+ # )
311
+ # Reshape if needed - make sure output_dim is correct
312
+ scaled_outputs = scaled_outputs[:, : self.output_dim]
313
+
314
+ # Add to final output
315
+ final_output.index_add_(0, indices, scaled_outputs)
316
+
317
+ start_idx = end_idx
318
+
319
+ # Reshape back to original dimensions
320
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
321
+
322
+ return output, router_logits
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ if self.use_mg_gemm and has_mg_gemm:
326
+ return self.forward_mg_gemm(x)
327
+ else:
328
+ return self.forward_manual_loop(x)
329
+
330
+
331
+ class MoEModel(nn.Module):
332
+ """
333
+ Simple model using MoE layers.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ vocab_size: int,
339
+ embed_dim: int,
340
+ hidden_dim: int,
341
+ num_experts: int,
342
+ top_k: int = 2,
343
+ use_mg_gemm: bool = False,
344
+ ):
345
+ super().__init__()
346
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
347
+ self.moe_layer = MixtureOfExperts(
348
+ input_dim=embed_dim,
349
+ hidden_dim=hidden_dim,
350
+ output_dim=embed_dim,
351
+ num_experts=num_experts,
352
+ top_k=top_k,
353
+ use_mg_gemm=use_mg_gemm,
354
+ )
355
+ self.output_layer = nn.Linear(embed_dim, vocab_size)
356
+
357
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
358
+ # x shape: (batch_size, seq_len)
359
+ embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
360
+ moe_output, router_logits = self.moe_layer(
361
+ embedded
362
+ ) # (batch_size, seq_len, embed_dim)
363
+ logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size)
364
+ return logits, router_logits
365
+
366
+
367
+ def compute_load_balancing_loss(
368
+ router_logits: torch.Tensor, num_experts: int
369
+ ) -> torch.Tensor:
370
+ """
371
+ Compute the load balancing loss for MoE training.
372
+
373
+ Args:
374
+ router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts)
375
+ num_experts (int): Number of experts
376
+
377
+ Returns:
378
+ torch.Tensor: Load balancing loss
379
+ """
380
+ # Get router probabilities
381
+ router_probs = F.softmax(
382
+ router_logits, dim=-1
383
+ ) # (batch_size * seq_len, num_experts)
384
+
385
+ # Compute fraction of tokens routed to each expert
386
+ # Sum across the batch dimension and normalize
387
+ router_probs_sum = router_probs.sum(dim=0) # (num_experts,)
388
+ router_probs_sum = router_probs_sum / router_probs_sum.sum()
389
+
390
+ # Compute the mean probability per expert
391
+ mean_prob = 1.0 / num_experts
392
+
393
+ # Compute the fraction of tokens routed to each expert
394
+ # The goal is to have uniform routing across experts
395
+ load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum)
396
+
397
+ return load_balancing_loss
398
+
399
+
400
+ def generate_sample_data(
401
+ batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda"
402
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
403
+ """
404
+ Generate sample data for training.
405
+
406
+ Args:
407
+ batch_size (int): Batch size
408
+ seq_len (int): Sequence length
409
+ vocab_size (int): Vocabulary size
410
+ device (str): Device to use
411
+
412
+ Returns:
413
+ Tuple of input tokens and target tokens
414
+ """
415
+ # Generate random input tokens
416
+ inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
417
+
418
+ # Generate random target tokens
419
+ targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
420
+
421
+ return inputs, targets
422
+
423
+
424
+ def train_epoch(
425
+ model: nn.Module,
426
+ optimizer: torch.optim.Optimizer,
427
+ batch_size: int,
428
+ seq_len: int,
429
+ vocab_size: int,
430
+ num_batches: int,
431
+ device: str,
432
+ load_balance_coef: float = 0.01,
433
+ ) -> Dict[str, float]:
434
+ """
435
+ Train the model for one epoch.
436
+
437
+ Args:
438
+ model (nn.Module): Model to train
439
+ optimizer (torch.optim.Optimizer): Optimizer
440
+ batch_size (int): Batch size
441
+ seq_len (int): Sequence length
442
+ vocab_size (int): Vocabulary size
443
+ num_batches (int): Number of batches per epoch
444
+ device (str): Device to use
445
+ load_balance_coef (float): Coefficient for load balancing loss
446
+
447
+ Returns:
448
+ Dict containing training metrics
449
+ """
450
+ model.train()
451
+ total_loss = 0.0
452
+ total_acc = 0.0
453
+ start_time = time.time()
454
+
455
+ for i in range(num_batches):
456
+ # Generate sample data
457
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
458
+
459
+ # Forward pass
460
+ optimizer.zero_grad()
461
+ logits, router_logits = model(inputs)
462
+
463
+ # Compute loss
464
+ # Reshape for cross entropy loss
465
+ logits_flat = logits.reshape(-1, vocab_size)
466
+ targets_flat = targets.reshape(-1)
467
+
468
+ # Cross entropy loss
469
+ ce_loss = F.cross_entropy(logits_flat, targets_flat)
470
+
471
+ # Load balancing loss
472
+ lb_loss = compute_load_balancing_loss(
473
+ router_logits, model.moe_layer.num_experts
474
+ )
475
+
476
+ # Combined loss
477
+ loss = ce_loss + load_balance_coef * lb_loss
478
+
479
+ # Backward pass
480
+ loss.backward()
481
+ optimizer.step()
482
+
483
+ # Compute accuracy
484
+ preds = logits_flat.argmax(dim=-1)
485
+ correct = (preds == targets_flat).float().sum()
486
+ acc = correct / (batch_size * seq_len)
487
+
488
+ # Accumulate metrics
489
+ total_loss += loss.item()
490
+ total_acc += acc.item()
491
+
492
+ # Log progress
493
+ if (i + 1) % 10 == 0:
494
+ logging.info(
495
+ f"Batch {i + 1}/{num_batches} | "
496
+ f"Loss: {loss.item():.4f} | "
497
+ f"CE Loss: {ce_loss.item():.4f} | "
498
+ f"LB Loss: {lb_loss.item():.4f} | "
499
+ f"Acc: {acc.item():.4f}"
500
+ )
501
+
502
+ # Compute average metrics
503
+ avg_loss = total_loss / num_batches
504
+ avg_acc = total_acc / num_batches
505
+ epoch_time = time.time() - start_time
506
+
507
+ return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time}
508
+
509
+
510
+ def evaluate(
511
+ model: nn.Module,
512
+ batch_size: int,
513
+ seq_len: int,
514
+ vocab_size: int,
515
+ num_batches: int,
516
+ device: str,
517
+ ) -> Dict[str, float]:
518
+ """
519
+ Evaluate the model.
520
+
521
+ Args:
522
+ model (nn.Module): Model to evaluate
523
+ batch_size (int): Batch size
524
+ seq_len (int): Sequence length
525
+ vocab_size (int): Vocabulary size
526
+ num_batches (int): Number of batches for evaluation
527
+ device (str): Device to use
528
+
529
+ Returns:
530
+ Dict containing evaluation metrics
531
+ """
532
+ model.eval()
533
+ total_loss = 0.0
534
+ total_acc = 0.0
535
+
536
+ with torch.no_grad():
537
+ for i in range(num_batches):
538
+ # Generate sample data
539
+ inputs, targets = generate_sample_data(
540
+ batch_size, seq_len, vocab_size, device
541
+ )
542
+
543
+ # Forward pass
544
+ logits, router_logits = model(inputs)
545
+
546
+ # Compute loss
547
+ logits_flat = logits.reshape(-1, vocab_size)
548
+ targets_flat = targets.reshape(-1)
549
+
550
+ # Cross entropy loss
551
+ loss = F.cross_entropy(logits_flat, targets_flat)
552
+
553
+ # Compute accuracy
554
+ preds = logits_flat.argmax(dim=-1)
555
+ correct = (preds == targets_flat).float().sum()
556
+ acc = correct / (batch_size * seq_len)
557
+
558
+ # Accumulate metrics
559
+ total_loss += loss.item()
560
+ total_acc += acc.item()
561
+
562
+ # Compute average metrics
563
+ avg_loss = total_loss / num_batches
564
+ avg_acc = total_acc / num_batches
565
+
566
+ return {"loss": avg_loss, "acc": avg_acc}
567
+
568
+
569
+ def measure_performance(
570
+ model: nn.Module,
571
+ batch_size: int,
572
+ seq_len: int,
573
+ vocab_size: int,
574
+ num_batches: int,
575
+ device: str,
576
+ ) -> Dict[str, float]:
577
+ """
578
+ Measure forward and backward pass performance.
579
+
580
+ Args:
581
+ model (nn.Module): Model to evaluate
582
+ batch_size (int): Batch size
583
+ seq_len (int): Sequence length
584
+ vocab_size (int): Vocabulary size
585
+ num_batches (int): Number of batches for measurement
586
+ device (str): Device to use
587
+
588
+ Returns:
589
+ Dict containing performance metrics
590
+ """
591
+ model.train()
592
+
593
+ # Create dummy optimizer
594
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
595
+
596
+ # Warmup
597
+ for _ in range(5):
598
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
599
+ logits, router_logits = model(inputs)
600
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
601
+ loss.backward()
602
+ optimizer.zero_grad()
603
+
604
+ # Measure forward pass time
605
+ torch.cuda.synchronize()
606
+ forward_start = time.time()
607
+
608
+ for _ in range(num_batches):
609
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
610
+ with torch.no_grad():
611
+ logits, router_logits = model(inputs)
612
+
613
+ torch.cuda.synchronize()
614
+ forward_end = time.time()
615
+ forward_time = (forward_end - forward_start) / num_batches
616
+
617
+ # Measure backward pass time
618
+ torch.cuda.synchronize()
619
+ backward_start = time.time()
620
+
621
+ for _ in range(num_batches):
622
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
623
+ logits, router_logits = model(inputs)
624
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
625
+ loss.backward()
626
+ optimizer.zero_grad()
627
+
628
+ torch.cuda.synchronize()
629
+ backward_end = time.time()
630
+ backward_time = (backward_end - backward_start) / num_batches
631
+
632
+ return {
633
+ "forward_time": forward_time * 1000, # Convert to ms
634
+ "backward_time": backward_time * 1000, # Convert to ms
635
+ "total_time": (forward_time + backward_time) * 1000, # Convert to ms
636
+ }
637
+
638
+
639
+ def compare_methods(args):
640
+ """
641
+ Compare manual looping and MG GEMM implementations.
642
+ """
643
+ device = torch.device(args.device)
644
+
645
+ # Create models
646
+ manual_model = MoEModel(
647
+ vocab_size=args.vocab_size,
648
+ embed_dim=args.embed_dim,
649
+ hidden_dim=args.hidden_dim,
650
+ num_experts=args.num_experts,
651
+ top_k=args.top_k,
652
+ use_mg_gemm=False,
653
+ ).to(device)
654
+
655
+ if has_mg_gemm:
656
+ mg_model = MoEModel(
657
+ vocab_size=args.vocab_size,
658
+ embed_dim=args.embed_dim,
659
+ hidden_dim=args.hidden_dim,
660
+ num_experts=args.num_experts,
661
+ top_k=args.top_k,
662
+ use_mg_gemm=True,
663
+ ).to(device)
664
+ else:
665
+ mg_model = None
666
+
667
+ # Measure performance
668
+ logging.info("Measuring performance of manual looping method...")
669
+ manual_perf = measure_performance(
670
+ manual_model,
671
+ args.batch_size,
672
+ args.seq_len,
673
+ args.vocab_size,
674
+ args.perf_batches,
675
+ device,
676
+ )
677
+
678
+ if mg_model is not None:
679
+ logging.info("Measuring performance of MG GEMM method...")
680
+ mg_perf = measure_performance(
681
+ mg_model,
682
+ args.batch_size,
683
+ args.seq_len,
684
+ args.vocab_size,
685
+ args.perf_batches,
686
+ device,
687
+ )
688
+ else:
689
+ mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0}
690
+
691
+ # Log results
692
+ logging.info("\n===== Performance Comparison =====")
693
+ logging.info("Model Configuration:")
694
+ logging.info(f" - Batch Size: {args.batch_size}")
695
+ logging.info(f" - Sequence Length: {args.seq_len}")
696
+ logging.info(f" - Embed Dimension: {args.embed_dim}")
697
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
698
+ logging.info(f" - Number of Experts: {args.num_experts}")
699
+ logging.info(f" - Top-K: {args.top_k}")
700
+ logging.info("")
701
+
702
+ logging.info("Manual Looping Method:")
703
+ logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms")
704
+ logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms")
705
+ logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms")
706
+ logging.info("")
707
+
708
+ if mg_model is not None:
709
+ logging.info("MG GEMM Method:")
710
+ logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms")
711
+ logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms")
712
+ logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms")
713
+ logging.info("")
714
+
715
+ # Calculate speedup
716
+ forward_speedup = (
717
+ manual_perf["forward_time"] / mg_perf["forward_time"]
718
+ if mg_perf["forward_time"] > 0
719
+ else 0
720
+ )
721
+ backward_speedup = (
722
+ manual_perf["backward_time"] / mg_perf["backward_time"]
723
+ if mg_perf["backward_time"] > 0
724
+ else 0
725
+ )
726
+ total_speedup = (
727
+ manual_perf["total_time"] / mg_perf["total_time"]
728
+ if mg_perf["total_time"] > 0
729
+ else 0
730
+ )
731
+
732
+ logging.info("Speedup (MG GEMM vs Manual):")
733
+ logging.info(f" - Forward Speedup: {forward_speedup:.2f}x")
734
+ logging.info(f" - Backward Speedup: {backward_speedup:.2f}x")
735
+ logging.info(f" - Total Speedup: {total_speedup:.2f}x")
736
+ else:
737
+ logging.info("MG GEMM method not available.")
738
+
739
+
740
+ def train_model(args):
741
+ """
742
+ Train an MoE model.
743
+ """
744
+ device = torch.device(args.device)
745
+
746
+ # Create model
747
+ model = MoEModel(
748
+ vocab_size=args.vocab_size,
749
+ embed_dim=args.embed_dim,
750
+ hidden_dim=args.hidden_dim,
751
+ num_experts=args.num_experts,
752
+ top_k=args.top_k,
753
+ use_mg_gemm=args.use_mg_gemm and has_mg_gemm,
754
+ ).to(device)
755
+
756
+ # Create optimizer
757
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
758
+
759
+ # Log model information
760
+ logging.info("Model configuration:")
761
+ logging.info(f" - Vocabulary Size: {args.vocab_size}")
762
+ logging.info(f" - Embedding Dimension: {args.embed_dim}")
763
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
764
+ logging.info(f" - Number of Experts: {args.num_experts}")
765
+ logging.info(f" - Top-K: {args.top_k}")
766
+ logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}")
767
+
768
+ # Training loop
769
+ for epoch in range(args.epochs):
770
+ logging.info(f"\nEpoch {epoch + 1}/{args.epochs}")
771
+
772
+ # Train
773
+ train_metrics = train_epoch(
774
+ model=model,
775
+ optimizer=optimizer,
776
+ batch_size=args.batch_size,
777
+ seq_len=args.seq_len,
778
+ vocab_size=args.vocab_size,
779
+ num_batches=args.train_batches,
780
+ device=device,
781
+ load_balance_coef=args.load_balance_coef,
782
+ )
783
+
784
+ # Evaluate
785
+ eval_metrics = evaluate(
786
+ model=model,
787
+ batch_size=args.batch_size,
788
+ seq_len=args.seq_len,
789
+ vocab_size=args.vocab_size,
790
+ num_batches=args.eval_batches,
791
+ device=device,
792
+ )
793
+
794
+ # Log metrics
795
+ logging.info(
796
+ f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}"
797
+ )
798
+ logging.info(
799
+ f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}"
800
+ )
801
+ logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds")
802
+
803
+
804
+ if __name__ == "__main__":
805
+ parser = argparse.ArgumentParser(description="Train MoE model")
806
+
807
+ # Model parameters
808
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size")
809
+ parser.add_argument(
810
+ "--embed_dim", type=int, default=512, help="Embedding dimension"
811
+ )
812
+ parser.add_argument(
813
+ "--hidden_dim", type=int, default=1024, help="Hidden dimension in experts"
814
+ )
815
+ parser.add_argument("--num_experts", type=int, default=8, help="Number of experts")
816
+ parser.add_argument(
817
+ "--top_k", type=int, default=2, help="Top-k experts to route to"
818
+ )
819
+
820
+ # Training parameters
821
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
822
+ parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
823
+ parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
824
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
825
+ parser.add_argument(
826
+ "--train_batches",
827
+ type=int,
828
+ default=100,
829
+ help="Number of training batches per epoch",
830
+ )
831
+ parser.add_argument(
832
+ "--eval_batches", type=int, default=20, help="Number of evaluation batches"
833
+ )
834
+ parser.add_argument(
835
+ "--perf_batches",
836
+ type=int,
837
+ default=50,
838
+ help="Number of batches for performance testing",
839
+ )
840
+ parser.add_argument(
841
+ "--load_balance_coef",
842
+ type=float,
843
+ default=0.01,
844
+ help="Load balancing loss coefficient",
845
+ )
846
+
847
+ # Runtime parameters
848
+ parser.add_argument(
849
+ "--device",
850
+ type=str,
851
+ default="cuda" if torch.cuda.is_available() else "cpu",
852
+ help="Device to use (cuda or cpu)",
853
+ )
854
+ parser.add_argument(
855
+ "--use_mg_gemm",
856
+ action="store_true",
857
+ help="Use MG GEMM implementation if available",
858
+ )
859
+ parser.add_argument(
860
+ "--compare",
861
+ action="store_true",
862
+ help="Compare manual and MG GEMM implementations",
863
+ )
864
+ parser.add_argument("--train", action="store_true", help="Train the model")
865
+
866
+ args = parser.parse_args()
867
+
868
+ # Check for CUDA
869
+ if args.device == "cuda" and not torch.cuda.is_available():
870
+ logging.warning("CUDA not available, using CPU instead.")
871
+ args.device = "cpu"
872
+
873
+ # Log basic information
874
+ logging.info(f"PyTorch version: {torch.__version__}")
875
+ logging.info(f"Device: {args.device}")
876
+ logging.info(f"MG GEMM available: {has_mg_gemm}")
877
+
878
+ # Run the requested action
879
+ if args.compare:
880
+ compare_methods(args)
881
+ elif args.train:
882
+ train_model(args)
883
+ else:
884
+ # Default to comparison if no action specified
885
+ compare_methods(args)
torchtitan/experiments/llama4/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **The Llama 4 folder is still under development.**
2
+
3
+ #### Available features
4
+ - Llama 4 model definition (text-only), including the MoE architecture with token-choice routing
5
+ - Basic FSDP, TP, PP, CP support
6
+ - DCP checkpoint conversion scripts
7
+
8
+ #### Download Llama 4 tokenizer
9
+ ```bash
10
+ # Llama 4 tokenizer.model
11
+ python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --tokenizer_path "" --hf_token=...
12
+ ```
13
+
14
+ #### To be added
15
+ - Modeling
16
+ - iRoPE implementation
17
+ - load balance loss for token-choice MoE
18
+ - alternative expert-choice MoE
19
+ - multimodal support
20
+ - Kernel integration
21
+ - efficient bfloat16 GroupedGEMM kernels (from PyTorch core)
22
+ - efficient float8 GroupedGEMM kernels (from torchao)
23
+ - Parallelism
24
+ - performant TP implementation and torch.compile support for MoE layers
25
+ - Context Parallel support for FlexAttention, iRoPE, and multimodal inputs
26
+ - Expert Parallel support
27
+ - Testing
28
+ - perfomance and loss converging tests
29
+ - CI integration
torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.66 kB). View file
 
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
torchtitan/experiments/llama4/model/args.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ from torch import nn
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.protocols.train_spec import BaseModelArgs
16
+ from torchtitan.tools.logging import logger
17
+
18
+
19
+ @dataclass
20
+ class TransformerModelArgs(BaseModelArgs):
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+ rope_theta: float = 10000
30
+
31
+ max_seq_len: int = 2048
32
+ # If `True`, then each transformer block init uses its layer ID, and if
33
+ # `False`, each uses the total number of transformer blocks
34
+ depth_init: bool = True
35
+ norm_type: str = "rmsnorm"
36
+
37
+ use_flex_attn: bool = False
38
+ attn_mask_type: str = "causal"
39
+ eos_id: int = 0
40
+
41
+ # MoE args
42
+ moe_enabled: bool = True
43
+ num_experts: int = 8
44
+ use_shared_expert: bool = True
45
+ auto_scale_hidden_dim: bool = True
46
+ # frequency of using MoE layer instead of feedforward layer in a transformer block
47
+ interleave_moe_layer_step: int = 2
48
+ # token-choice
49
+ top_k: int = 1
50
+
51
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
52
+ self.norm_type = job_config.model.norm_type
53
+ self.vocab_size = tokenizer.n_words
54
+ self.max_seq_len = job_config.training.seq_len
55
+ self.use_flex_attn = job_config.model.use_flex_attn
56
+
57
+ def get_nparams_and_flops(
58
+ self, model: nn.Module, seq_len: int
59
+ ) -> tuple[int, float]:
60
+ nparams_embedding = 0
61
+ nparams_moe_router = 0
62
+ nparams_shared_expert = 0
63
+ nparams_experts = 0
64
+ nparams_dense = 0
65
+
66
+ for name, p in model.named_parameters():
67
+ if "embedding" in name:
68
+ nparams_embedding += p.numel()
69
+ nparams_dense += p.numel()
70
+ elif "moe.shared_expert" in name:
71
+ nparams_shared_expert += p.numel()
72
+ elif "moe.router" in name:
73
+ nparams_moe_router += p.numel()
74
+ elif "moe.experts" in name:
75
+ nparams_experts += p.numel()
76
+ else:
77
+ nparams_dense += p.numel()
78
+
79
+ nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
80
+ nparams = nparams_dense + nparams_sparse
81
+ nparams_sparse_active = (
82
+ nparams_moe_router
83
+ + nparams_shared_expert
84
+ + nparams_experts * self.top_k // self.num_experts
85
+ )
86
+
87
+ logger.info(
88
+ f"Total parameter count: dense {nparams_dense:,}, "
89
+ f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
90
+ )
91
+
92
+ l, h, q, t = (
93
+ self.n_layers,
94
+ self.n_heads,
95
+ self.dim // self.n_heads,
96
+ seq_len,
97
+ )
98
+ # Reasoning behind the factor of 12 for the self-attention part of the formula:
99
+ # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
100
+ # 2. the flash attention does 1 more matmul recomputation in the backward
101
+ # but recomputation should not be counted in calculating MFU (+0)
102
+ # 3. each matmul performs 1 multiplication and 1 addition (*2)
103
+ # 4. we follow the convention and do not account for sparsity in causal attention
104
+ num_flops_per_token = (
105
+ 6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
106
+ + 12 * l * h * q * t
107
+ )
108
+
109
+ return nparams, num_flops_per_token
torchtitan/experiments/llama4/model/moe.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from .args import TransformerModelArgs
12
+
13
+
14
+ class GroupedExperts(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ hidden_dim: int,
19
+ num_experts: int,
20
+ ):
21
+ super().__init__()
22
+ self.num_experts = num_experts
23
+ self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
24
+ self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
25
+ self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26
+
27
+ def forward(
28
+ self,
29
+ x: torch.Tensor,
30
+ num_local_tokens_per_expert: torch.Tensor | None = None,
31
+ ) -> torch.Tensor:
32
+ if num_local_tokens_per_expert is not None:
33
+ # a tuple of tensors indexed by experts
34
+ # each with shape (tokens_per_expert(varying), dim)
35
+ x = torch.split(
36
+ x,
37
+ split_size_or_sections=num_local_tokens_per_expert.tolist(),
38
+ dim=0,
39
+ )
40
+ out_experts_splits = []
41
+ for expert_idx, x_expert in enumerate(x):
42
+ w1, w2, w3 = (
43
+ self.w1[expert_idx],
44
+ self.w2[expert_idx],
45
+ self.w3[expert_idx],
46
+ )
47
+ h = F.silu(torch.matmul(x_expert, w1))
48
+ h = h * torch.matmul(x_expert, w3)
49
+ h = torch.matmul(h, w2)
50
+ # h shape (tokens_per_expert(varying), dim)
51
+ out_experts_splits.append(h)
52
+ out = torch.cat(out_experts_splits, dim=0)
53
+
54
+ # TODO:optimize with GroupedGEMM
55
+ # https://github.com/pytorch/pytorch/pull/150374
56
+ # _gouped_mm requires shapes to be multiple of 8
57
+ # offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
58
+ # h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
59
+ # h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
60
+ # out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
61
+ else:
62
+ # x shape (num_experts, tokens_per_expert, dim)
63
+ h = F.silu(torch.bmm(x, self.w1))
64
+ h = h * torch.bmm(x, self.w3)
65
+ # out shape (num_experts, tokens_per_expert, dim)
66
+ out = torch.bmm(h, self.w2)
67
+ return out
68
+
69
+ def init_weights(self, init_std: float):
70
+ nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
71
+ nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
72
+ nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
73
+
74
+
75
+ class TokenChoiceTopKRouter(nn.Module):
76
+ """This class implements token-choice routing. In token-choice top-K routing, each token is
77
+ routed to top K experts based on the router scores.
78
+
79
+ Args:
80
+ gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
81
+ dim (int): Dimension of input tokens.
82
+ num_experts (int): Number of experts in each moe layer.
83
+ top_k (int): Number of experts each token will be routed to in token-choice routing.
84
+ use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim: int,
90
+ num_experts: int,
91
+ top_k: int,
92
+ use_sigmoid: bool = False,
93
+ ):
94
+ super().__init__()
95
+ self.gate = nn.Linear(dim, num_experts, bias=False)
96
+ self.num_experts = num_experts
97
+ self.top_k = top_k
98
+ self.use_sigmoid = use_sigmoid
99
+
100
+ def forward(
101
+ self, x: torch.Tensor
102
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ """
104
+ Args:
105
+ x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
106
+
107
+ Returns:
108
+ routed_input (torch.Tensor):
109
+ Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
110
+ token_indices (torch.Tensor):
111
+ Token indices for routed_input with shape ``(bs*slen*top_k,)``.
112
+ num_local_tokens_per_expert (torch.Tensor):
113
+ Number of tokens assigned to each expert with shape ``(num_experts,)``.
114
+ """
115
+ # scores shape (bs*slen, num_experts)
116
+ scores = self.gate(x)
117
+
118
+ # By default, sigmoid or softmax is performed in float32 to avoid loss explosion
119
+ if self.use_sigmoid:
120
+ scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
121
+ else:
122
+ scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
123
+
124
+ # top scores shape (bs*slen, top_k)
125
+ top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
126
+ # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
127
+
128
+ # group tokens together by expert indices from 0 to num_experts and pass that to experts forward
129
+ num_local_tokens_per_expert = torch.histc(
130
+ selected_experts_indices.view(-1),
131
+ bins=self.num_experts,
132
+ min=0,
133
+ max=self.num_experts,
134
+ )
135
+ # token_indices_experts_sorted shape (bs*slen*top_k,)
136
+ token_indices_experts_sorted = torch.argsort(
137
+ selected_experts_indices.view(-1), stable=True
138
+ )
139
+ top_scores = top_scores.view(-1)[token_indices_experts_sorted]
140
+ token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
141
+
142
+ return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
143
+
144
+ def init_weights(self, init_std: float):
145
+ nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
146
+
147
+
148
+ # TODO: implement load balancing auxiliary loss for token-choice routing
149
+ class MoE(nn.Module):
150
+ def __init__(self, model_args: TransformerModelArgs):
151
+ super().__init__()
152
+ dim = model_args.dim
153
+ hidden_dim = 4 * model_args.dim
154
+ ffn_dim_multiplier = model_args.ffn_dim_multiplier
155
+ hidden_dim = int(2 * hidden_dim / 3)
156
+ if ffn_dim_multiplier is not None:
157
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
158
+
159
+ num_experts = model_args.num_experts
160
+
161
+ hidden_dim_denom = 1
162
+ if model_args.auto_scale_hidden_dim:
163
+ hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert)
164
+
165
+ if model_args.auto_scale_hidden_dim:
166
+ hidden_dim = int(hidden_dim / hidden_dim_denom)
167
+ hidden_dim += -hidden_dim % model_args.multiple_of
168
+
169
+ self.experts = GroupedExperts(
170
+ dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
171
+ )
172
+ self.router = TokenChoiceTopKRouter(
173
+ dim=dim, num_experts=num_experts, top_k=model_args.top_k
174
+ )
175
+ self.shared_expert = (
176
+ GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
177
+ if model_args.use_shared_expert
178
+ else None
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ """
183
+ Args:
184
+ x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
185
+
186
+ Returns:
187
+ out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
188
+ """
189
+ bs, slen, dim = x.shape
190
+ # top_scores and selected_indices shape (bs*slen*top_k,)
191
+ # num_local_tokens_per_expert shape (num_experts,)
192
+ (
193
+ top_scores,
194
+ token_indices,
195
+ num_local_tokens_per_expert,
196
+ ) = self.router(x.reshape(bs * slen, dim))
197
+
198
+ # shape (bs*slen*top_k, dim)
199
+ token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
200
+
201
+ # shape (bs*slen*top_k, dim)
202
+ routed_input = torch.gather(
203
+ x.view(-1, dim),
204
+ dim=0,
205
+ index=token_indices,
206
+ )
207
+ routed_input = routed_input * top_scores.reshape(-1, 1)
208
+
209
+ # shape (bs*slen*top_k, dim)
210
+ routed_output = self.experts(routed_input, num_local_tokens_per_expert)
211
+
212
+ # shared expert
213
+ if self.shared_expert is not None:
214
+ out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
215
+ bs * slen, dim
216
+ )
217
+ else:
218
+ out = torch.zeros_like(x.reshape(bs * slen, dim))
219
+
220
+ out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
221
+ out = out.reshape(bs, slen, dim)
222
+ return out
223
+
224
+ def init_weights(self, init_std: float):
225
+ self.experts.init_weights(init_std)
226
+ self.router.init_weights(init_std)
227
+ if self.shared_expert is not None:
228
+ self.shared_expert.init_weights(init_std)
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from dataclasses import dataclass
11
+ from typing import Any, Optional
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard
16
+ from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
17
+ from torchtitan.components.checkpoint import MODEL
18
+ from torchtitan.config_manager import JobConfig
19
+ from torchtitan.tools.logging import init_logger, logger
20
+ from torchtitan.train import Trainer
21
+
22
+ # Sharding dims for MP checkpoints
23
+
24
+ column_parallel = [
25
+ "tok_embeddings",
26
+ "wq",
27
+ "wk",
28
+ "wv",
29
+ "wqkv",
30
+ "w_in_shared_FD",
31
+ "w_out_eF_D",
32
+ "w_swiglu_FD",
33
+ "output",
34
+ "_linear",
35
+ "c_fc",
36
+ "vision_projection",
37
+ ]
38
+
39
+ row_parallel = [
40
+ "wo",
41
+ "w_out_shared_DF",
42
+ "w_in_eD_F",
43
+ "moe_w_swiglu_eD_F",
44
+ "c_proj",
45
+ ]
46
+
47
+
48
+ def convert_to_titan_fqns(fqn: str) -> list[str]:
49
+ # From the stored checkpoint keys to TorchTitan keys.
50
+ if "wqkv" in fqn and "layer_norm_weight" not in fqn:
51
+ ret = []
52
+ for k in ("wq", "wk", "wv"):
53
+ ret.append(fqn.replace("wqkv", k))
54
+ return ret
55
+ return [fqn]
56
+
57
+
58
+ def get_shard_dim(fqn: str) -> Optional[int]:
59
+ if "bias" in fqn:
60
+ # Some bias params are still sharded
61
+ if "resblocks" in fqn:
62
+ for k in ("wq", "wk", "wv", "c_fc"):
63
+ if k in fqn:
64
+ return 0
65
+ return None
66
+ elif any([x in fqn for x in column_parallel]):
67
+ return 0
68
+ elif any([x in fqn for x in row_parallel]):
69
+ return 1
70
+ else:
71
+ return None
72
+
73
+
74
+ def split_fused_qkv(shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
75
+ qkvs = [torch.split(shard, [640, 128, 128]) for shard in shards]
76
+ q = torch.cat([qkv[0] for qkv in qkvs], dim=0)
77
+ k = torch.cat([qkv[1] for qkv in qkvs], dim=0)
78
+ v = torch.cat([qkv[2] for qkv in qkvs], dim=0)
79
+ return q, k, v
80
+
81
+
82
+ @dataclass
83
+ class _Assignment:
84
+ loader_id: int
85
+ filename: str
86
+ fqns: tuple[str, ...]
87
+ shapes: tuple[torch.Size, ...]
88
+ dtypes: tuple[torch.dtype, ...]
89
+
90
+
91
+ @dataclass
92
+ class _AssignmentRound:
93
+ loader_assignments: dict[int, _Assignment] # List of assignments for each loader
94
+
95
+
96
+ class CheckpointConverter:
97
+ TOTAL_SHARDS = 8
98
+
99
+ def __init__(
100
+ self,
101
+ process_group: dist.ProcessGroup,
102
+ path: str,
103
+ loader_every_n_ranks: int = 8,
104
+ ) -> None:
105
+ self.path = path
106
+ self.pg = process_group
107
+ self.my_rank = dist.get_rank(self.pg)
108
+ self.loader_every_n_ranks = loader_every_n_ranks
109
+ self.loader_id = self.my_rank // loader_every_n_ranks
110
+ self.should_load = (
111
+ self.my_rank % loader_every_n_ranks == 0
112
+ and self.loader_id < CheckpointConverter.TOTAL_SHARDS
113
+ )
114
+ self.total_loader = CheckpointConverter.TOTAL_SHARDS
115
+ self.titan_fqn_to_stored_fqn: dict[str, str] = {}
116
+ self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {}
117
+ self.total_send_bytes = 0
118
+ self.total_recv_bytes = 0
119
+
120
+ def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
121
+ begin = time.time()
122
+ self._load_metadata()
123
+ self._create_fqn_mappings(state_dict)
124
+ rounds = self._get_load_assignments(state_dict)
125
+
126
+ for assignments in rounds:
127
+ loader_assignments = assignments.loader_assignments
128
+ loaded_state_dict = None
129
+ # Let each loader to load its own data and move to its GPU.
130
+ for i in range(self.total_loader):
131
+ # This loader doesn't have any loading assignment for this round.
132
+ if i not in loader_assignments:
133
+ continue
134
+ # This rank is not the loader
135
+ if i != self.loader_id or not self.should_load:
136
+ continue
137
+ loaded_state_dict = self._load_round(loader_assignments[i])
138
+
139
+ results = []
140
+ for i in range(self.total_loader):
141
+ if i not in loader_assignments:
142
+ continue
143
+
144
+ if i == self.loader_id and self.should_load:
145
+ # This rank is the loader. It needs to send the loaded data to
146
+ # the other ranks.
147
+ assert loaded_state_dict is not None
148
+ results.append(
149
+ self._reshard_send(loader_assignments[i], loaded_state_dict)
150
+ )
151
+ else:
152
+ results.append(
153
+ self._reshard_receive(loader_assignments[i], state_dict)
154
+ )
155
+
156
+ self._reshard(results, state_dict)
157
+
158
+ torch.cuda.synchronize()
159
+ logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.")
160
+ logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB")
161
+ logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB")
162
+ return state_dict
163
+
164
+ def _get_file_path(self, loader_id: int) -> str:
165
+ return os.path.join(self.path, f"consolidated.0{loader_id}.pth")
166
+
167
+ def _load_metadata(self) -> None:
168
+ if not self.should_load:
169
+ self.read_dict = {}
170
+ return
171
+ self.read_dict = torch.load(
172
+ self._get_file_path(self.loader_id),
173
+ mmap=True,
174
+ weights_only=False,
175
+ )
176
+
177
+ def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None:
178
+ if not self.read_dict:
179
+ return
180
+
181
+ # Create the mapping from the stored checkpoint keys to TorchTitan keys.
182
+ for fqn in list(self.read_dict.keys()):
183
+ titan_fqns = convert_to_titan_fqns(fqn)
184
+ # We don't know how to process _extra_state
185
+ if "_extra_state" in fqn:
186
+ self.read_dict.pop(fqn)
187
+ continue
188
+
189
+ if titan_fqns[0] not in state_dict:
190
+ for titan_fqn in titan_fqns:
191
+ assert titan_fqns[0] not in state_dict
192
+ self.read_dict.pop(fqn)
193
+ continue
194
+ self.stored_fqn_to_titan_fqn[fqn] = titan_fqns
195
+ for titan_fqn in titan_fqns:
196
+ self.titan_fqn_to_stored_fqn[titan_fqn] = fqn
197
+
198
+ assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), (
199
+ set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys()),
200
+ set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys()),
201
+ )
202
+
203
+ def _get_load_assignments(
204
+ self, state_dict: dict[str, torch.Tensor]
205
+ ) -> list[_AssignmentRound]:
206
+ if self.my_rank == 0:
207
+ rounds: list[_AssignmentRound] = []
208
+ size = 0
209
+ fqns = []
210
+ shapes = []
211
+ dtypes = []
212
+
213
+ # All loader must load all the FQNs because the checkpoint is purely TP sharded.
214
+ all_keys = list(self.read_dict.keys())
215
+ for fqn in all_keys:
216
+ fqns.append(fqn)
217
+ shapes.append(self.read_dict[fqn].shape)
218
+ dtypes.append(self.read_dict[fqn].dtype)
219
+ size += self.read_dict[fqn].numel() * self.read_dict[fqn].element_size()
220
+ if size < 1e9 and fqn != all_keys[-1]:
221
+ continue
222
+
223
+ logger.info(f"Adding {fqns} to round {len(rounds)}")
224
+ round_assignment = _AssignmentRound(loader_assignments={})
225
+ for loader_id in range(self.total_loader):
226
+ path = self._get_file_path(loader_id)
227
+ round_assignment.loader_assignments[loader_id] = _Assignment(
228
+ filename=path,
229
+ fqns=tuple(fqns),
230
+ shapes=tuple(shapes),
231
+ dtypes=tuple(dtypes),
232
+ loader_id=loader_id,
233
+ )
234
+ rounds.append(round_assignment)
235
+ size = 0
236
+ fqns.clear()
237
+ shapes.clear()
238
+ dtypes.clear()
239
+
240
+ object_list: list[Any] = [
241
+ rounds,
242
+ self.titan_fqn_to_stored_fqn,
243
+ self.stored_fqn_to_titan_fqn,
244
+ ]
245
+ else:
246
+ object_list = [None, None, None]
247
+
248
+ dist.broadcast_object_list(object_list, src=0, group=self.pg)
249
+ rounds = object_list[0]
250
+ self.titan_fqn_to_stored_fqn = object_list[1]
251
+ self.stored_fqn_to_titan_fqn = object_list[2]
252
+ return rounds
253
+
254
+ def _load_round(self, assignment: _Assignment) -> dict[str, torch.Tensor]:
255
+ ret = {}
256
+ assert self.read_dict
257
+ for fqn in assignment.fqns:
258
+ ret[fqn] = self.read_dict[fqn].to(device="cuda")
259
+ return ret
260
+
261
+ def _reshard_send(
262
+ self,
263
+ assignment: _Assignment,
264
+ loaded_state_dict: dict[str, torch.Tensor],
265
+ ) -> dict[str, torch.Tensor]:
266
+ flatten_tensors = [t.flatten() for t in loaded_state_dict.values()]
267
+ flatten_tensor = torch.concat(flatten_tensors)
268
+ assert self.loader_id == assignment.loader_id
269
+ rank = self.loader_id * self.loader_every_n_ranks
270
+ assert rank == self.my_rank
271
+ logger.info(f"Sending {assignment.filename} from {rank} {self.loader_id}")
272
+ logger.info(f"Sending {assignment.fqns}")
273
+ dist.broadcast(flatten_tensor, src=rank, group=self.pg)
274
+ self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
275
+ return loaded_state_dict
276
+
277
+ def _reshard_receive(
278
+ self, assignment: _Assignment, state_dict: dict[str, torch.Tensor]
279
+ ) -> dict[str, torch.Tensor]:
280
+ flatten_tensor = torch.empty(
281
+ sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)),
282
+ dtype=assignment.dtypes[0],
283
+ device="cuda",
284
+ )
285
+ rank = assignment.loader_id * self.loader_every_n_ranks
286
+ dist.broadcast(flatten_tensor, src=rank, group=self.pg)
287
+ self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
288
+
289
+ ret: dict[str, torch.Tensor] = {}
290
+ loc = 0
291
+ for fqn, shape, dtype in zip(
292
+ assignment.fqns, assignment.shapes, assignment.dtypes
293
+ ):
294
+ n_ele = math.prod(shape)
295
+ ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape)
296
+ loc += n_ele
297
+ return ret
298
+
299
+ def _reshard(
300
+ self,
301
+ results: list[dict[str, torch.Tensor]],
302
+ state_dict: dict[str, torch.Tensor],
303
+ ) -> None:
304
+ def _inplace_copy(fqn: str, full_tensors: tuple[torch.Tensor, ...]):
305
+ titan_fqns = self.stored_fqn_to_titan_fqn[fqn]
306
+ assert len(titan_fqns) == len(full_tensors)
307
+ for titan_fqn, full_tensor in zip(titan_fqns, full_tensors):
308
+ dtensor = state_dict[titan_fqn]
309
+ logger.info(f"{titan_fqn} {full_tensor.sum()}")
310
+ assert isinstance(dtensor, DTensor)
311
+ shape, offset = compute_local_shape_and_global_offset(
312
+ full_tensor.shape, dtensor.device_mesh, dtensor.placements
313
+ )
314
+ slices = [
315
+ slice(cur_offset, cur_offset + cur_shape)
316
+ for cur_shape, cur_offset in zip(shape, offset)
317
+ ]
318
+ logger.info(
319
+ f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} "
320
+ f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} "
321
+ f"{dtensor.placements=} {dtensor.device_mesh=} "
322
+ )
323
+ dtensor.to_local().copy_(full_tensor[slices])
324
+
325
+ def _concat_shards(fqn, shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
326
+ if "wqkv" in fqn:
327
+ if "layer_norm" in fqn:
328
+ return (shards[0],)
329
+ return split_fused_qkv(shards)
330
+
331
+ shard_dim = get_shard_dim(fqn)
332
+ if shard_dim is None:
333
+ return (shards[0],)
334
+ return (torch.cat(shards, dim=shard_dim),)
335
+
336
+ fqns = list(results[0].keys())
337
+ for result in results:
338
+ assert list(result.keys()) == fqns
339
+
340
+ for fqn in fqns:
341
+ full_tensors = _concat_shards(fqn, [result[fqn] for result in results])
342
+ _inplace_copy(fqn, full_tensors)
343
+
344
+
345
+ def _create_verified_state_dict(
346
+ pg: dist.ProcessGroup, mesh: DeviceMesh
347
+ ) -> dict[str, torch.Tensor]:
348
+ placements = [Shard(0)]
349
+ state_dict = {
350
+ "tok_embeddings.weight": torch.rand(
351
+ 25256 * 8, 5120, device="cuda", dtype=torch.bfloat16
352
+ ),
353
+ "layers.47.attention.wqkv.layer_norm_weight": torch.rand(
354
+ 5120, device="cuda", dtype=torch.bfloat16
355
+ ),
356
+ "layers.47.attention.wq.weight": torch.rand(
357
+ 640 * 8, 5120, device="cuda", dtype=torch.bfloat16
358
+ ),
359
+ "layers.47.attention.wk.weight": torch.rand(
360
+ 128 * 8, 5120, device="cuda", dtype=torch.bfloat16
361
+ ),
362
+ "layers.47.attention.wv.weight": torch.rand(
363
+ 128 * 8, 5120, device="cuda", dtype=torch.bfloat16
364
+ ),
365
+ "layers.47.attention.wo.weight": torch.rand(
366
+ 5120, 640 * 8, device="cuda", dtype=torch.bfloat16
367
+ ),
368
+ # "layers.47.feed_forward.router_DE": torch.rand(5120, 128, device="cuda", dtype=torch.bfloat16),
369
+ # "layers.47.feed_forward.running_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
370
+ # "layers.47.feed_forward.global_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
371
+ "layers.47.feed_forward.w_in_shared_FD.weight": torch.rand(
372
+ 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
373
+ ),
374
+ "layers.47.feed_forward.w_out_shared_DF.weight": torch.rand(
375
+ 5120, 1024 * 8, device="cuda", dtype=torch.bfloat16
376
+ ),
377
+ "layers.47.feed_forward.w_swiglu_FD.weight": torch.rand(
378
+ 1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
379
+ ),
380
+ "layers.47.feed_forward.norm.weight": torch.rand(
381
+ 5120, device="cuda", dtype=torch.bfloat16
382
+ ),
383
+ "layers.47.feed_forward.experts.moe_w_in_eD_F": torch.rand(
384
+ 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
385
+ ),
386
+ "layers.47.feed_forward.experts.moe_w_out_eF_D": torch.rand(
387
+ 131072 * 8, 5120, device="cuda", dtype=torch.bfloat16
388
+ ),
389
+ "layers.47.feed_forward.experts.moe_w_swiglu_eD_F": torch.rand(
390
+ 655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
391
+ ),
392
+ }
393
+ return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()}
394
+
395
+
396
+ def _verify_state_dict(
397
+ state_dict: dict[str, torch.Tensor], path: str, rank: int
398
+ ) -> None:
399
+ stored_state_dicts = [
400
+ torch.load(
401
+ os.path.join(path, f"consolidated.0{i}.pth"),
402
+ map_location="cpu",
403
+ weights_only=False,
404
+ mmap=True,
405
+ )
406
+ for i in range(8)
407
+ ]
408
+
409
+ def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None:
410
+ logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ")
411
+ shards = [stored_state_dicts[i][fqn] for i in range(8)]
412
+ full_tensor = dtensor.full_tensor()
413
+ logger.info(f"Gather {fqn} {full_tensor.shape} completely.")
414
+
415
+ if rank > 0:
416
+ return
417
+
418
+ if len(shards[0].shape) == 1:
419
+ assert full_tensor.shape == shards[0].shape, fqn
420
+ assert torch.allclose(shards[0].to(device="cuda"), full_tensor), fqn
421
+ return
422
+ elif shards[0].shape[0] == full_tensor.shape[0]:
423
+ concat_shards = torch.cat(shards, dim=1)
424
+ logger.info(f"Load {fqn} completely.")
425
+ elif shards[0].shape[1] == full_tensor.shape[1]:
426
+ concat_shards = torch.cat(shards, dim=0)
427
+ logger.info(f"Load {fqn} completely.")
428
+
429
+ concat_shards = concat_shards.to(device="cuda")
430
+ logger.info(f"Move to GPU {fqn} completely.")
431
+
432
+ assert concat_shards.shape == full_tensor.shape, fqn
433
+ assert concat_shards.dtype == full_tensor.dtype, fqn
434
+ assert concat_shards.device == full_tensor.device, fqn
435
+ assert torch.allclose(concat_shards, full_tensor), fqn
436
+
437
+ for k, v in state_dict.items():
438
+ if "wq" in k and "wqkv" not in k:
439
+ pass
440
+ elif "wk" in k:
441
+ pass
442
+ elif "wv" in k:
443
+ pass
444
+ else:
445
+ assert v is not None, k
446
+ read_and_verify_tensor(k, v)
447
+
448
+
449
+ if __name__ == "__main__":
450
+ init_logger()
451
+ config = JobConfig()
452
+ config.parser.add_argument(
453
+ "--checkpoint.convert_path",
454
+ type=str,
455
+ default="",
456
+ help="""Specify the path of the target checkpoint to convert.""",
457
+ )
458
+ config.parser.add_argument(
459
+ "--checkpoint.convert_load_every_n_ranks",
460
+ type=int,
461
+ default=8,
462
+ help="""
463
+ Specify the interval at which ranks are assigned to load checkpoints.
464
+
465
+ For example, if this number is 4, then ranks 0, 4, 8, ... will load the
466
+ checkpoint. Each loader is responsible for loading one file. If there
467
+ are more loaders than files, only the first few loaders will be assigned
468
+ to load the checkpoint. The default value is 8.
469
+ """,
470
+ )
471
+ config.parser.add_argument(
472
+ "--checkpoint.fake_model",
473
+ action="store_true",
474
+ help="""If true, the model will be fake.""",
475
+ )
476
+ config.parse_args()
477
+ assert config.checkpoint.convert_path != ""
478
+
479
+ trainer: Optional[Trainer] = None
480
+
481
+ try:
482
+ trainer = Trainer(config)
483
+ if os.path.exists(trainer.checkpointer.folder):
484
+ raise RuntimeError(
485
+ "The checkpoint folder already exists. Abort to avoid overwriting "
486
+ f"the checkpoint. {trainer.checkpointer.folder=}"
487
+ )
488
+ if config.checkpoint.fake_model:
489
+ state_dict = _create_verified_state_dict(
490
+ trainer.world_mesh.get_group(), trainer.world_mesh
491
+ )
492
+ else:
493
+ state_dict = trainer.checkpointer.states[MODEL].state_dict()
494
+
495
+ size = 0
496
+ for v in state_dict.values():
497
+ size += v.numel() * v.element_size()
498
+ logger.info(f"Total size of the model: {size / 1e9:.2f} GB")
499
+
500
+ # Do not support PP yet, we will need to iterate over the PP dimension and
501
+ # extract the corresponding state_dict and device_mesh.
502
+ if "freq_cis" in state_dict:
503
+ state_dict.pop("freqs_cis")
504
+
505
+ state_dict = CheckpointConverter(
506
+ process_group=trainer.world_mesh.get_group(),
507
+ path=config.checkpoint.convert_path,
508
+ loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks,
509
+ ).convert(state_dict)
510
+
511
+ class DummyModel:
512
+ def __init__(self, state_dict: dict[str, torch.Tensor]) -> None:
513
+ self._state_dict = state_dict
514
+
515
+ def state_dict(self) -> dict[str, torch.Tensor]:
516
+ return self._state_dict
517
+
518
+ if config.checkpoint.fake_model:
519
+ begin = time.time()
520
+ _verify_state_dict(
521
+ state_dict,
522
+ config.checkpoint.convert_path,
523
+ trainer.world_mesh.get_rank(),
524
+ )
525
+ dist.barrier()
526
+ logger.info(f"Verifies state_dict {time.time() - begin}.")
527
+ else:
528
+ # oh, this is pretty bad, when can we get rid of the freqs_cis issue?
529
+ state_dict["freqs_cis"] = None
530
+ trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
531
+ trainer.checkpointer.model_weights_only = True
532
+ trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
533
+ trainer.checkpointer.save(curr_step=0, force=True)
534
+ time.sleep(2)
535
+ finally:
536
+ pass
torchtitan/experiments/multimodal/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from mm_dataset import build_mm_dataloader
8
+
9
+ from torchtitan.components.loss import build_cross_entropy_loss
10
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
11
+ from torchtitan.components.optimizer import build_optimizers
12
+ from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
13
+ from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
14
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
15
+
16
+ from .model import ModelArgs, MultimodalDecoder, VisionEncoder
17
+
18
+ __all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"]
19
+
20
+ llama4_mm_configs = {
21
+ # TODO: add configs for llama4 multimodal
22
+ }
23
+
24
+ register_train_spec(
25
+ TrainSpec(
26
+ name="llama4_multimodal",
27
+ cls=MultimodalDecoder,
28
+ config=llama4_mm_configs,
29
+ parallelize_fn=parallelize_llama,
30
+ pipelining_fn=pipeline_llama,
31
+ build_optimizers_fn=build_optimizers,
32
+ build_lr_schedulers_fn=build_lr_schedulers,
33
+ build_dataloader_fn=build_mm_dataloader,
34
+ build_tokenizer_fn=build_tiktoken_tokenizer,
35
+ build_loss_fn=build_cross_entropy_loss,
36
+ )
37
+ )
torchtitan/experiments/multimodal/mm_collator.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from tokenizer.tiktoken import IGNORE_INDEX
16
+
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+
20
+ def padded_collate(
21
+ batch: List[Dict[str, List[int]]],
22
+ padding_idx: int = 0,
23
+ ignore_idx: int = -100,
24
+ ) -> Dict[str, torch.Tensor]:
25
+ """Pad a batch of sequences to the longest sequence length in the batch, and
26
+ convert integer lists to tensors.
27
+
28
+ Args:
29
+ batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
30
+ padding_idx (int): Padding index for input ids. Defaults to 0.
31
+ ignore_idx (int): Padding index for labels. Defaults to -100.
32
+
33
+ Returns:
34
+ Dict[str, torch.Tensor]: Collated input and label tensors.
35
+
36
+ Example:
37
+ >>> token_pairs = [
38
+ >>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
39
+ >>> {"input_ids": [7,], "labels": [10,]},
40
+ >>> ]
41
+ >>> collated = padded_collate(
42
+ >>> batch=token_pairs,
43
+ >>> padding_idx=padding_idx,
44
+ >>> ignore_idx=ignore_idx,
45
+ >>> )
46
+ >>> collated["input_ids"]
47
+ >>> tensor([[1, 2, 3], [7, 0, 0]])
48
+ >>> collated["labels"]
49
+ >>> tensor([[4, 5, 6], [10, -100, -100]])
50
+ """
51
+ input_ids = pad_sequence(
52
+ [x["input_ids"] for x in batch],
53
+ batch_first=True,
54
+ padding_value=padding_idx,
55
+ )
56
+ labels = pad_sequence(
57
+ [x["labels"] for x in batch],
58
+ batch_first=True,
59
+ padding_value=ignore_idx,
60
+ )
61
+
62
+ input_ids_seq_len = input_ids.shape[-1]
63
+ labels_seq_len = labels.shape[-1]
64
+
65
+ # Hack to pad correctly and not use max_seq_len, which is costly
66
+ if input_ids_seq_len > labels_seq_len:
67
+ labels = F.pad(
68
+ labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx
69
+ )
70
+ elif labels_seq_len > input_ids_seq_len:
71
+ input_ids = F.pad(
72
+ input_ids,
73
+ (0, labels_seq_len - input_ids_seq_len),
74
+ value=padding_idx,
75
+ )
76
+ return {"input_ids": input_ids, "labels": labels}
77
+
78
+
79
+ # NOTE Inspired from torchtune.data._collate.py
80
+ @dataclass
81
+ class MultiModalCollator:
82
+ padding_idx: int = 128004
83
+ ignore_idx: int = IGNORE_INDEX
84
+ pad_max_tiles: Optional[int] = None
85
+ pad_max_images: Optional[int] = None
86
+
87
+ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
88
+ """Pad a batch of text sequences, tiled image tensors, aspect ratios,
89
+ and cross attention masks. This can be used for both training and inference.
90
+
91
+ ``batch`` is expected to be a list of sample dicts containing the following::
92
+ - "input_ids": List[int] of length text_seq_len, varies across samples
93
+ - "labels": List[int] of length text_seq_len, varies across samples
94
+ - "encoder_input": Dict[str, List[torch.Tensor]]
95
+ - "images": List[torch.Tensor], each with shape (n_tiles, c, h, w)
96
+ - "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio
97
+
98
+ Shape notation:
99
+ - c = channel dim
100
+ - h = height dim
101
+ - w = weight dim
102
+
103
+ Note:
104
+ For each element in the batch, ``len(images) == len(aspect_ratio)``.
105
+
106
+ This collater does the following:
107
+ (1) Pad text sequence and encoder mask to the longest sequence length in the batch
108
+ (2) Pad image tensors in the tile dimension with zeros to the largest number
109
+ of tiles in the batch
110
+ (3) Add empty images of zeros to samples up to max number of images in the batch
111
+ (4) Pad aspect ratios with (1,1) for all added padding images
112
+
113
+ Args:
114
+ batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids,
115
+ labels, images, and aspect_ratio.
116
+ padding_idx (int): Padding index for input token ids. Defaults to 0.
117
+ ignore_idx (int): Padding index for labels. Defaults to -100.
118
+ pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles
119
+ in the batch. Defaults to None.
120
+ pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
121
+ in the batch. Defaults to None.
122
+
123
+ Returns:
124
+ Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors.
125
+ - tokens: Tensor of shape (bsz, max_seq_len)
126
+ - labels: Tensor of shape (bsz, max_seq_len)
127
+ - images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)
128
+ - aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
129
+
130
+ Example:
131
+ >>> image_id = 1
132
+ >>> tokens_per_tile = 5
133
+ >>> c, h, w = 1, 1, 1
134
+ >>> batch = [
135
+ ... {
136
+ ... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
137
+ ... "encoder_input": {
138
+ ... # One image with two tiles, one image with three tiles
139
+ ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
140
+ ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
141
+ ... },
142
+ ... },
143
+ ... {
144
+ ... "input_ids": [1, 4], "labels": [8, 9],
145
+ ... "encoder_input": {
146
+ ... # One image with four tiles
147
+ ... "images": [torch.ones(4, c, h, w)],
148
+ ... "aspect_ratio": [torch.tensor([2, 2])],
149
+ ... },
150
+ ... },
151
+ ... ]
152
+ ... collator = MultiModalCollator(pad_max_tiles=4)
153
+ >>> model_inputs = collator(batch=batch)
154
+ >>> print(model_inputs["input_ids"])
155
+ tensor([[1, 2, 1, 3],
156
+ [1, 4, 0, 0]])
157
+ >>> print(model_inputs["labels"])
158
+ tensor([[4, 5, 6, 7],
159
+ [8, 9, -100, -100]])
160
+ >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
161
+ torch.Size([2, 2, 4, 1, 1, 1])
162
+ >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2)
163
+ torch.Size([2, 2, 2])
164
+ >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four
165
+ tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
166
+ >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four
167
+ tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
168
+ >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded
169
+ tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
170
+ >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample
171
+ tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])
172
+ """
173
+ # Text tokens can be handled independently by existing collaters
174
+ text_only = [
175
+ {"input_ids": sample["input_ids"], "labels": sample["labels"]}
176
+ for sample in batch
177
+ ]
178
+ collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx)
179
+
180
+ if self.pad_max_tiles is None:
181
+ # Get max number of tiles in batch
182
+ max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch)
183
+ else:
184
+ max_num_tiles = self.pad_max_tiles
185
+
186
+ # Pad images and aspect ratios to max number of tiles
187
+ batch_images = []
188
+ batch_aspect_ratios = []
189
+
190
+ for sample in batch:
191
+ sample_images = []
192
+ for image in sample["encoder_input"]["images"]:
193
+ # Single image in each sample has shape (n_tiles, c, h, w)
194
+ n_tiles = image.shape[0]
195
+ # Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len)
196
+ # where image_seq_len = n_tiles * tokens_per_tile
197
+ padding_tiles = max_num_tiles - n_tiles
198
+
199
+ # Image should now have shape (max_num_tiles, c, h, w)
200
+ padded_image = F.pad(
201
+ image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0
202
+ )
203
+
204
+ sample_images.append(padded_image)
205
+ # Stack multiple images and masks per sample in num_images dimension
206
+ batch_images.append(torch.stack(sample_images))
207
+ batch_aspect_ratios.append(
208
+ torch.stack(sample["encoder_input"]["aspect_ratio"])
209
+ )
210
+ # Finally, pad images, masks, aspect ratios to max number of images in batch
211
+ # (bsz, max_num_images, max_num_tiles, c, h, w)
212
+ collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0)
213
+ # (bsz, max_num_images, 2)
214
+ collated_aspect_ratios = pad_sequence(
215
+ batch_aspect_ratios, batch_first=True, padding_value=1
216
+ )
217
+
218
+ batch_dict = {
219
+ "input_ids": collated_text["input_ids"],
220
+ "labels": collated_text["labels"],
221
+ "encoder_input": {
222
+ "images": collated_images,
223
+ "aspect_ratio": collated_aspect_ratios,
224
+ },
225
+ }
226
+
227
+ return batch_dict
torchtitan/experiments/multimodal/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torchvision
torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc ADDED
Binary file (6.83 kB). View file
 
torchtitan/experiments/simple_fsdp/model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torchtitan.models.llama3 import Transformer, TransformerModelArgs
8
+ from .simple_fsdp import disable_data_parallel
9
+
10
+
11
+ class SimpleFSDPTransformer(Transformer):
12
+ def __init__(self, model_args: TransformerModelArgs):
13
+ super().__init__(model_args)
14
+ self.init_weights()
15
+
16
+ def init_weights(self, *args, **kwargs):
17
+ with disable_data_parallel():
18
+ super().init_weights(*args, **kwargs)
torchtitan/experiments/simple_fsdp/simple_fsdp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from contextlib import contextmanager
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from torch.distributed._tensor import (
15
+ distribute_tensor,
16
+ DTensor,
17
+ Partial,
18
+ Replicate,
19
+ Shard,
20
+ )
21
+ from torch.utils.checkpoint import (
22
+ checkpoint,
23
+ CheckpointPolicy,
24
+ create_selective_checkpoint_contexts,
25
+ )
26
+
27
+
28
+ _active_parametrization = True
29
+
30
+
31
+ @contextmanager
32
+ def disable_data_parallel():
33
+ global _active_parametrization
34
+ try:
35
+ _active_parametrization = False
36
+ yield
37
+ finally:
38
+ _active_parametrization = True
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class MixedPrecisionPolicy:
43
+ param_dtype: Optional[torch.dtype] = None
44
+ reduce_dtype: Optional[torch.dtype] = None
45
+
46
+
47
+ def fsdp_policy():
48
+ def _fsdp_recomp_policy():
49
+ def _custom_policy(ctx, func, *args, **kwargs):
50
+ to_recompute = func in {
51
+ torch.ops._c10d_functional.all_gather_into_tensor.default,
52
+ torch.ops._c10d_functional.wait_tensor.default,
53
+ torch.ops.aten._to_copy.default, # for dtype cast in FSDP
54
+ }
55
+ return (
56
+ CheckpointPolicy.MUST_RECOMPUTE
57
+ if to_recompute
58
+ else CheckpointPolicy.MUST_SAVE
59
+ )
60
+
61
+ return _custom_policy
62
+
63
+ return create_selective_checkpoint_contexts(_fsdp_recomp_policy())
64
+
65
+
66
+ class ReplicateComputation(torch.nn.Module):
67
+ def __init__(self, device_mesh, param_sharding, mode, regional_ac, mp_policy):
68
+ super().__init__()
69
+ self.device_mesh = device_mesh
70
+ self.param_sharding = param_sharding
71
+ self.mode = mode
72
+ self.compute_placements = [Replicate()] * self.device_mesh.ndim
73
+ self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
74
+ self.regional_ac = regional_ac
75
+ mp_policy = mp_policy or MixedPrecisionPolicy()
76
+ self.param_dtype = mp_policy.param_dtype
77
+ self.reduce_dtype = mp_policy.reduce_dtype
78
+
79
+ def replicate_compute(self, x):
80
+ # data parallel runtime replicate parameters and do local compute
81
+ # the gradients are partial tensors that needs to perform reduction
82
+ # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
83
+
84
+ # NOTE: specifying mixed precision is only available in pytorch_intern24
85
+ # https://github.com/tianyu-l/pytorch_intern24/pull/20
86
+ # support for FSDP + TP (assuming TP shards the inner-most dim)
87
+ if self.mode == "fully_shard" and x._spec.mesh.ndim == 2:
88
+ dp_placement, tp_placement = x._spec.placements
89
+ dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"]
90
+
91
+ # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
92
+ # TODO: we should consider merging this logic into DTensor redistribute API
93
+ sharded_local_tensor = x.to_local()
94
+ sharded_dtensor = DTensor.from_local(
95
+ sharded_local_tensor, dp_mesh, self.param_sharding
96
+ )
97
+
98
+ # the actuall FSDP all-gather on dp_mesh
99
+ # TODO(ruisizhang123): enable mixed-precision training here
100
+ # add the forward_dtype and backward_dtype back after landing changes in PyTorch DTensor
101
+ replicated_dtensor = sharded_dtensor.redistribute(
102
+ placements=self.compute_placements,
103
+ # forward_dtype=self.param_dtype,
104
+ # backward_dtype=self.reduce_dtype,
105
+ )
106
+
107
+ # re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh
108
+ # TODO: DTensor should support this mesh collasping operation
109
+ replicated_local_tensor = replicated_dtensor.to_local(
110
+ grad_placements=self.grad_placements
111
+ )
112
+ output = DTensor.from_local(
113
+ replicated_local_tensor, tp_mesh, (tp_placement,)
114
+ )
115
+ else:
116
+ output = x.redistribute(
117
+ placements=self.compute_placements,
118
+ # forward_dtype=self.param_dtype,
119
+ # backward_dtype=self.reduce_dtype,
120
+ ).to_local(grad_placements=self.grad_placements)
121
+
122
+ return output
123
+
124
+ def forward(self, x):
125
+ global _active_parametrization
126
+ # This should never be set to true during forward, only outside for model
127
+ # inspection / debugging / initialization
128
+ # model initialization can be done now through
129
+ # with disable_data_parallel():
130
+ # model.init_weights()
131
+ if not _active_parametrization:
132
+ return x
133
+
134
+ if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"):
135
+ # apply checkpointing to implement reshard_after_forward
136
+ output = checkpoint(
137
+ self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy
138
+ )
139
+ else:
140
+ output = self.replicate_compute(x)
141
+
142
+ return output
143
+
144
+
145
+ def data_parallel(
146
+ model,
147
+ device_mesh,
148
+ mode="replicate",
149
+ ac_mode: str = "none",
150
+ mp_policy: Optional[MixedPrecisionPolicy] = None,
151
+ ):
152
+ if mode == "replicate":
153
+ param_sharding = (Replicate(),)
154
+ elif mode == "fully_shard":
155
+ param_sharding = (Shard(0),)
156
+ elif mode == "hybrid_shard":
157
+ # replicate inter-host, fully shard intra-host
158
+ param_sharding = (Replicate(), Shard(0))
159
+ assert (
160
+ device_mesh.ndim == 2
161
+ ), "hybrid sharded data parallel requires 2D DeviceMesh"
162
+ else:
163
+ raise ValueError(f"Unsupported mode {mode}")
164
+
165
+ modules = list(model.modules())
166
+
167
+ # apply regional ac (with fsdp_policy) if no global ac is to be applied
168
+ regional_ac = ac_mode == "none"
169
+
170
+ for mod in modules:
171
+ params_dict = dict(mod.named_parameters(recurse=False))
172
+ for p_name, p in params_dict.items():
173
+ if p is not None and p.numel() > 0:
174
+ mod.register_parameter(
175
+ p_name,
176
+ # NOTE: for 2D we need to distribute_tensor a DTensor
177
+ # which requires latest change in pytorch_intern24
178
+ # https://github.com/tianyu-l/pytorch_intern24/pull/25
179
+ nn.Parameter(distribute_tensor(p, device_mesh, param_sharding)),
180
+ )
181
+ nn.utils.parametrize.register_parametrization(
182
+ mod,
183
+ p_name,
184
+ ReplicateComputation(
185
+ device_mesh,
186
+ param_sharding,
187
+ mode,
188
+ regional_ac,
189
+ mp_policy=mp_policy,
190
+ ),
191
+ unsafe=True,
192
+ )
193
+
194
+ return model
torchtitan/experiments/simple_fsdp/tests/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
torchtitan/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Import the built-in models here so that the corresponding register_model_spec()
9
+ # will be called.
10
+ import torchtitan.models.llama3 # noqa: F401
torchtitan/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (195 Bytes). View file
 
torchtitan/models/__pycache__/norms.cpython-312.pyc ADDED
Binary file (1.39 kB). View file
 
torchtitan/models/llama3/__init__.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8
+
9
+ from torchtitan.components.loss import build_cross_entropy_loss
10
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
11
+ from torchtitan.components.optimizer import build_optimizers
12
+ from torchtitan.datasets.hf_datasets import build_hf_dataloader
13
+ from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
14
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
15
+
16
+ from .model import Transformer, TransformerModelArgs
17
+ from .parallelize_llama import parallelize_llama
18
+ from .pipeline_llama import pipeline_llama
19
+
20
+ __all__ = [
21
+ "parallelize_llama",
22
+ "pipeline_llama",
23
+ "TransformerModelArgs",
24
+ "Transformer",
25
+ "llama3_configs",
26
+ ]
27
+
28
+
29
+ llama3_configs = {
30
+ "debugmodel": TransformerModelArgs(
31
+ dim=256, n_layers=8, n_heads=16, rope_theta=500000
32
+ ),
33
+ "8B": TransformerModelArgs(
34
+ dim=4096,
35
+ n_layers=32,
36
+ n_heads=32,
37
+ n_kv_heads=8,
38
+ ffn_dim_multiplier=1.3,
39
+ multiple_of=1024,
40
+ rope_theta=500000,
41
+ ),
42
+ "70B": TransformerModelArgs(
43
+ dim=8192,
44
+ n_layers=80,
45
+ n_heads=64,
46
+ n_kv_heads=8,
47
+ ffn_dim_multiplier=1.3,
48
+ multiple_of=4096,
49
+ rope_theta=500000,
50
+ ),
51
+ "405B": TransformerModelArgs(
52
+ dim=16384,
53
+ n_layers=126,
54
+ n_heads=128,
55
+ n_kv_heads=8,
56
+ ffn_dim_multiplier=1.2,
57
+ multiple_of=4096,
58
+ rope_theta=500000,
59
+ ),
60
+ }
61
+
62
+
63
+ register_train_spec(
64
+ TrainSpec(
65
+ name="llama3",
66
+ cls=Transformer,
67
+ config=llama3_configs,
68
+ parallelize_fn=parallelize_llama,
69
+ pipelining_fn=pipeline_llama,
70
+ build_optimizers_fn=build_optimizers,
71
+ build_lr_schedulers_fn=build_lr_schedulers,
72
+ build_dataloader_fn=build_hf_dataloader,
73
+ build_tokenizer_fn=build_tiktoken_tokenizer,
74
+ build_loss_fn=build_cross_entropy_loss,
75
+ )
76
+ )
torchtitan/models/llama3/parallelize_llama.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed._composable.replicate import replicate
15
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16
+ checkpoint_wrapper as ptd_checkpoint_wrapper,
17
+ )
18
+
19
+ from torch.distributed.device_mesh import DeviceMesh
20
+ from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
21
+ from torch.distributed.tensor import Replicate, Shard
22
+ from torch.distributed.tensor.parallel import (
23
+ ColwiseParallel,
24
+ parallelize_module,
25
+ PrepareModuleInput,
26
+ RowwiseParallel,
27
+ SequenceParallel,
28
+ )
29
+
30
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
31
+ from torchtitan.distributed import ParallelDims
32
+ from torchtitan.tools.logging import logger
33
+
34
+
35
+ def parallelize_llama(
36
+ model: nn.Module,
37
+ world_mesh: DeviceMesh,
38
+ parallel_dims: ParallelDims,
39
+ job_config: JobConfig,
40
+ ):
41
+ """
42
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
43
+ parallelism to the model.
44
+
45
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
46
+ the model must fit on GPU or CPU memory.
47
+ """
48
+
49
+ if parallel_dims.tp_enabled:
50
+ if (
51
+ job_config.parallelism.enable_async_tensor_parallel
52
+ and not job_config.training.compile
53
+ ):
54
+ raise RuntimeError("Async TP requires --training.compile")
55
+
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ float8_is_rowwise = job_config.float8.recipe_name in (
58
+ "rowwise",
59
+ "rowwise_with_gw_hp",
60
+ )
61
+
62
+ # For now, float8 all-gather with TP is only supported for tensorwise
63
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
64
+ # all-gather happens in high precision.
65
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
66
+
67
+ apply_tp(
68
+ model,
69
+ world_mesh["tp"],
70
+ loss_parallel=parallel_dims.loss_parallel_enabled,
71
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
72
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
73
+ )
74
+
75
+ if job_config.model.use_flex_attn:
76
+ if job_config.activation_checkpoint.mode == "selective":
77
+ raise ValueError(
78
+ "FlexAttention is not compatible with selective AC yet. "
79
+ "See https://github.com/pytorch/pytorch/issues/147879"
80
+ )
81
+
82
+ if parallel_dims.cp_enabled:
83
+ raise ValueError(
84
+ "FlexAttention is not compatible with CP yet. "
85
+ "We are still working on this."
86
+ )
87
+
88
+ if job_config.activation_checkpoint.mode != "none":
89
+ apply_ac(model, job_config.activation_checkpoint)
90
+
91
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
92
+ if job_config.training.compile:
93
+ apply_compile(model)
94
+
95
+ if (
96
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
97
+ ): # apply FSDP or HSDP, potentially with Context Parallel
98
+ if parallel_dims.dp_replicate_enabled:
99
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
100
+ else:
101
+ dp_mesh_dim_names = ("dp_shard_cp",)
102
+
103
+ apply_fsdp(
104
+ model,
105
+ world_mesh[tuple(dp_mesh_dim_names)],
106
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
107
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
108
+ pp_enabled=parallel_dims.pp_enabled,
109
+ cpu_offload=job_config.training.enable_cpu_offload,
110
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
111
+ )
112
+
113
+ if parallel_dims.dp_replicate_enabled:
114
+ logger.info("Applied HSDP to the model")
115
+ else:
116
+ logger.info("Applied FSDP to the model")
117
+
118
+ if parallel_dims.cp_enabled:
119
+ logger.info("Applied Context Parallel to the model")
120
+
121
+ if job_config.training.enable_cpu_offload:
122
+ logger.info("Applied CPU Offloading to the model")
123
+ elif parallel_dims.dp_replicate_enabled:
124
+ if world_mesh.ndim > 1:
125
+ raise RuntimeError("DDP has not supported > 1D parallelism")
126
+ apply_ddp(
127
+ model,
128
+ world_mesh,
129
+ enable_compile=job_config.training.compile,
130
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
131
+ )
132
+
133
+ return model
134
+
135
+
136
+ def apply_tp(
137
+ model: nn.Module,
138
+ tp_mesh: DeviceMesh,
139
+ loss_parallel: bool,
140
+ enable_float8_tensorwise_tp: bool,
141
+ enable_async_tp: bool,
142
+ ):
143
+ """Apply tensor parallelism."""
144
+ # 1. Parallelize the embedding and shard its outputs (which are the first
145
+ # transformer block's inputs)
146
+ # 2. Parallelize the root norm layer over the sequence dim
147
+ # 3. Parallelize the final linear output layer
148
+ parallelize_module(
149
+ model,
150
+ tp_mesh,
151
+ {
152
+ "tok_embeddings": RowwiseParallel(
153
+ input_layouts=Replicate(),
154
+ output_layouts=Shard(1),
155
+ ),
156
+ "norm": SequenceParallel(),
157
+ "output": ColwiseParallel(
158
+ input_layouts=Shard(1),
159
+ output_layouts=Shard(-1) if loss_parallel else Replicate(),
160
+ use_local_output=not loss_parallel,
161
+ ),
162
+ },
163
+ )
164
+
165
+ # Parallel styles used for transformer block linear weights and their
166
+ # inputs may be different for float8 linears with tensorwise scaling.
167
+ if enable_float8_tensorwise_tp:
168
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
169
+ from torchao.float8.float8_tensor_parallel import (
170
+ Float8ColwiseParallel,
171
+ Float8RowwiseParallel,
172
+ PrepareFloat8ModuleInput,
173
+ )
174
+
175
+ rowwise_parallel, colwise_parallel, prepare_module_input = (
176
+ Float8RowwiseParallel,
177
+ Float8ColwiseParallel,
178
+ PrepareFloat8ModuleInput,
179
+ )
180
+ else:
181
+ rowwise_parallel, colwise_parallel, prepare_module_input = (
182
+ RowwiseParallel,
183
+ ColwiseParallel,
184
+ PrepareModuleInput,
185
+ )
186
+
187
+ # Apply tensor + sequence parallelism to every transformer block
188
+ # NOTE: At the cost of model code change, we can accelerate Sequence Parallel
189
+ # by folding (and unfolding) the batch dimension and the sequence dimension.
190
+ # Examples can be found at https://github.com/pytorch/torchtitan/pull/437
191
+ for layer_id, transformer_block in model.layers.items():
192
+ layer_plan = {
193
+ "attention_norm": SequenceParallel(),
194
+ "attention": prepare_module_input(
195
+ input_layouts=(Shard(1), None),
196
+ desired_input_layouts=(Replicate(), None),
197
+ ),
198
+ "attention.wq": colwise_parallel(),
199
+ "attention.wk": colwise_parallel(),
200
+ "attention.wv": colwise_parallel(),
201
+ "attention.wo": rowwise_parallel(output_layouts=Shard(1)),
202
+ "ffn_norm": SequenceParallel(),
203
+ "feed_forward": prepare_module_input(
204
+ input_layouts=(Shard(1),),
205
+ desired_input_layouts=(Replicate(),),
206
+ ),
207
+ "feed_forward.w1": colwise_parallel(),
208
+ "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
209
+ "feed_forward.w3": colwise_parallel(),
210
+ }
211
+
212
+ parallelize_module(
213
+ module=transformer_block,
214
+ device_mesh=tp_mesh,
215
+ parallelize_plan=layer_plan,
216
+ )
217
+
218
+ if enable_async_tp:
219
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
220
+
221
+ torch._inductor.config._micro_pipeline_tp = True
222
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
223
+
224
+ logger.info(
225
+ f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
226
+ "Tensor Parallelism to the model"
227
+ )
228
+
229
+
230
+ # for selective op activation checkpointing
231
+ _save_list = {
232
+ torch.ops.aten.mm.default,
233
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
234
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
235
+ # for low precision training, it's useful to always save
236
+ # the result of max, since the absolute maximum is
237
+ # used to compute the scaling factor for quantization.
238
+ torch.ops.aten.max.default,
239
+ }
240
+
241
+
242
+ def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
243
+ valid_ac_modes = ("full", "selective")
244
+ if ac_config.mode not in valid_ac_modes:
245
+ raise ValueError(
246
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
247
+ )
248
+
249
+ if ac_config.mode == "full":
250
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
251
+
252
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
253
+ use_op_sac = ac_config.selective_ac_option == "op"
254
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
255
+ if not use_op_sac and not use_layer_sac:
256
+ raise ValueError(
257
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
258
+ f"Valid options: 'op' or a positive int representing layer frequency"
259
+ )
260
+ if use_op_sac:
261
+ from torch.utils.checkpoint import (
262
+ CheckpointPolicy,
263
+ create_selective_checkpoint_contexts,
264
+ )
265
+
266
+ def _get_custom_policy(meta):
267
+ def _custom_policy(ctx, func, *args, **kwargs):
268
+ mode = "recompute" if ctx.is_recompute else "forward"
269
+ mm_count_key = f"{mode}_mm_count"
270
+ if func == torch.ops.aten.mm.default:
271
+ meta[mm_count_key] += 1
272
+ # Saves output of all compute ops, except every second mm
273
+ to_save = func in _save_list and not (
274
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
275
+ )
276
+ return (
277
+ CheckpointPolicy.MUST_SAVE
278
+ if to_save
279
+ else CheckpointPolicy.PREFER_RECOMPUTE
280
+ )
281
+
282
+ return _custom_policy
283
+
284
+ def selective_checkpointing_context_fn():
285
+ meta = defaultdict(int)
286
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
287
+
288
+ return ptd_checkpoint_wrapper(
289
+ module,
290
+ context_fn=selective_checkpointing_context_fn,
291
+ preserve_rng_state=False,
292
+ )
293
+ elif use_layer_sac:
294
+ # Checkpoint every `ac_freq` of the modules passed to this function
295
+ ac_freq = int(ac_config.selective_ac_option)
296
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
297
+ ptd_checkpoint_wrapper._count += 1
298
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
299
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
300
+ else:
301
+ return module
302
+
303
+
304
+ def apply_ac(model: nn.Module, ac_config):
305
+ """Apply activation checkpointing to the model."""
306
+ for layer_id, transformer_block in model.layers.named_children():
307
+ transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
308
+ model.layers.register_module(layer_id, transformer_block)
309
+
310
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
311
+
312
+
313
+ def apply_compile(model: nn.Module):
314
+ """
315
+ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
316
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
317
+ """
318
+ for layer_id, transformer_block in model.layers.named_children():
319
+ transformer_block = torch.compile(transformer_block, fullgraph=True)
320
+ model.layers.register_module(layer_id, transformer_block)
321
+
322
+ logger.info("Compiling each TransformerBlock with torch.compile")
323
+
324
+
325
+ def apply_fsdp(
326
+ model: nn.Module,
327
+ dp_mesh: DeviceMesh,
328
+ param_dtype: torch.dtype,
329
+ reduce_dtype: torch.dtype,
330
+ pp_enabled: bool,
331
+ cpu_offload: bool = False,
332
+ reshard_after_forward_policy: str = "default",
333
+ ):
334
+ """
335
+ Apply data parallelism (via FSDP2) to the model.
336
+
337
+ Args:
338
+ model (nn.Module): The model to apply data parallelism to.
339
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
340
+ param_dtype (torch.dtype): The data type to use for model parameters.
341
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
342
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
343
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
344
+ reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
345
+ Other options: "never", "always".
346
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
347
+ - "always" will enable `reshard_after_forward` for all forward passes.
348
+ - "never" will disable `reshard_after_forward` for all forward passes.
349
+
350
+ """
351
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
352
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
353
+ if cpu_offload:
354
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
355
+
356
+ for layer_id, transformer_block in model.layers.items():
357
+ if reshard_after_forward_policy == "always":
358
+ reshard_after_forward = True
359
+ elif reshard_after_forward_policy == "never":
360
+ reshard_after_forward = False
361
+ elif reshard_after_forward_policy == "default":
362
+ if pp_enabled:
363
+ # For PP, do not reshard after forward to avoid per-microbatch
364
+ # all-gathers, which can be expensive and non-overlapped
365
+ reshard_after_forward = False
366
+ else:
367
+ # As an optimization, do not reshard after forward for the last
368
+ # transformer block since FSDP would prefetch it immediately
369
+ reshard_after_forward = int(layer_id) < len(model.layers) - 1
370
+ else:
371
+ raise ValueError(
372
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
373
+ )
374
+ fully_shard(
375
+ transformer_block,
376
+ **fsdp_config,
377
+ reshard_after_forward=reshard_after_forward,
378
+ )
379
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
380
+
381
+
382
+ def apply_ddp(
383
+ model: nn.Module,
384
+ dp_mesh: DeviceMesh,
385
+ enable_compile: bool,
386
+ enable_compiled_autograd: bool,
387
+ ):
388
+ if enable_compile:
389
+ if enable_compiled_autograd:
390
+ torch._dynamo.config.optimize_ddp = (
391
+ "python_reducer_without_compiled_forward"
392
+ )
393
+ else:
394
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
395
+
396
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
397
+
398
+ logger.info("Applied DDP to the model")
torchtitan/models/llama3/pipeline_llama.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D pipeline parallelism to the Llama model.
8
+
9
+ import copy
10
+
11
+ import torch.nn as nn
12
+ from torch.distributed import DeviceMesh
13
+ from torch.distributed.pipelining import PipelineStage
14
+ from torch.distributed.pipelining.schedules import (
15
+ _PipelineSchedule,
16
+ get_schedule_class,
17
+ ScheduleZBVZeroBubble,
18
+ )
19
+
20
+ from torchtitan.components.loss import LossFunction
21
+ from torchtitan.config_manager import JobConfig
22
+ from torchtitan.distributed import ParallelDims
23
+ from torchtitan.distributed.pipeline import (
24
+ build_pipeline_schedule,
25
+ generate_split_points,
26
+ stage_ids_this_rank,
27
+ )
28
+ from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
29
+ from torchtitan.tools.logging import logger
30
+
31
+ from .model import TransformerModelArgs
32
+
33
+
34
+ def pipeline_llama(
35
+ model: nn.Module,
36
+ world_mesh: DeviceMesh,
37
+ parallel_dims: ParallelDims,
38
+ job_config: JobConfig,
39
+ device: DeviceType,
40
+ model_config: TransformerModelArgs,
41
+ parallelize_fn: ParallelizeFunction,
42
+ loss_fn: LossFunction,
43
+ ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
44
+ pp_mesh = world_mesh["pp"]
45
+
46
+ stages, model_parts = pipeline_llama_manual_split(
47
+ model, pp_mesh, parallel_dims, job_config, device, model_config
48
+ )
49
+
50
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
51
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
52
+ # optimizer, and checkpointing
53
+ for i, m in enumerate(model_parts):
54
+ # apply SPMD-style PT-D techniques
55
+ m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
56
+ model_parts[i] = m
57
+ # NOTE: this is to update the model in the stage
58
+ # in case the model is modified e.g. by torch.compile
59
+ stages[i].submod = m
60
+
61
+ pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
62
+
63
+ # This is used in the train loop to determine whether to pass in the input_ids and labels
64
+ has_first_stage = False
65
+ has_last_stage = False
66
+ for stage in stages:
67
+ if stage.is_first:
68
+ has_first_stage = True
69
+ if stage.is_last:
70
+ has_last_stage = True
71
+
72
+ return pp_schedule, model_parts, has_first_stage, has_last_stage
73
+
74
+
75
+ def pipeline_llama_manual_split(
76
+ whole_model: nn.Module,
77
+ pp_mesh: DeviceMesh,
78
+ parallel_dims: ParallelDims,
79
+ job_config: JobConfig,
80
+ device: DeviceType,
81
+ model_config: TransformerModelArgs,
82
+ ) -> tuple[list[PipelineStage], list[nn.Module]]:
83
+ """
84
+ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
85
+
86
+ It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
87
+
88
+ The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
89
+ parallelism.
90
+ """
91
+ pp_rank = pp_mesh.get_local_rank()
92
+ pp_size = pp_mesh.size()
93
+ parallelism_config = job_config.parallelism
94
+
95
+ splits = parallelism_config.pipeline_parallel_split_points or generate_split_points(
96
+ parallelism_config.pipeline_parallel_schedule,
97
+ parallelism_config.pipeline_parallel_layers_per_stage,
98
+ parallel_dims.pp,
99
+ model_config.n_layers,
100
+ )
101
+
102
+ def _build_stage(
103
+ stage_idx: int,
104
+ start_layer: str | None,
105
+ stop_layer: str | None,
106
+ is_first: bool = False,
107
+ is_last: bool = False,
108
+ ) -> tuple[PipelineStage, nn.Module]:
109
+ model = copy.deepcopy(whole_model)
110
+ if not is_first:
111
+ model.tok_embeddings = None
112
+
113
+ drop_layers = start_layer is not None
114
+ for name in list(model.layers.keys()):
115
+ # we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
116
+ if f"layers.{name}" == start_layer:
117
+ drop_layers = False
118
+ if f"layers.{name}" == stop_layer:
119
+ drop_layers = True
120
+ if drop_layers:
121
+ del model.layers[name]
122
+
123
+ if not is_last:
124
+ model.norm = None
125
+ model.output = None
126
+
127
+ stage = PipelineStage(
128
+ model,
129
+ stage_idx,
130
+ num_stages,
131
+ device,
132
+ group=pp_mesh.get_group("pp"),
133
+ )
134
+ return stage, model
135
+
136
+ num_stages = len(splits) + 1
137
+ stage_idx = pp_rank
138
+
139
+ stages = []
140
+ models = []
141
+
142
+ schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule)
143
+ style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
144
+
145
+ for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
146
+ start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
147
+ stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
148
+ stage, model_chunk = _build_stage(
149
+ stage_idx,
150
+ start_layer,
151
+ stop_layer,
152
+ is_first=stage_idx == 0,
153
+ is_last=stage_idx == num_stages - 1,
154
+ )
155
+ logger.info(
156
+ f"PP rank {pp_rank} is building stage_idx {stage_idx}"
157
+ f" with start_layer {start_layer}, stop_layer {stop_layer}"
158
+ )
159
+ stages.append(stage)
160
+ models.append(model_chunk)
161
+ return stages, models
torchtitan/models/llama3/train_configs/llama3_405b.toml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtitan Config.toml
2
+ # NOTE: this toml config is a preset for 128 H100 GPUs.
3
+
4
+ [job]
5
+ dump_folder = "./outputs"
6
+ description = "Llama 3 405B training"
7
+
8
+ [profiling]
9
+ enable_profiling = true
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 100
12
+
13
+ [metrics]
14
+ log_freq = 10
15
+ enable_tensorboard = true
16
+ save_tb_folder = "tb"
17
+
18
+ [model]
19
+ name = "llama3"
20
+ flavor = "405B"
21
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
22
+ tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
23
+ converters = "float8"
24
+
25
+ [optimizer]
26
+ name = "AdamW"
27
+ lr = 8e-5
28
+ eps = 1e-8
29
+
30
+ [lr_scheduler]
31
+ warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
32
+
33
+ [training]
34
+ batch_size = 2
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 3000
38
+ compile = true
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8 # 8-way TP
45
+ enable_async_tensor_parallel = true
46
+ pipeline_parallel_degree = 1
47
+ context_parallel_degree = 1
48
+
49
+ [checkpoint]
50
+ enable_checkpoint = false
51
+ folder = "checkpoint"
52
+ interval = 500
53
+ model_weights_only = false
54
+ export_dtype = "float32"
55
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
56
+
57
+ [activation_checkpoint]
58
+ mode = 'full' # ['none', 'selective', 'full']
59
+
60
+ [float8]
61
+ enable_fsdp_float8_all_gather = true
62
+ precompute_float8_dynamic_scale_for_fsdp = true
63
+ filter_fqns = "output"
torchtitan/tools/profiling.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import contextlib
8
+ import os
9
+ import pickle
10
+ import time
11
+
12
+ import torch
13
+
14
+ from torchtitan.config_manager import JobConfig
15
+ from torchtitan.tools.logging import logger
16
+
17
+ # the number of warmup steps before the active step in each profiling cycle
18
+ WARMUP = 3
19
+
20
+ # how much memory allocation/free ops to record in memory snapshots
21
+ MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
22
+
23
+
24
+ @contextlib.contextmanager
25
+ def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
26
+ # get user defined profiler settings
27
+ enable_profiling = config.profiling.enable_profiling
28
+
29
+ if enable_profiling:
30
+ dump_dir = config.job.dump_folder
31
+ save_trace_dir = config.profiling.save_traces_folder
32
+ trace_dir = os.path.join(dump_dir, save_trace_dir)
33
+ profile_freq = config.profiling.profile_freq
34
+
35
+ rank = torch.distributed.get_rank()
36
+
37
+ def trace_handler(prof):
38
+ curr_trace_dir_name = "iteration_" + str(prof.step_num)
39
+ curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
40
+ if not os.path.exists(curr_trace_dir):
41
+ os.makedirs(curr_trace_dir, exist_ok=True)
42
+
43
+ logger.info(f"Dumping profiler traces at step {prof.step_num}")
44
+ begin = time.monotonic()
45
+ prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
46
+ logger.info(
47
+ f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
48
+ )
49
+
50
+ logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
51
+
52
+ if not os.path.exists(trace_dir):
53
+ os.makedirs(trace_dir, exist_ok=True)
54
+
55
+ warmup, active = WARMUP, 1
56
+ wait = profile_freq - (active + warmup)
57
+ assert (
58
+ wait >= 0
59
+ ), "profile_freq must be greater than or equal to warmup + active"
60
+ gpu_device_profiled = None
61
+ if torch.cuda.is_available():
62
+ gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
63
+ elif torch.xpu.is_available():
64
+ gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
65
+ with torch.profiler.profile(
66
+ activities=[
67
+ torch.profiler.ProfilerActivity.CPU,
68
+ gpu_device_profiled,
69
+ ],
70
+ schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
71
+ on_trace_ready=trace_handler,
72
+ record_shapes=True,
73
+ ) as torch_profiler:
74
+ torch_profiler.step_num = global_step
75
+ yield torch_profiler
76
+ else:
77
+ torch_profiler = contextlib.nullcontext()
78
+ yield None
79
+
80
+
81
+ @contextlib.contextmanager
82
+ def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0):
83
+ enable_snapshot = config.profiling.enable_memory_snapshot
84
+ if enable_snapshot:
85
+ snapshot_folder = config.profiling.save_memory_snapshot_folder
86
+ snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder)
87
+ if not os.path.exists(snapshot_dir):
88
+ os.makedirs(snapshot_dir, exist_ok=True)
89
+ rank = torch.distributed.get_rank()
90
+
91
+ class MemoryProfiler:
92
+ def __init__(self, step_num: int, freq: int):
93
+ torch.cuda.memory._record_memory_history(
94
+ max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES
95
+ )
96
+ # when resume training, we start from the last step
97
+ self.step_num = step_num
98
+ self.freq = freq
99
+
100
+ def step(self, exit_ctx: bool = False):
101
+ self.step_num += 1
102
+ if not exit_ctx and self.step_num % self.freq != 0:
103
+ return
104
+ if not exit_ctx:
105
+ curr_step = self.step_num
106
+ dir_name = f"iteration_{curr_step}"
107
+ else:
108
+ # dump as iteration_0_exit if OOM at iter 1
109
+ curr_step = self.step_num - 1
110
+ dir_name = f"iteration_{curr_step}_exit"
111
+ curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
112
+ if not os.path.exists(curr_snapshot_dir):
113
+ os.makedirs(curr_snapshot_dir, exist_ok=True)
114
+ logger.info(f"Dumping memory snapshot at step {curr_step}")
115
+ begin = time.monotonic()
116
+ with open(
117
+ f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
118
+ ) as output:
119
+ pickle.dump(torch.cuda.memory._snapshot(), output)
120
+ logger.info(
121
+ f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
122
+ )
123
+
124
+ logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
125
+ profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
126
+ try:
127
+ yield profiler
128
+ except torch.OutOfMemoryError as e:
129
+ profiler.step(exit_ctx=True)
130
+ else:
131
+ yield None