Add files using upload-large-folder tool
Browse files- fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc +0 -0
- fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc +0 -0
- logs/none_1_grtqk5/attempt_0/0/stderr.log +0 -0
- logs/none_1_grtqk5/attempt_0/1/stderr.log +0 -0
- logs/none_1_grtqk5/attempt_0/6/stderr.log +0 -0
- setup.py +51 -0
- torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
- torchtitan/components/optimizer.py +303 -0
- torchtitan/datasets/hf_datasets.py +173 -0
- torchtitan/datasets/tokenizer/tiktoken.py +190 -0
- torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc +0 -0
- torchtitan/distributed/__pycache__/utils.cpython-312.pyc +0 -0
- torchtitan/experiments/deepseek_v3/LICENSE-CODE +21 -0
- torchtitan/experiments/deepseek_v3/README.md +40 -0
- torchtitan/experiments/deepseek_v3/generate.py +308 -0
- torchtitan/experiments/deepseek_v3/indices.py +195 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py +11 -0
- torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py +260 -0
- torchtitan/experiments/flux/README.md +23 -0
- torchtitan/experiments/flux/__init__.py +122 -0
- torchtitan/experiments/flux/dataset/tokenizer.py +64 -0
- torchtitan/experiments/flux/model/autoencoder.py +388 -0
- torchtitan/experiments/flux/model/hf_embedder.py +40 -0
- torchtitan/experiments/flux/model/math.py +38 -0
- torchtitan/experiments/flux/model/model.py +177 -0
- torchtitan/experiments/flux/scripts/download_autoencoder.py +61 -0
- torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
- torchtitan/experiments/flux/train_configs/debug_model.toml +68 -0
- torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py +885 -0
- torchtitan/experiments/llama4/README.md +29 -0
- torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc +0 -0
- torchtitan/experiments/llama4/model/args.py +109 -0
- torchtitan/experiments/llama4/model/moe.py +228 -0
- torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py +536 -0
- torchtitan/experiments/multimodal/__init__.py +37 -0
- torchtitan/experiments/multimodal/mm_collator.py +227 -0
- torchtitan/experiments/multimodal/requirements.txt +1 -0
- torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc +0 -0
- torchtitan/experiments/simple_fsdp/model.py +18 -0
- torchtitan/experiments/simple_fsdp/simple_fsdp.py +194 -0
- torchtitan/experiments/simple_fsdp/tests/__init__.py +5 -0
- torchtitan/models/__init__.py +10 -0
- torchtitan/models/__pycache__/__init__.cpython-312.pyc +0 -0
- torchtitan/models/__pycache__/norms.cpython-312.pyc +0 -0
- torchtitan/models/llama3/__init__.py +76 -0
- torchtitan/models/llama3/parallelize_llama.py +398 -0
- torchtitan/models/llama3/pipeline_llama.py +161 -0
- torchtitan/models/llama3/train_configs/llama3_405b.toml +63 -0
- torchtitan/tools/profiling.py +131 -0
fla/models/mamba/__pycache__/configuration_mamba.cpython-312.pyc
ADDED
|
Binary file (7.06 kB). View file
|
|
|
fla/modules/__pycache__/fused_linear_cross_entropy.cpython-312.pyc
ADDED
|
Binary file (20.6 kB). View file
|
|
|
logs/none_1_grtqk5/attempt_0/0/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_1_grtqk5/attempt_0/1/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
logs/none_1_grtqk5/attempt_0/6/stderr.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
setup.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from setuptools import find_packages, setup
|
| 9 |
+
|
| 10 |
+
with open('README.md') as f:
|
| 11 |
+
long_description = f.read()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_package_version():
|
| 15 |
+
with open(Path(os.path.dirname(os.path.abspath(__file__))) / 'flame' / '__init__.py') as f:
|
| 16 |
+
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
| 17 |
+
return ast.literal_eval(version_match.group(1))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
setup(
|
| 21 |
+
name='flame',
|
| 22 |
+
version=get_package_version(),
|
| 23 |
+
description='A minimal training framework for scaling FLA models',
|
| 24 |
+
long_description=long_description,
|
| 25 |
+
long_description_content_type='text/markdown',
|
| 26 |
+
author='Songlin Yang, Yu Zhang',
|
| 27 |
+
author_email='yangsl66@mit.edu, yzhang.cs@outlook.com',
|
| 28 |
+
url='https://github.com/fla-org/flame',
|
| 29 |
+
packages=find_packages(),
|
| 30 |
+
license='MIT',
|
| 31 |
+
classifiers=[
|
| 32 |
+
'Programming Language :: Python :: 3',
|
| 33 |
+
'License :: OSI Approved :: MIT License',
|
| 34 |
+
'Operating System :: OS Independent',
|
| 35 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence'
|
| 36 |
+
],
|
| 37 |
+
python_requires='>=3.10',
|
| 38 |
+
install_requires=[
|
| 39 |
+
'torch==2.6',
|
| 40 |
+
'torchdata',
|
| 41 |
+
'transformers==4.51.3',
|
| 42 |
+
'triton>=3.0',
|
| 43 |
+
'datasets>=3.3.0',
|
| 44 |
+
'einops',
|
| 45 |
+
'ninja',
|
| 46 |
+
'wandb',
|
| 47 |
+
'tiktoken',
|
| 48 |
+
'tensorboard',
|
| 49 |
+
'python-dotenv'
|
| 50 |
+
],
|
| 51 |
+
)
|
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
torchtitan/components/optimizer.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
from typing import Any, Generic, Iterator, TypeVar
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed.checkpoint.state_dict import (
|
| 13 |
+
get_optimizer_state_dict,
|
| 14 |
+
set_optimizer_state_dict,
|
| 15 |
+
StateDictOptions,
|
| 16 |
+
)
|
| 17 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 18 |
+
from torch.optim import Optimizer
|
| 19 |
+
|
| 20 |
+
from torchtitan.components.ft import FTManager, has_torchft
|
| 21 |
+
from torchtitan.config_manager import JobConfig
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"OptimizersContainer",
|
| 25 |
+
"build_optimizers",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if has_torchft:
|
| 30 |
+
import torchft as ft
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
T = TypeVar("T", bound=Optimizer)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class OptimizersContainer(Optimizer, Stateful, Generic[T]):
|
| 37 |
+
"""A container for multiple optimizers.
|
| 38 |
+
|
| 39 |
+
This class is used to wrap multiple optimizers into a single object that can be
|
| 40 |
+
used to reduce the complexity of the training loop. This mimics the behavior of
|
| 41 |
+
``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``.
|
| 42 |
+
|
| 43 |
+
**Note**
|
| 44 |
+
Users who want to customize the optimizer behavior can inherit from this class and
|
| 45 |
+
extend the functionality as needed. The following methods must follow the same signature
|
| 46 |
+
as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``,
|
| 47 |
+
``load_state_dict()``.
|
| 48 |
+
|
| 49 |
+
**Limitations**
|
| 50 |
+
This class assumes that all the optimizers are the same type and have the same
|
| 51 |
+
configurations. With this assumption, TorchTitan can support lr scheduler resharding
|
| 52 |
+
(e.g., loading a checkpoint with a different number of GPUs and/or different
|
| 53 |
+
parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the
|
| 54 |
+
resharding for the optimizer state but not for the lr scheduler state, hence the limitation.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model_parts (List[nn.Module]): List of model parts to be optimized.
|
| 58 |
+
optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers.
|
| 59 |
+
name (str): Name of the optimizers.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
optimizers: list[T]
|
| 63 |
+
model_parts: list[nn.Module]
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
model_parts: list[nn.Module],
|
| 68 |
+
optimizer_cls: type[T],
|
| 69 |
+
optimizer_kwargs: dict[str, Any],
|
| 70 |
+
) -> None:
|
| 71 |
+
all_params = []
|
| 72 |
+
self.optimizers = []
|
| 73 |
+
self.model_parts = model_parts
|
| 74 |
+
for model in self.model_parts:
|
| 75 |
+
params = [p for p in model.parameters() if p.requires_grad]
|
| 76 |
+
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
|
| 77 |
+
all_params.extend(params)
|
| 78 |
+
self._validate_length(len(self.model_parts))
|
| 79 |
+
self._post_init(all_params, optimizer_kwargs)
|
| 80 |
+
|
| 81 |
+
def __iter__(self) -> Iterator[T]:
|
| 82 |
+
return iter(self.optimizers)
|
| 83 |
+
|
| 84 |
+
def __len__(self) -> int:
|
| 85 |
+
return len(self.optimizers)
|
| 86 |
+
|
| 87 |
+
def step(self, *args, **kwargs) -> None:
|
| 88 |
+
for optimizer in self.optimizers:
|
| 89 |
+
optimizer.step(*args, **kwargs)
|
| 90 |
+
|
| 91 |
+
def zero_grad(self, *args, **kwargs) -> None:
|
| 92 |
+
for optimizer in self.optimizers:
|
| 93 |
+
optimizer.zero_grad(*args, **kwargs)
|
| 94 |
+
|
| 95 |
+
def state_dict(self) -> dict[str, Any]:
|
| 96 |
+
func = functools.partial(
|
| 97 |
+
get_optimizer_state_dict,
|
| 98 |
+
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 99 |
+
)
|
| 100 |
+
return {
|
| 101 |
+
k: v
|
| 102 |
+
for sd in map(func, self.model_parts, self.optimizers)
|
| 103 |
+
for k, v in sd.items()
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 107 |
+
func = functools.partial(
|
| 108 |
+
set_optimizer_state_dict,
|
| 109 |
+
optim_state_dict=state_dict,
|
| 110 |
+
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 111 |
+
)
|
| 112 |
+
list(map(func, self.model_parts, self.optimizers))
|
| 113 |
+
|
| 114 |
+
def _validate_length(self, expected_length: int) -> None:
|
| 115 |
+
assert expected_length == len(self.optimizers), (
|
| 116 |
+
"Must pass one optimizer per model part or per param if "
|
| 117 |
+
"using OptimizersInBackwardContainer."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def _post_init(
|
| 121 |
+
self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
|
| 122 |
+
) -> None:
|
| 123 |
+
# We need to call Optimizer.__init__() to initialize some necessary optimizer
|
| 124 |
+
# functionality such as hooks.
|
| 125 |
+
Optimizer.__init__(self, all_params, optimizer_kwargs)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class OptimizersInBackwardContainer(OptimizersContainer):
|
| 129 |
+
"""OptimizersContainer for executing ``optim.step()`` in backward pass.
|
| 130 |
+
|
| 131 |
+
This class extend ``OptimizersContainer`` to support optimizer step in
|
| 132 |
+
backward pass. ``step()`` and ``zero_grad()`` are no-op in this class.
|
| 133 |
+
Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to
|
| 134 |
+
execute these methods when the gradient is accumulated.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
model_parts: list[nn.Module],
|
| 140 |
+
optimizer_cls: type[T],
|
| 141 |
+
optimizer_kwargs: dict[str, Any],
|
| 142 |
+
) -> None:
|
| 143 |
+
all_params = []
|
| 144 |
+
self.model_parts = model_parts
|
| 145 |
+
|
| 146 |
+
optim_dict = {}
|
| 147 |
+
for model in self.model_parts:
|
| 148 |
+
for p in model.parameters():
|
| 149 |
+
if p.requires_grad:
|
| 150 |
+
optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
|
| 151 |
+
all_params.append(p)
|
| 152 |
+
|
| 153 |
+
def optim_hook(param) -> None:
|
| 154 |
+
optim_dict[param].step()
|
| 155 |
+
optim_dict[param].zero_grad()
|
| 156 |
+
|
| 157 |
+
for model in self.model_parts:
|
| 158 |
+
for param in model.parameters():
|
| 159 |
+
if param.requires_grad:
|
| 160 |
+
param.register_post_accumulate_grad_hook(optim_hook)
|
| 161 |
+
|
| 162 |
+
self.optimizers = list(optim_dict.values())
|
| 163 |
+
|
| 164 |
+
self._validate_length(
|
| 165 |
+
sum(len(list(model.parameters())) for model in self.model_parts)
|
| 166 |
+
)
|
| 167 |
+
self._post_init(all_params, optimizer_kwargs)
|
| 168 |
+
|
| 169 |
+
def step(self) -> None:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
def zero_grad(self) -> None:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class FTOptimizersContainer(OptimizersContainer):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
model_parts: list[nn.Module],
|
| 180 |
+
optimizer_cls: type[T],
|
| 181 |
+
optimizer_kwargs: dict[str, Any],
|
| 182 |
+
ft_manager: "ft.Manager",
|
| 183 |
+
) -> None:
|
| 184 |
+
super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
|
| 185 |
+
|
| 186 |
+
# Force to initialize the optimizer state so that `optim.step()`
|
| 187 |
+
# won't be called by state_dict() and load_state_dict().
|
| 188 |
+
_ = {
|
| 189 |
+
k: v
|
| 190 |
+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
|
| 191 |
+
for k, v in sd.items()
|
| 192 |
+
}
|
| 193 |
+
self.cache_state_dict: dict[str, Any] = {}
|
| 194 |
+
self._ft_optimizer = ft.Optimizer(ft_manager, self)
|
| 195 |
+
self._call_from_ft: bool = False
|
| 196 |
+
|
| 197 |
+
def init_cache_state_dict(self) -> None:
|
| 198 |
+
self.cache_state_dict = super().state_dict()
|
| 199 |
+
|
| 200 |
+
def state_dict(self) -> dict[str, Any]:
|
| 201 |
+
return self.cache_state_dict
|
| 202 |
+
|
| 203 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 204 |
+
# We have to invalidate the `cache_state_dict` because optimizer uses
|
| 205 |
+
# assign instead of copy when doing `load_state_dict()`. Without
|
| 206 |
+
# invalidating the `cache_state_dict`, there will be memory leakage.
|
| 207 |
+
self.cache_state_dict = {}
|
| 208 |
+
super().load_state_dict(state_dict)
|
| 209 |
+
self.init_cache_state_dict()
|
| 210 |
+
|
| 211 |
+
def step(self, *args, **kwargs) -> None:
|
| 212 |
+
"""Calling the correct step() depending on the caller.
|
| 213 |
+
|
| 214 |
+
TorchFT's OptimizerWrapper.step() is designed to be callled only once
|
| 215 |
+
per train step per ft.Manager regardless how many optimizers are used.
|
| 216 |
+
Hence we will need to appropriately dispatch the call.
|
| 217 |
+
"""
|
| 218 |
+
if self._call_from_ft:
|
| 219 |
+
super().step(*args, **kwargs)
|
| 220 |
+
else:
|
| 221 |
+
self._call_from_ft = True
|
| 222 |
+
self._ft_optimizer.step(*args, **kwargs)
|
| 223 |
+
self._call_from_ft = False
|
| 224 |
+
|
| 225 |
+
def zero_grad(self, *args, **kwargs) -> None:
|
| 226 |
+
"""Calling the correct zero_grad() depending on the caller.
|
| 227 |
+
|
| 228 |
+
Check the comment in ``step()``.
|
| 229 |
+
"""
|
| 230 |
+
if self._call_from_ft:
|
| 231 |
+
super().zero_grad(*args, **kwargs)
|
| 232 |
+
else:
|
| 233 |
+
self._call_from_ft = True
|
| 234 |
+
self._ft_optimizer.zero_grad(*args, **kwargs)
|
| 235 |
+
self._call_from_ft = False
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def build_optimizers(
|
| 239 |
+
model_parts: list[nn.Module],
|
| 240 |
+
job_config: JobConfig,
|
| 241 |
+
ft_manager: FTManager,
|
| 242 |
+
) -> OptimizersContainer:
|
| 243 |
+
"""Create a OptimizersContainer for the given model parts and job config.
|
| 244 |
+
|
| 245 |
+
This function creates a ``OptimizersContainer`` for the given model parts.
|
| 246 |
+
``job_config`` should define the correct optimizer name and parameters.
|
| 247 |
+
This function currently supports creating ``OptimizersContainer`` and
|
| 248 |
+
``OptimizersInBackwardContainer``.
|
| 249 |
+
|
| 250 |
+
**Note**
|
| 251 |
+
Users who want to customize the optimizer behavior can create their own
|
| 252 |
+
``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
|
| 253 |
+
customized ``build_optimizers`` to ``TrainSpec`` will create the customized
|
| 254 |
+
``OptimizersContainer``.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
model_parts (List[nn.Module]): List of model parts to be optimized.
|
| 258 |
+
job_config (JobConfig): Job config containing the optimizer name and parameters.
|
| 259 |
+
"""
|
| 260 |
+
optim_in_bwd = job_config.optimizer.early_step_in_backward
|
| 261 |
+
if optim_in_bwd and job_config.parallelism.pipeline_parallel_degree > 1:
|
| 262 |
+
raise NotImplementedError(
|
| 263 |
+
"Optimizers in backward is not supported with pipeline parallelism."
|
| 264 |
+
)
|
| 265 |
+
name = job_config.optimizer.name
|
| 266 |
+
lr = job_config.optimizer.lr
|
| 267 |
+
eps = job_config.optimizer.eps
|
| 268 |
+
|
| 269 |
+
optim_implementation = job_config.optimizer.implementation
|
| 270 |
+
assert optim_implementation in ["fused", "foreach", "for-loop"]
|
| 271 |
+
|
| 272 |
+
fused = optim_implementation == "fused"
|
| 273 |
+
foreach = optim_implementation == "foreach"
|
| 274 |
+
|
| 275 |
+
optimizer_kwargs = {
|
| 276 |
+
"lr": lr,
|
| 277 |
+
"eps": eps,
|
| 278 |
+
"betas": (0.9, 0.95),
|
| 279 |
+
"weight_decay": 0.1,
|
| 280 |
+
"fused": fused,
|
| 281 |
+
"foreach": foreach,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
optimizer_classes = {
|
| 285 |
+
"Adam": torch.optim.Adam,
|
| 286 |
+
"AdamW": torch.optim.AdamW,
|
| 287 |
+
}
|
| 288 |
+
if name not in optimizer_classes:
|
| 289 |
+
raise NotImplementedError(f"Optimizer {name} not added.")
|
| 290 |
+
optimizer_cls = optimizer_classes[name]
|
| 291 |
+
|
| 292 |
+
if optim_in_bwd and ft_manager.enabled:
|
| 293 |
+
raise ValueError("TorchFT is not supported with optimizers in backward.")
|
| 294 |
+
elif optim_in_bwd:
|
| 295 |
+
return OptimizersInBackwardContainer(
|
| 296 |
+
model_parts, optimizer_cls, optimizer_kwargs
|
| 297 |
+
)
|
| 298 |
+
elif ft_manager.enabled:
|
| 299 |
+
return FTOptimizersContainer(
|
| 300 |
+
model_parts, optimizer_cls, optimizer_kwargs, ft_manager.manager
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
|
torchtitan/datasets/hf_datasets.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Callable
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from datasets import Dataset, load_dataset
|
| 13 |
+
from datasets.distributed import split_dataset_by_node
|
| 14 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 15 |
+
from torch.utils.data import IterableDataset
|
| 16 |
+
|
| 17 |
+
from torchtitan.components.dataloader import ParallelAwareDataloader
|
| 18 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 19 |
+
from torchtitan.config_manager import JobConfig
|
| 20 |
+
from torchtitan.tools.logging import logger
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _load_c4_dataset(dataset_path: str):
|
| 24 |
+
"""Load C4 dataset with default configuration."""
|
| 25 |
+
return load_dataset(dataset_path, name="en", split="train", streaming=True)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _process_c4_text(sample: dict[str, Any]) -> str:
|
| 29 |
+
"""Process C4 dataset sample text."""
|
| 30 |
+
return sample["text"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class DatasetConfig:
|
| 35 |
+
path: str
|
| 36 |
+
loader: Callable
|
| 37 |
+
text_processor: Callable
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Add your dataset here here - more information at docs/datasets.md
|
| 41 |
+
DATASETS = {
|
| 42 |
+
"c4": DatasetConfig(
|
| 43 |
+
path="allenai/c4",
|
| 44 |
+
loader=_load_c4_dataset,
|
| 45 |
+
text_processor=_process_c4_text,
|
| 46 |
+
),
|
| 47 |
+
"c4_test": DatasetConfig(
|
| 48 |
+
path="tests/assets/c4_test",
|
| 49 |
+
loader=lambda path: load_dataset(path, split="train"),
|
| 50 |
+
text_processor=_process_c4_text,
|
| 51 |
+
),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _validate_dataset(
|
| 56 |
+
dataset_name: str, dataset_path: str | None = None
|
| 57 |
+
) -> tuple[str, Callable, Callable]:
|
| 58 |
+
"""Validate dataset name and path."""
|
| 59 |
+
if dataset_name not in DATASETS:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
f"Dataset {dataset_name} is not supported. "
|
| 62 |
+
f"Supported datasets are: {list(DATASETS.keys())}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
config = DATASETS[dataset_name]
|
| 66 |
+
path = dataset_path or config.path
|
| 67 |
+
logger.info(f"Preparing {dataset_name} dataset from {path}")
|
| 68 |
+
return path, config.loader, config.text_processor
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class HuggingFaceDataset(IterableDataset, Stateful):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
dataset_name: str,
|
| 75 |
+
dataset_path: str | None,
|
| 76 |
+
tokenizer: Tokenizer,
|
| 77 |
+
seq_len: int = 2048,
|
| 78 |
+
dp_rank: int = 0,
|
| 79 |
+
dp_world_size: int = 1,
|
| 80 |
+
infinite: bool = False,
|
| 81 |
+
) -> None:
|
| 82 |
+
# Force lowercase for consistent comparison
|
| 83 |
+
dataset_name = dataset_name.lower()
|
| 84 |
+
|
| 85 |
+
path, dataset_loader, text_processor = _validate_dataset(
|
| 86 |
+
dataset_name, dataset_path
|
| 87 |
+
)
|
| 88 |
+
ds = dataset_loader(path)
|
| 89 |
+
|
| 90 |
+
self.dataset_name = dataset_name
|
| 91 |
+
self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
|
| 92 |
+
self._tokenizer = tokenizer
|
| 93 |
+
self.seq_len = seq_len
|
| 94 |
+
self.infinite = infinite
|
| 95 |
+
self._text_processor = text_processor
|
| 96 |
+
|
| 97 |
+
# Variables for checkpointing
|
| 98 |
+
self._sample_idx = 0
|
| 99 |
+
self._all_tokens: list[int] = []
|
| 100 |
+
|
| 101 |
+
def _get_data_iter(self):
|
| 102 |
+
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
|
| 103 |
+
return iter([])
|
| 104 |
+
|
| 105 |
+
it = iter(self._data)
|
| 106 |
+
for _ in range(self._sample_idx):
|
| 107 |
+
next(it)
|
| 108 |
+
return it
|
| 109 |
+
|
| 110 |
+
def __iter__(self):
|
| 111 |
+
max_buffer_token_len = 1 + self.seq_len
|
| 112 |
+
|
| 113 |
+
while True:
|
| 114 |
+
for sample in self._get_data_iter():
|
| 115 |
+
# Use the dataset-specific text processor
|
| 116 |
+
sample_text = self._text_processor(sample)
|
| 117 |
+
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
|
| 118 |
+
self._all_tokens.extend(sample_tokens)
|
| 119 |
+
self._sample_idx += 1
|
| 120 |
+
|
| 121 |
+
while len(self._all_tokens) >= max_buffer_token_len:
|
| 122 |
+
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
|
| 123 |
+
# update tokens to the remaining tokens
|
| 124 |
+
self._all_tokens = self._all_tokens[max_buffer_token_len:]
|
| 125 |
+
input = x[:-1]
|
| 126 |
+
label = x[1:]
|
| 127 |
+
yield {"input": input}, label
|
| 128 |
+
|
| 129 |
+
if not self.infinite:
|
| 130 |
+
logger.warning(f"Dataset {self.dataset_name} has run out of data")
|
| 131 |
+
break
|
| 132 |
+
else:
|
| 133 |
+
# Reset offset for the next iteration
|
| 134 |
+
self._sample_idx = 0
|
| 135 |
+
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
|
| 136 |
+
|
| 137 |
+
def load_state_dict(self, state_dict):
|
| 138 |
+
self._sample_idx = state_dict["sample_idx"]
|
| 139 |
+
self._all_tokens = state_dict["token_buffer"]
|
| 140 |
+
|
| 141 |
+
def state_dict(self):
|
| 142 |
+
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def build_hf_dataloader(
|
| 146 |
+
dp_world_size: int,
|
| 147 |
+
dp_rank: int,
|
| 148 |
+
tokenizer: Tokenizer,
|
| 149 |
+
job_config: JobConfig,
|
| 150 |
+
infinite: bool = True,
|
| 151 |
+
) -> ParallelAwareDataloader:
|
| 152 |
+
"""Build a data loader for HuggingFace datasets."""
|
| 153 |
+
dataset_name = job_config.training.dataset
|
| 154 |
+
dataset_path = job_config.training.dataset_path
|
| 155 |
+
batch_size = job_config.training.batch_size
|
| 156 |
+
seq_len = job_config.training.seq_len
|
| 157 |
+
|
| 158 |
+
hf_ds = HuggingFaceDataset(
|
| 159 |
+
dataset_name=dataset_name,
|
| 160 |
+
dataset_path=dataset_path,
|
| 161 |
+
tokenizer=tokenizer,
|
| 162 |
+
seq_len=seq_len,
|
| 163 |
+
dp_rank=dp_rank,
|
| 164 |
+
dp_world_size=dp_world_size,
|
| 165 |
+
infinite=infinite,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return ParallelAwareDataloader(
|
| 169 |
+
dataset=hf_ds,
|
| 170 |
+
dp_rank=dp_rank,
|
| 171 |
+
dp_world_size=dp_world_size,
|
| 172 |
+
batch_size=batch_size,
|
| 173 |
+
)
|
torchtitan/datasets/tokenizer/tiktoken.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from collections.abc import Collection, Iterator, Sequence, Set as AbstractSet
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import cast, Literal
|
| 14 |
+
|
| 15 |
+
import tiktoken
|
| 16 |
+
from tiktoken.load import load_tiktoken_bpe
|
| 17 |
+
|
| 18 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 19 |
+
from torchtitan.config_manager import JobConfig
|
| 20 |
+
from torchtitan.tools.logging import logger
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TikTokenizer(Tokenizer):
|
| 24 |
+
"""
|
| 25 |
+
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_path (str): The path to the Tiktoken model file.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
special_tokens: dict[str, int]
|
| 32 |
+
|
| 33 |
+
num_reserved_special_tokens = 256
|
| 34 |
+
|
| 35 |
+
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501, B950
|
| 36 |
+
|
| 37 |
+
def __init__(self, model_path: str):
|
| 38 |
+
super().__init__()
|
| 39 |
+
assert os.path.exists(
|
| 40 |
+
model_path
|
| 41 |
+
), f"The tokenizer path does not exist: {model_path}"
|
| 42 |
+
assert os.path.isfile(model_path), model_path
|
| 43 |
+
|
| 44 |
+
mergeable_ranks = load_tiktoken_bpe(model_path)
|
| 45 |
+
num_base_tokens = len(mergeable_ranks)
|
| 46 |
+
special_tokens = [
|
| 47 |
+
"<|begin_of_text|>",
|
| 48 |
+
"<|end_of_text|>",
|
| 49 |
+
"<|reserved_special_token_0|>",
|
| 50 |
+
"<|reserved_special_token_1|>",
|
| 51 |
+
"<|reserved_special_token_2|>",
|
| 52 |
+
"<|reserved_special_token_3|>",
|
| 53 |
+
"<|start_header_id|>",
|
| 54 |
+
"<|end_header_id|>",
|
| 55 |
+
"<|reserved_special_token_4|>",
|
| 56 |
+
"<|eot_id|>", # end of turn
|
| 57 |
+
] + [
|
| 58 |
+
f"<|reserved_special_token_{i}|>"
|
| 59 |
+
for i in range(5, self.num_reserved_special_tokens - 5)
|
| 60 |
+
]
|
| 61 |
+
self.special_tokens = {
|
| 62 |
+
token: num_base_tokens + i for i, token in enumerate(special_tokens)
|
| 63 |
+
}
|
| 64 |
+
self.model = tiktoken.Encoding(
|
| 65 |
+
name=Path(model_path).name,
|
| 66 |
+
pat_str=self.pat_str,
|
| 67 |
+
mergeable_ranks=mergeable_ranks,
|
| 68 |
+
special_tokens=self.special_tokens,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self._n_words: int = self.model.n_vocab
|
| 72 |
+
# BOS / EOS token IDs
|
| 73 |
+
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
| 74 |
+
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
| 75 |
+
self.pad_id: int = -1
|
| 76 |
+
self.stop_tokens = {
|
| 77 |
+
self.special_tokens["<|end_of_text|>"],
|
| 78 |
+
self.special_tokens["<|eot_id|>"],
|
| 79 |
+
}
|
| 80 |
+
logger.info(
|
| 81 |
+
f"TikTokenizer built: #words {self.n_words}, BOS ID {self.bos_id}, EOS ID {self.eos_id}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def encode(
|
| 85 |
+
self,
|
| 86 |
+
s: str,
|
| 87 |
+
*,
|
| 88 |
+
bos: bool,
|
| 89 |
+
eos: bool,
|
| 90 |
+
allowed_special: Literal["all"] | AbstractSet[str] | None = None,
|
| 91 |
+
disallowed_special: Literal["all"] | Collection[str] | None = None,
|
| 92 |
+
) -> list[int]:
|
| 93 |
+
"""
|
| 94 |
+
Encodes a string into a list of token IDs.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
s (str): The input string to be encoded.
|
| 98 |
+
bos (bool): Whether to prepend the beginning-of-sequence token.
|
| 99 |
+
eos (bool): Whether to append the end-of-sequence token.
|
| 100 |
+
allowed_tokens ("all"|set[str]): allowed special tokens in string
|
| 101 |
+
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
list[int]: A list of token IDs.
|
| 105 |
+
|
| 106 |
+
By default, setting disallowed_special=() encodes a string by ignoring
|
| 107 |
+
special tokens. Specifically:
|
| 108 |
+
- Setting `disallowed_special` to () will cause all text corresponding
|
| 109 |
+
to special tokens to be encoded as natural text (insteading of raising
|
| 110 |
+
an error).
|
| 111 |
+
- Setting `allowed_special` to "all" will treat all text corresponding
|
| 112 |
+
to special tokens to be encoded as special tokens.
|
| 113 |
+
"""
|
| 114 |
+
assert type(s) is str
|
| 115 |
+
allowed_special = allowed_special or set()
|
| 116 |
+
disallowed_special = disallowed_special or ()
|
| 117 |
+
|
| 118 |
+
# The tiktoken tokenizer can handle <=400k chars without
|
| 119 |
+
# pyo3_runtime.PanicException.
|
| 120 |
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
| 121 |
+
|
| 122 |
+
# https://github.com/openai/tiktoken/issues/195
|
| 123 |
+
# Here we iterate over subsequences and split if we exceed the limit
|
| 124 |
+
# of max consecutive non-whitespace or whitespace characters.
|
| 125 |
+
MAX_NO_WHITESPACES_CHARS = 25_000
|
| 126 |
+
|
| 127 |
+
substrs = (
|
| 128 |
+
substr
|
| 129 |
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
| 130 |
+
for substr in self._split_whitespaces_or_nonwhitespaces(
|
| 131 |
+
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
t: list[int] = []
|
| 135 |
+
for substr in substrs:
|
| 136 |
+
t.extend(
|
| 137 |
+
self.model.encode(
|
| 138 |
+
substr,
|
| 139 |
+
allowed_special=allowed_special,
|
| 140 |
+
disallowed_special=disallowed_special,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
if bos:
|
| 144 |
+
t.insert(0, self.bos_id)
|
| 145 |
+
if eos:
|
| 146 |
+
t.append(self.eos_id)
|
| 147 |
+
return t
|
| 148 |
+
|
| 149 |
+
def decode(self, t: Sequence[int]) -> str:
|
| 150 |
+
"""
|
| 151 |
+
Decodes a list of token IDs into a string.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
t (List[int]): The list of token IDs to be decoded.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
str: The decoded string.
|
| 158 |
+
"""
|
| 159 |
+
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
| 160 |
+
return self.model.decode(cast(list[int], t))
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def _split_whitespaces_or_nonwhitespaces(
|
| 164 |
+
s: str, max_consecutive_slice_len: int
|
| 165 |
+
) -> Iterator[str]:
|
| 166 |
+
"""
|
| 167 |
+
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
| 168 |
+
consecutive whitespaces or consecutive non-whitespaces.
|
| 169 |
+
"""
|
| 170 |
+
current_slice_len = 0
|
| 171 |
+
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
| 172 |
+
slice_start = 0
|
| 173 |
+
|
| 174 |
+
for i in range(len(s)):
|
| 175 |
+
is_now_space = s[i].isspace()
|
| 176 |
+
|
| 177 |
+
if current_slice_is_space ^ is_now_space:
|
| 178 |
+
current_slice_len = 1
|
| 179 |
+
current_slice_is_space = is_now_space
|
| 180 |
+
else:
|
| 181 |
+
current_slice_len += 1
|
| 182 |
+
if current_slice_len > max_consecutive_slice_len:
|
| 183 |
+
yield s[slice_start:i]
|
| 184 |
+
slice_start = i
|
| 185 |
+
current_slice_len = 1
|
| 186 |
+
yield s[slice_start:]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def build_tiktoken_tokenizer(job_config: JobConfig) -> TikTokenizer:
|
| 190 |
+
return TikTokenizer(job_config.model.tokenizer_path)
|
torchtitan/distributed/__pycache__/pipeline.cpython-312.pyc
ADDED
|
Binary file (7.82 kB). View file
|
|
|
torchtitan/distributed/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
torchtitan/experiments/deepseek_v3/LICENSE-CODE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 DeepSeek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
torchtitan/experiments/deepseek_v3/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Running DeepSeek in Titan (experimental)
|
| 2 |
+
|
| 3 |
+
This folder contains a DeepSeek model supporting v2 and v3 as well as kernels
|
| 4 |
+
and scripts needed to run it.
|
| 5 |
+
|
| 6 |
+
## Inference
|
| 7 |
+
|
| 8 |
+
### Prerequisites:
|
| 9 |
+
|
| 10 |
+
You will need to download a DeepSeek model's weights if you want to run a
|
| 11 |
+
pre-trained checkpoint. We provided a script to download the weights from
|
| 12 |
+
HuggingFace Model Hub:
|
| 13 |
+
```bash
|
| 14 |
+
python download.py [vX]
|
| 15 |
+
```
|
| 16 |
+
where `vX` can be v2 or v3, both are supported. You may be required to create a
|
| 17 |
+
HuggingFace account and log in first.
|
| 18 |
+
|
| 19 |
+
### Running inference:
|
| 20 |
+
|
| 21 |
+
The inference script is in `generate.py`. You can run it with the following
|
| 22 |
+
command:
|
| 23 |
+
```bash
|
| 24 |
+
torchrun --standalone --nproc-per-node 4 generate.py
|
| 25 |
+
```
|
| 26 |
+
This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
|
| 27 |
+
default.
|
| 28 |
+
|
| 29 |
+
Alternatively, you can run inference by using `bash inference.sh`, optionally
|
| 30 |
+
followed by your prompt.
|
| 31 |
+
|
| 32 |
+
## Training
|
| 33 |
+
|
| 34 |
+
The training script is in `train.py`. You can run it by the following command:
|
| 35 |
+
```bash
|
| 36 |
+
torchrun --standalone --nproc-per-node 8 train.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
|
| 40 |
+
default, with pipeline parallel, expert parallel, and data parallel enabled.
|
torchtitan/experiments/deepseek_v3/generate.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# torchrun --standalone --nproc-per-node 4 generate.py
|
| 8 |
+
|
| 9 |
+
# use inference.sh "Your Question Here?" to run inference with a single prompt.
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
|
| 17 |
+
from checkpoint import load_weights_from_hf
|
| 18 |
+
from model import DeepseekForCausalLM
|
| 19 |
+
from model_config import deepseek_config_registry
|
| 20 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 21 |
+
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
|
| 22 |
+
from torchtitan.tools.utils import Color
|
| 23 |
+
from transformers import AutoTokenizer
|
| 24 |
+
|
| 25 |
+
# Uncomment the model you want to run.
|
| 26 |
+
model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
|
| 27 |
+
# model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
|
| 31 |
+
"""Parse and colorize chat output with optional colors for each role."""
|
| 32 |
+
lines = text.split("\n")
|
| 33 |
+
result = []
|
| 34 |
+
|
| 35 |
+
current_role = None
|
| 36 |
+
current_content = []
|
| 37 |
+
|
| 38 |
+
def _process_current_content():
|
| 39 |
+
if not current_role or not current_content:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
content = "\n".join(current_content)
|
| 43 |
+
if current_role == "output":
|
| 44 |
+
return (
|
| 45 |
+
f"Output: {output_color}{content}{color.reset}"
|
| 46 |
+
if output_color
|
| 47 |
+
else f"Output: {content}"
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
try:
|
| 51 |
+
prefix, rest = current_content[0].split(":", 1)
|
| 52 |
+
role_color = user_color if current_role == "user" else assistant_color
|
| 53 |
+
if role_color:
|
| 54 |
+
formatted = f"{prefix}:{role_color}{rest}{color.reset}"
|
| 55 |
+
if len(current_content) > 1:
|
| 56 |
+
formatted += (
|
| 57 |
+
f"{role_color}\n"
|
| 58 |
+
+ "\n".join(current_content[1:])
|
| 59 |
+
+ f"{color.reset}"
|
| 60 |
+
)
|
| 61 |
+
return formatted
|
| 62 |
+
except ValueError:
|
| 63 |
+
pass
|
| 64 |
+
return content
|
| 65 |
+
|
| 66 |
+
for line in lines:
|
| 67 |
+
if line.startswith("Output:"):
|
| 68 |
+
if processed := _process_current_content():
|
| 69 |
+
result.append(processed)
|
| 70 |
+
current_role = "output"
|
| 71 |
+
content = line[len("Output:") :].strip()
|
| 72 |
+
if output_color:
|
| 73 |
+
content = f"Output: {output_color}{content}{color.reset}"
|
| 74 |
+
else:
|
| 75 |
+
content = f"Output: {content}"
|
| 76 |
+
result.append(content)
|
| 77 |
+
current_content = []
|
| 78 |
+
|
| 79 |
+
elif line.startswith("User:"):
|
| 80 |
+
if processed := _process_current_content():
|
| 81 |
+
result.append(processed)
|
| 82 |
+
current_role = "user"
|
| 83 |
+
current_content = [line]
|
| 84 |
+
|
| 85 |
+
elif line.startswith("Assistant:"):
|
| 86 |
+
if processed := _process_current_content():
|
| 87 |
+
result.append(processed)
|
| 88 |
+
current_role = "assistant"
|
| 89 |
+
current_content = [line]
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
if current_content:
|
| 93 |
+
current_content.append(line)
|
| 94 |
+
elif line.strip() and current_role is None:
|
| 95 |
+
# Handle system message at the beginning
|
| 96 |
+
current_role = "output"
|
| 97 |
+
if output_color:
|
| 98 |
+
result.append(f"Output: {output_color}{line.strip()}{color.reset}")
|
| 99 |
+
else:
|
| 100 |
+
result.append(f"Output: {line.strip()}")
|
| 101 |
+
|
| 102 |
+
# Process the last segment
|
| 103 |
+
if processed := _process_current_content():
|
| 104 |
+
result.append(processed)
|
| 105 |
+
|
| 106 |
+
return "\n".join(result)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
color = Color()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@dataclass
|
| 113 |
+
class DistConfig:
|
| 114 |
+
mesh: DeviceMesh
|
| 115 |
+
pp_mesh: DeviceMesh
|
| 116 |
+
ep_mesh: DeviceMesh
|
| 117 |
+
pp_size: int
|
| 118 |
+
ep_size: int
|
| 119 |
+
ep_rank: int
|
| 120 |
+
pp_rank: int
|
| 121 |
+
device: torch.device
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def create_model(dist_config: DistConfig):
|
| 125 |
+
model_args = deepseek_config_registry[model_id]
|
| 126 |
+
model_args.ep_size = dist_config.ep_size
|
| 127 |
+
model_args.num_stages = dist_config.pp_size
|
| 128 |
+
model_args.stage_idx = dist_config.pp_rank
|
| 129 |
+
model_args.max_seq_len = 16384
|
| 130 |
+
|
| 131 |
+
with dist_config.device, dist_config.mesh:
|
| 132 |
+
model = DeepseekForCausalLM(model_args)
|
| 133 |
+
load_weights_from_hf(model, model_id, dist_config.device)
|
| 134 |
+
model.eval()
|
| 135 |
+
model.setup_symm_mem(torch.bfloat16, dist_config.device)
|
| 136 |
+
|
| 137 |
+
stage = PipelineStage(
|
| 138 |
+
model,
|
| 139 |
+
dist_config.pp_rank,
|
| 140 |
+
dist_config.pp_size,
|
| 141 |
+
dist_config.device,
|
| 142 |
+
group=dist_config.pp_mesh.get_group(),
|
| 143 |
+
)
|
| 144 |
+
pp_schedule = ScheduleGPipe(stage, dist_config.pp_size)
|
| 145 |
+
return model, pp_schedule
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def create_dist_config(mesh: DeviceMesh):
|
| 149 |
+
rank = dist.get_rank()
|
| 150 |
+
device_count = torch.cuda.device_count()
|
| 151 |
+
device = torch.device("cuda", rank % device_count)
|
| 152 |
+
|
| 153 |
+
dist_config = DistConfig(
|
| 154 |
+
mesh=mesh,
|
| 155 |
+
pp_mesh=mesh["pp"],
|
| 156 |
+
ep_mesh=mesh["ep"],
|
| 157 |
+
pp_rank=mesh["pp"].get_local_rank(),
|
| 158 |
+
pp_size=mesh["pp"].size(),
|
| 159 |
+
ep_size=mesh["ep"].size(),
|
| 160 |
+
ep_rank=mesh["ep"].get_local_rank(),
|
| 161 |
+
device=device,
|
| 162 |
+
)
|
| 163 |
+
return dist_config
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def decode(tokenizer, x):
|
| 167 |
+
output = tokenizer.decode(x[0])
|
| 168 |
+
# Clean up the output by removing special tokens
|
| 169 |
+
bos = tokenizer.bos_token
|
| 170 |
+
output = output.replace(bos, "")
|
| 171 |
+
# Truncate at end of sentence token
|
| 172 |
+
eos_token = tokenizer.eos_token
|
| 173 |
+
if eos_token and eos_token in output:
|
| 174 |
+
output = output.split(eos_token)[0]
|
| 175 |
+
colored_output = colorize_chat(
|
| 176 |
+
output,
|
| 177 |
+
user_color=color.green,
|
| 178 |
+
assistant_color=color.cyan,
|
| 179 |
+
output_color=color.blue,
|
| 180 |
+
)
|
| 181 |
+
return colored_output
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@torch.inference_mode()
|
| 185 |
+
def generate(
|
| 186 |
+
model,
|
| 187 |
+
pp_schedule,
|
| 188 |
+
tokenizer,
|
| 189 |
+
dist_config,
|
| 190 |
+
messages: list[dict],
|
| 191 |
+
n_tokens: int = 50,
|
| 192 |
+
):
|
| 193 |
+
rank = dist.get_rank()
|
| 194 |
+
device = dist_config.device
|
| 195 |
+
x = tokenizer.apply_chat_template(
|
| 196 |
+
[messages] * dist_config.pp_size,
|
| 197 |
+
add_generation_prompt=True,
|
| 198 |
+
return_tensors="pt",
|
| 199 |
+
)
|
| 200 |
+
next_idx = x.shape[-1]
|
| 201 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
| 202 |
+
x = x.to(device)
|
| 203 |
+
|
| 204 |
+
for _ in range(n_tokens):
|
| 205 |
+
if dist_config.pp_size > 1:
|
| 206 |
+
if dist_config.pp_rank == 0:
|
| 207 |
+
pp_schedule.step(x)
|
| 208 |
+
torch.distributed.broadcast(
|
| 209 |
+
x,
|
| 210 |
+
group=dist_config.pp_mesh.get_group(),
|
| 211 |
+
group_src=dist_config.pp_size - 1,
|
| 212 |
+
)
|
| 213 |
+
elif dist_config.pp_rank == dist_config.pp_size - 1:
|
| 214 |
+
preds = pp_schedule.step()
|
| 215 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 216 |
+
x[:, next_idx] = next_token
|
| 217 |
+
torch.distributed.broadcast(
|
| 218 |
+
x,
|
| 219 |
+
group=dist_config.pp_mesh.get_group(),
|
| 220 |
+
group_src=dist_config.pp_size - 1,
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
pp_schedule.step()
|
| 224 |
+
torch.distributed.broadcast(
|
| 225 |
+
x,
|
| 226 |
+
group=dist_config.pp_mesh.get_group(),
|
| 227 |
+
group_src=dist_config.pp_size - 1,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
next_idx += 1
|
| 231 |
+
else:
|
| 232 |
+
preds = model(x)
|
| 233 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 234 |
+
x[:, next_idx] = next_token
|
| 235 |
+
next_idx += 1
|
| 236 |
+
|
| 237 |
+
if rank == 0:
|
| 238 |
+
colored_output = decode(tokenizer, x)
|
| 239 |
+
print(f"Without CUDA Graph:\n{colored_output}")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
@torch.inference_mode()
|
| 243 |
+
def generate_with_cuda_graph(
|
| 244 |
+
model,
|
| 245 |
+
tokenizer,
|
| 246 |
+
dist_config,
|
| 247 |
+
messages: list[dict],
|
| 248 |
+
n_tokens: int = 10,
|
| 249 |
+
):
|
| 250 |
+
rank = dist.get_rank()
|
| 251 |
+
device = dist_config.device
|
| 252 |
+
x = tokenizer.apply_chat_template(
|
| 253 |
+
[messages] * dist_config.pp_size,
|
| 254 |
+
add_generation_prompt=True,
|
| 255 |
+
return_tensors="pt",
|
| 256 |
+
)
|
| 257 |
+
next_idx = x.shape[-1]
|
| 258 |
+
x = torch.cat([x, torch.zeros(x.shape[0], n_tokens, dtype=torch.int64)], dim=-1)
|
| 259 |
+
x = x.to(device)
|
| 260 |
+
|
| 261 |
+
torch.cuda.synchronize()
|
| 262 |
+
|
| 263 |
+
# Create CUDA graph
|
| 264 |
+
g = torch.cuda.CUDAGraph()
|
| 265 |
+
with torch.cuda.graph(g):
|
| 266 |
+
preds = model(x)
|
| 267 |
+
|
| 268 |
+
# Run CUDA graph
|
| 269 |
+
for _ in range(n_tokens):
|
| 270 |
+
g.replay()
|
| 271 |
+
next_token = torch.argmax(preds[:, next_idx - 1], dim=-1)
|
| 272 |
+
x[:, next_idx] = next_token
|
| 273 |
+
next_idx += 1
|
| 274 |
+
|
| 275 |
+
if rank == 0:
|
| 276 |
+
colored_output = decode(tokenizer, x)
|
| 277 |
+
print(f"With CUDA Graph:\n{colored_output}")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
# Get user prompt from command line arguments
|
| 282 |
+
user_prompt = "What is 2+2?" # Default prompt
|
| 283 |
+
if len(sys.argv) > 1:
|
| 284 |
+
user_prompt = sys.argv[1]
|
| 285 |
+
|
| 286 |
+
mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("pp", "ep"))
|
| 287 |
+
rank = dist.get_rank()
|
| 288 |
+
if rank == 0:
|
| 289 |
+
print(
|
| 290 |
+
f"{color.yellow}Running inference with {model_id} on {mesh_shape} mesh{color.reset}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
dist_config = create_dist_config(mesh)
|
| 294 |
+
model, pp_schedule = create_model(dist_config)
|
| 295 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 296 |
+
|
| 297 |
+
messages = [
|
| 298 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 299 |
+
{"role": "user", "content": user_prompt},
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
generate(model, pp_schedule, tokenizer, dist_config, messages)
|
| 303 |
+
generate_with_cuda_graph(model, tokenizer, dist_config, messages)
|
| 304 |
+
|
| 305 |
+
if rank == 0:
|
| 306 |
+
print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
|
| 307 |
+
|
| 308 |
+
dist.destroy_process_group()
|
torchtitan/experiments/deepseek_v3/indices.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import triton
|
| 9 |
+
import triton.language as tl
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ["generate_permute_indices"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@triton.jit
|
| 16 |
+
def fill_indices_kernel(
|
| 17 |
+
tokens_per_expert_group_ptr, # *Pointer* to first input vector.
|
| 18 |
+
start_index_values_ptr, # *Pointer* to second input vector.
|
| 19 |
+
write_offsets_ptr, # *Pointer* to third input vector.
|
| 20 |
+
output_ptr, # *Pointer* to output vector.
|
| 21 |
+
experts_per_rank, # Number of experts per rank.
|
| 22 |
+
num_ranks, # Number of expert ranks.
|
| 23 |
+
):
|
| 24 |
+
# There are multiple 'programs' processing different data. We identify which program
|
| 25 |
+
# we are here:
|
| 26 |
+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
|
| 27 |
+
# The total number of programs in the launch grid.
|
| 28 |
+
num_programs = tl.num_programs(axis=0)
|
| 29 |
+
# We map the programs (blocks) to the experts.
|
| 30 |
+
for expert_id in tl.range(pid, experts_per_rank, step=num_programs):
|
| 31 |
+
# Read this expert's write offset.
|
| 32 |
+
write_offset = tl.load(write_offsets_ptr + expert_id)
|
| 33 |
+
# Loop over the ranks.
|
| 34 |
+
for r in tl.range(num_ranks):
|
| 35 |
+
# Slot in the tokens_per_expert_group array.
|
| 36 |
+
i = r * experts_per_rank + expert_id
|
| 37 |
+
start_index = tl.load(start_index_values_ptr + i)
|
| 38 |
+
length = tl.load(tokens_per_expert_group_ptr + i)
|
| 39 |
+
# Write the indices.
|
| 40 |
+
for l in tl.range(length):
|
| 41 |
+
val = start_index + l
|
| 42 |
+
tl.store(output_ptr + write_offset + l, val)
|
| 43 |
+
write_offset += length
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def fill_indices(
|
| 47 |
+
tokens_per_expert_group: torch.Tensor,
|
| 48 |
+
start_index_values: torch.Tensor,
|
| 49 |
+
write_offsets: torch.Tensor,
|
| 50 |
+
experts_per_rank: int,
|
| 51 |
+
num_ranks: int,
|
| 52 |
+
max_len: int,
|
| 53 |
+
):
|
| 54 |
+
# We need to preallocate the output.
|
| 55 |
+
permuted_indices = torch.full(
|
| 56 |
+
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
| 57 |
+
)
|
| 58 |
+
# Analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
|
| 59 |
+
# In this case, we use a 1D grid where the size is the number of blocks (TODO: bump this value).
|
| 60 |
+
grid = lambda meta: (1,)
|
| 61 |
+
# Each torch.tensor object is implicitly converted into a pointer to its first element.
|
| 62 |
+
fill_indices_kernel[grid](
|
| 63 |
+
tokens_per_expert_group,
|
| 64 |
+
start_index_values,
|
| 65 |
+
write_offsets,
|
| 66 |
+
permuted_indices,
|
| 67 |
+
experts_per_rank,
|
| 68 |
+
num_ranks,
|
| 69 |
+
)
|
| 70 |
+
return permuted_indices
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def fill_indices_cpu(
|
| 74 |
+
tokens_per_expert_group: torch.Tensor,
|
| 75 |
+
start_index_values: torch.Tensor,
|
| 76 |
+
write_offsets: torch.Tensor,
|
| 77 |
+
experts_per_rank: int,
|
| 78 |
+
num_ranks: int,
|
| 79 |
+
max_len: int,
|
| 80 |
+
):
|
| 81 |
+
# We need to preallocate the output.
|
| 82 |
+
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
| 83 |
+
# Fill the permuted indices
|
| 84 |
+
# For each local expert
|
| 85 |
+
for e in range(experts_per_rank):
|
| 86 |
+
write_start = write_offsets[e]
|
| 87 |
+
# For each remote rank
|
| 88 |
+
for r in range(num_ranks):
|
| 89 |
+
i = r * experts_per_rank + e
|
| 90 |
+
start_index = start_index_values[i]
|
| 91 |
+
length = tokens_per_expert_group[i]
|
| 92 |
+
# Fill in the indices
|
| 93 |
+
permuted_indices[write_start : write_start + length] = torch.arange(
|
| 94 |
+
start_index, start_index + length
|
| 95 |
+
)
|
| 96 |
+
write_start += length
|
| 97 |
+
return permuted_indices
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def generate_permute_indices(
|
| 101 |
+
tokens_per_expert_group: torch.Tensor,
|
| 102 |
+
experts_per_rank: int,
|
| 103 |
+
num_ranks: int,
|
| 104 |
+
max_len: int,
|
| 105 |
+
alignment: int,
|
| 106 |
+
use_cpu: bool = False,
|
| 107 |
+
):
|
| 108 |
+
# Prepare permutation indices and the number of tokens for each expert. The
|
| 109 |
+
# permutation indices are the indices of the tokens for each expert. The
|
| 110 |
+
# number of tokens for each expert is the sum of the number of tokens for
|
| 111 |
+
# such experts from all ranks. This number is aligned to the provided
|
| 112 |
+
# alignment requirement (usually comes from group gemm).
|
| 113 |
+
|
| 114 |
+
# Args:
|
| 115 |
+
# tokens_per_expert_group: number of tokens for each expert from all ranks.
|
| 116 |
+
# experts_per_rank: number of experts per rank.
|
| 117 |
+
# num_ranks: number of ranks.
|
| 118 |
+
# max_len: maximum length of the output index vector. If greater than
|
| 119 |
+
# total number of tokens, the remaining indices are set to -1.
|
| 120 |
+
# alignment: alignment for each returned element in `m_sizes`.
|
| 121 |
+
# use_cpu: whether to use cpu or gpu.
|
| 122 |
+
# Returns:
|
| 123 |
+
# permuted_indices: permutation indices.
|
| 124 |
+
# m_sizes: number of tokens for each expert.
|
| 125 |
+
|
| 126 |
+
# `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example:
|
| 127 |
+
# From: | rank 0 | rank 1 |
|
| 128 |
+
# To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 |
|
| 129 |
+
# | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 |
|
| 130 |
+
|
| 131 |
+
# Prefix sum to get the start index value of each expert
|
| 132 |
+
start_index_values = (
|
| 133 |
+
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
| 134 |
+
)
|
| 135 |
+
# Chunk sizes for each expert
|
| 136 |
+
chunk_size_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
| 137 |
+
# Align the chunk sizes to the given alignment
|
| 138 |
+
m_sizes = ((chunk_size_per_expert + alignment - 1) // alignment * alignment).to(
|
| 139 |
+
torch.int32
|
| 140 |
+
)
|
| 141 |
+
# Perform another prefix sum to get the write offset of each expert in `permuted_indices`
|
| 142 |
+
write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
|
| 143 |
+
# Select the method to fill the permuted indices
|
| 144 |
+
fill_fn = fill_indices_cpu if use_cpu else fill_indices
|
| 145 |
+
# Fill the permuted indices
|
| 146 |
+
permuted_indices = fill_fn(
|
| 147 |
+
tokens_per_expert_group,
|
| 148 |
+
start_index_values,
|
| 149 |
+
write_offsets,
|
| 150 |
+
experts_per_rank,
|
| 151 |
+
num_ranks,
|
| 152 |
+
max_len,
|
| 153 |
+
)
|
| 154 |
+
return permuted_indices, m_sizes
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Below is for testing only
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def test():
|
| 161 |
+
device = torch.device("cuda", 0)
|
| 162 |
+
experts_per_rank = 4
|
| 163 |
+
num_ranks = 4
|
| 164 |
+
tokens_per_expert_group = torch.full(
|
| 165 |
+
(num_ranks * experts_per_rank,), 4, dtype=torch.int32, device=device
|
| 166 |
+
)
|
| 167 |
+
max_len = 128
|
| 168 |
+
alignment = 32
|
| 169 |
+
# Use the GPU kernel
|
| 170 |
+
permuted_indices_gpu, m_sizes = generate_permute_indices(
|
| 171 |
+
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
|
| 172 |
+
)
|
| 173 |
+
# Use the CPU method
|
| 174 |
+
permuted_indices_cpu, _ = generate_permute_indices(
|
| 175 |
+
tokens_per_expert_group,
|
| 176 |
+
experts_per_rank,
|
| 177 |
+
num_ranks,
|
| 178 |
+
max_len,
|
| 179 |
+
alignment,
|
| 180 |
+
use_cpu=True,
|
| 181 |
+
)
|
| 182 |
+
# Check that the results are the same
|
| 183 |
+
assert torch.equal(permuted_indices_gpu.cpu(), permuted_indices_cpu)
|
| 184 |
+
assert torch.equal(
|
| 185 |
+
torch.remainder(m_sizes, alignment),
|
| 186 |
+
torch.zeros(experts_per_rank, device=device),
|
| 187 |
+
)
|
| 188 |
+
# Print the results
|
| 189 |
+
print(permuted_indices_gpu)
|
| 190 |
+
print(m_sizes)
|
| 191 |
+
print("Success")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
test()
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .triton_on_device_all_to_all_v import OnDeviceAllToAllV
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"OnDeviceAllToAllV",
|
| 11 |
+
]
|
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
import torch.distributed._symmetric_memory as symm_mem
|
| 10 |
+
import triton
|
| 11 |
+
import triton.language as tl
|
| 12 |
+
|
| 13 |
+
from .triton_barrier import blockwise_barrier
|
| 14 |
+
from .triton_utils import sync_threads
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@triton.jit
|
| 18 |
+
def _exchange_row_offsets(
|
| 19 |
+
split_sizes_ptrs,
|
| 20 |
+
rank: tl.constexpr,
|
| 21 |
+
world_size: tl.constexpr,
|
| 22 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 23 |
+
):
|
| 24 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 25 |
+
|
| 26 |
+
# split_sizes_ptr for all ranks
|
| 27 |
+
# All these vector stacks into split_sizes_matrix
|
| 28 |
+
split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64))
|
| 29 |
+
|
| 30 |
+
# split_sizes_matrix[remote_rank, :]
|
| 31 |
+
input_split_sizes_ptr = tl.load(split_sizes_ptrs + remote_rank).to(
|
| 32 |
+
tl.pointer_type(tl.int64)
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
offsets_ = tl.arange(0, world_size)
|
| 36 |
+
input_split_sizes = tl.load(
|
| 37 |
+
input_split_sizes_ptr + offsets_, mask=offsets_ <= rank, other=0
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
num_rows = tl.load(input_split_sizes_ptr + rank)
|
| 41 |
+
input_row_offset = tl.sum(input_split_sizes) - num_rows
|
| 42 |
+
|
| 43 |
+
# split_sizes_matrix[:, rank]
|
| 44 |
+
output_split_sizes_ptrs = (
|
| 45 |
+
tl.load(split_sizes_ptrs + offsets_).to(tl.pointer_type(tl.int64)) + rank
|
| 46 |
+
)
|
| 47 |
+
output_split_sizes = tl.load(
|
| 48 |
+
output_split_sizes_ptrs, mask=offsets_ <= remote_rank, other=0
|
| 49 |
+
)
|
| 50 |
+
output_row_offset = tl.sum(output_split_sizes) - num_rows
|
| 51 |
+
|
| 52 |
+
return input_row_offset, output_row_offset, num_rows
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@triton.jit
|
| 56 |
+
def on_device_all_to_all_v_kernel(
|
| 57 |
+
output_ptr,
|
| 58 |
+
output_splits_ptr,
|
| 59 |
+
input_ptrs,
|
| 60 |
+
input_splits_ptr,
|
| 61 |
+
signal_pad_ptrs,
|
| 62 |
+
dim: tl.constexpr, # Separate dim for easier vectorization
|
| 63 |
+
rank: tl.constexpr,
|
| 64 |
+
world_size: tl.constexpr,
|
| 65 |
+
BLOCKS_PER_REMOTE_RANK: tl.constexpr,
|
| 66 |
+
UNROLL_FACTOR: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
):
|
| 69 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 70 |
+
sync_threads()
|
| 71 |
+
|
| 72 |
+
remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK
|
| 73 |
+
block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK
|
| 74 |
+
|
| 75 |
+
input_row_offset, output_row_offset, num_rows = _exchange_row_offsets(
|
| 76 |
+
input_splits_ptr, rank, world_size, BLOCKS_PER_REMOTE_RANK
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
output_splits_ptr = output_splits_ptr.to(tl.pointer_type(tl.uint64))
|
| 80 |
+
if block_offset == 0:
|
| 81 |
+
# Update output_splits
|
| 82 |
+
tl.store(output_splits_ptr + remote_rank, num_rows)
|
| 83 |
+
|
| 84 |
+
input_ptr = (
|
| 85 |
+
tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to(
|
| 86 |
+
tl.pointer_type(tl.bfloat16)
|
| 87 |
+
)
|
| 88 |
+
+ input_row_offset * dim
|
| 89 |
+
)
|
| 90 |
+
output_ptr = output_ptr + output_row_offset * dim
|
| 91 |
+
|
| 92 |
+
outer_loop_step = BLOCK_SIZE * UNROLL_FACTOR
|
| 93 |
+
outer_loop_iters_per_rank = tl.cdiv(
|
| 94 |
+
tl.cdiv(num_rows * dim, outer_loop_step), BLOCKS_PER_REMOTE_RANK
|
| 95 |
+
)
|
| 96 |
+
numel_per_rank = outer_loop_step * outer_loop_iters_per_rank
|
| 97 |
+
offset = numel_per_rank * block_offset
|
| 98 |
+
end = tl.minimum(numel_per_rank * (block_offset + 1), num_rows * dim)
|
| 99 |
+
|
| 100 |
+
unroll_region_size = (end - offset) // outer_loop_step * outer_loop_step
|
| 101 |
+
for i in tl.range(offset, offset + unroll_region_size, outer_loop_step):
|
| 102 |
+
datas = []
|
| 103 |
+
for j in tl.range(
|
| 104 |
+
i,
|
| 105 |
+
i + outer_loop_step,
|
| 106 |
+
BLOCK_SIZE,
|
| 107 |
+
loop_unroll_factor=UNROLL_FACTOR,
|
| 108 |
+
):
|
| 109 |
+
offsets = j + tl.arange(0, BLOCK_SIZE)
|
| 110 |
+
data = tl.load(input_ptr + offsets)
|
| 111 |
+
tl.store(output_ptr + offsets, data)
|
| 112 |
+
|
| 113 |
+
offset += unroll_region_size
|
| 114 |
+
while offset < end:
|
| 115 |
+
offsets = offset + tl.arange(0, BLOCK_SIZE)
|
| 116 |
+
mask = offsets < num_rows * dim
|
| 117 |
+
data = tl.load(input_ptr + offsets, mask=mask)
|
| 118 |
+
tl.store(output_ptr + offsets, data, mask=mask)
|
| 119 |
+
offset += BLOCK_SIZE
|
| 120 |
+
|
| 121 |
+
sync_threads()
|
| 122 |
+
blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _on_device_all_to_all_v(
|
| 127 |
+
output: torch.Tensor,
|
| 128 |
+
output_splits: torch.Tensor,
|
| 129 |
+
input: torch.Tensor,
|
| 130 |
+
input_splits: torch.Tensor,
|
| 131 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 132 |
+
BLOCKS_PER_REMOTE_RANK=8,
|
| 133 |
+
UNROLL_FACTOR: int = 8,
|
| 134 |
+
BLOCK_SIZE: int = 16384,
|
| 135 |
+
):
|
| 136 |
+
assert output.dim() == 2, f"{output.shape}"
|
| 137 |
+
assert input.dim() == 2, f"{input.shape}"
|
| 138 |
+
assert output.shape[1] == input.shape[1]
|
| 139 |
+
|
| 140 |
+
dim = output.shape[1]
|
| 141 |
+
input_hdl = symm_mem.rendezvous(input, group=group)
|
| 142 |
+
input_splits_hdl = symm_mem.rendezvous(input_splits, group=group)
|
| 143 |
+
|
| 144 |
+
num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK
|
| 145 |
+
kernel = on_device_all_to_all_v_kernel[(num_blocks, 1, 1)](
|
| 146 |
+
output,
|
| 147 |
+
output_splits,
|
| 148 |
+
input_hdl.buffer_ptrs_dev,
|
| 149 |
+
input_splits_hdl.buffer_ptrs_dev,
|
| 150 |
+
input_hdl.signal_pad_ptrs_dev,
|
| 151 |
+
dim=dim,
|
| 152 |
+
rank=input_hdl.rank,
|
| 153 |
+
world_size=input_hdl.world_size,
|
| 154 |
+
BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK,
|
| 155 |
+
UNROLL_FACTOR=UNROLL_FACTOR,
|
| 156 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
| 157 |
+
num_warps=16,
|
| 158 |
+
)
|
| 159 |
+
# log_triton_kernel(kernel)
|
| 160 |
+
return output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class OnDeviceAllToAllV(torch.autograd.Function):
|
| 164 |
+
# A symmetric memory holding the grad_output during backward
|
| 165 |
+
grad_output_buf = None
|
| 166 |
+
# A symmetric memory for exchanges split sizes during both forward and backward
|
| 167 |
+
splits_buf = None
|
| 168 |
+
# Maximum output length (need to be set before use of OnDeviceAllToAllV)
|
| 169 |
+
max_output_len = None
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def forward(
|
| 173 |
+
ctx,
|
| 174 |
+
input: torch.Tensor,
|
| 175 |
+
input_splits: torch.Tensor,
|
| 176 |
+
group: dist.ProcessGroup = dist.group.WORLD,
|
| 177 |
+
):
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
input: input tensor with data for all ranks concatenated.
|
| 181 |
+
input_splits: input splits of shape (group.world_size,)
|
| 182 |
+
group: process group to scope the collective.
|
| 183 |
+
"""
|
| 184 |
+
# Initialize input splits buffer (one time only)
|
| 185 |
+
if OnDeviceAllToAllV.splits_buf is None:
|
| 186 |
+
OnDeviceAllToAllV.splits_buf = symm_mem.empty(
|
| 187 |
+
*input_splits.shape,
|
| 188 |
+
dtype=input_splits.dtype,
|
| 189 |
+
device=input_splits.device,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if OnDeviceAllToAllV.max_output_len is None:
|
| 193 |
+
raise RuntimeError(
|
| 194 |
+
"Please set max output length via `OnDeviceAllToAllV.max_output_len = ...`"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Allocate output buffer
|
| 198 |
+
output = input.new_empty(OnDeviceAllToAllV.max_output_len, *input.shape[1:])
|
| 199 |
+
# Allocate output splits tensor
|
| 200 |
+
output_splits = torch.empty_like(input_splits)
|
| 201 |
+
# Copy input splits to the buffer
|
| 202 |
+
OnDeviceAllToAllV.splits_buf.copy_(input_splits)
|
| 203 |
+
|
| 204 |
+
# Shuffle input to output
|
| 205 |
+
_on_device_all_to_all_v(
|
| 206 |
+
output, output_splits, input, OnDeviceAllToAllV.splits_buf, group=group
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Output splits in forward is the input splits in backward
|
| 210 |
+
ctx.save_for_backward(output_splits)
|
| 211 |
+
ctx.group = group
|
| 212 |
+
ctx.input_shape = input.shape
|
| 213 |
+
return output, output_splits
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def backward(ctx, grad_output, grad_splits):
|
| 217 |
+
"""
|
| 218 |
+
Backward is implemented as a shuffle of the output's gradients to the input.
|
| 219 |
+
Args:
|
| 220 |
+
`grad_output`: output's gradients passed from the downstream.
|
| 221 |
+
`grad_splits`: unused.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
# Initialize grad_output buffer (one time only)
|
| 225 |
+
if OnDeviceAllToAllV.grad_output_buf is None:
|
| 226 |
+
assert (
|
| 227 |
+
OnDeviceAllToAllV.max_output_len is not None
|
| 228 |
+
), "`max_output_len` not set"
|
| 229 |
+
OnDeviceAllToAllV.grad_output_buf = symm_mem.empty(
|
| 230 |
+
OnDeviceAllToAllV.max_output_len,
|
| 231 |
+
*grad_output.shape[1:],
|
| 232 |
+
dtype=grad_output.dtype,
|
| 233 |
+
device=grad_output.device,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# TODO: is there a way to tell autograd to feed grad_output directly to
|
| 237 |
+
# our symm_mem buffer?
|
| 238 |
+
OnDeviceAllToAllV.grad_output_buf.narrow(0, 0, grad_output.shape[0]).copy_(
|
| 239 |
+
grad_output
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Size info
|
| 243 |
+
(grad_output_splits,) = ctx.saved_tensors
|
| 244 |
+
OnDeviceAllToAllV.splits_buf.copy_(grad_output_splits)
|
| 245 |
+
grad_input_splits = torch.empty_like(grad_output_splits) # unused
|
| 246 |
+
grad_input = grad_output.new_empty(*ctx.input_shape)
|
| 247 |
+
|
| 248 |
+
# Shuffle gradients back to the input
|
| 249 |
+
_on_device_all_to_all_v(
|
| 250 |
+
grad_input,
|
| 251 |
+
grad_input_splits,
|
| 252 |
+
OnDeviceAllToAllV.grad_output_buf,
|
| 253 |
+
OnDeviceAllToAllV.splits_buf,
|
| 254 |
+
group=ctx.group,
|
| 255 |
+
)
|
| 256 |
+
return grad_input, None, None
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Alias
|
| 260 |
+
on_device_all_to_all_v = OnDeviceAllToAllV.apply
|
torchtitan/experiments/flux/README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLUX model in torchtitan
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
## Usage
|
| 6 |
+
First, download the autoencoder model from HuggingFace with your own access token:
|
| 7 |
+
```bash
|
| 8 |
+
python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token <your_access_token>
|
| 9 |
+
```
|
| 10 |
+
This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file.
|
| 11 |
+
|
| 12 |
+
Run the following command to train the model on a single GPU:
|
| 13 |
+
```bash
|
| 14 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --nproc_per_node=1 torchtitan/experiments/flux/train.py --job.config_file torchtitan/experiments/flux/train_configs/debug_model.toml
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## TODO
|
| 18 |
+
- [ ] Supporting for multiple GPUs is comming soon (FSDP, etc)
|
| 19 |
+
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
|
| 20 |
+
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
|
| 21 |
+
- [ ] Support for distributed checkpointing and loading
|
| 22 |
+
- [ ] Implement init_weights() function to initialize the model weights
|
| 23 |
+
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
|
torchtitan/experiments/flux/__init__.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
| 8 |
+
|
| 9 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
| 10 |
+
from torchtitan.components.optimizer import build_optimizers
|
| 11 |
+
from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader
|
| 12 |
+
from torchtitan.experiments.flux.loss import build_mse_loss
|
| 13 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
|
| 14 |
+
from torchtitan.experiments.flux.parallelize_flux import parallelize_flux
|
| 15 |
+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
|
| 16 |
+
|
| 17 |
+
from .model.model import FluxModel, FluxModelArgs
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"FluxModelArgs",
|
| 21 |
+
"FluxModel",
|
| 22 |
+
"flux_configs",
|
| 23 |
+
"parallelize_flux",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
flux_configs = {
|
| 28 |
+
"flux-dev": FluxModelArgs(
|
| 29 |
+
in_channels=64,
|
| 30 |
+
out_channels=64,
|
| 31 |
+
vec_in_dim=768,
|
| 32 |
+
context_in_dim=512,
|
| 33 |
+
hidden_size=3072,
|
| 34 |
+
mlp_ratio=4.0,
|
| 35 |
+
num_heads=24,
|
| 36 |
+
depth=19,
|
| 37 |
+
depth_single_blocks=38,
|
| 38 |
+
axes_dim=(16, 56, 56),
|
| 39 |
+
theta=10_000,
|
| 40 |
+
qkv_bias=True,
|
| 41 |
+
guidance_embed=True,
|
| 42 |
+
autoencoder_params=AutoEncoderParams(
|
| 43 |
+
resolution=256,
|
| 44 |
+
in_channels=3,
|
| 45 |
+
ch=128,
|
| 46 |
+
out_ch=3,
|
| 47 |
+
ch_mult=(1, 2, 4, 4),
|
| 48 |
+
num_res_blocks=2,
|
| 49 |
+
z_channels=16,
|
| 50 |
+
scale_factor=0.3611,
|
| 51 |
+
shift_factor=0.1159,
|
| 52 |
+
),
|
| 53 |
+
),
|
| 54 |
+
"flux-schnell": FluxModelArgs(
|
| 55 |
+
in_channels=64,
|
| 56 |
+
out_channels=64,
|
| 57 |
+
vec_in_dim=768,
|
| 58 |
+
context_in_dim=4096,
|
| 59 |
+
hidden_size=3072,
|
| 60 |
+
mlp_ratio=4.0,
|
| 61 |
+
num_heads=24,
|
| 62 |
+
depth=19,
|
| 63 |
+
depth_single_blocks=38,
|
| 64 |
+
axes_dim=(16, 56, 56),
|
| 65 |
+
theta=10_000,
|
| 66 |
+
qkv_bias=True,
|
| 67 |
+
guidance_embed=False,
|
| 68 |
+
autoencoder_params=AutoEncoderParams(
|
| 69 |
+
resolution=256,
|
| 70 |
+
in_channels=3,
|
| 71 |
+
ch=128,
|
| 72 |
+
out_ch=3,
|
| 73 |
+
ch_mult=(1, 2, 4, 4),
|
| 74 |
+
num_res_blocks=2,
|
| 75 |
+
z_channels=16,
|
| 76 |
+
scale_factor=0.3611,
|
| 77 |
+
shift_factor=0.1159,
|
| 78 |
+
),
|
| 79 |
+
),
|
| 80 |
+
"flux-debug": FluxModelArgs(
|
| 81 |
+
in_channels=64,
|
| 82 |
+
out_channels=64,
|
| 83 |
+
vec_in_dim=768,
|
| 84 |
+
context_in_dim=512,
|
| 85 |
+
hidden_size=512,
|
| 86 |
+
mlp_ratio=4.0,
|
| 87 |
+
num_heads=4,
|
| 88 |
+
depth=2,
|
| 89 |
+
depth_single_blocks=2,
|
| 90 |
+
axes_dim=(16, 56, 56),
|
| 91 |
+
theta=10_000,
|
| 92 |
+
qkv_bias=True,
|
| 93 |
+
guidance_embed=True,
|
| 94 |
+
autoencoder_params=AutoEncoderParams(
|
| 95 |
+
resolution=256,
|
| 96 |
+
in_channels=3,
|
| 97 |
+
ch=128,
|
| 98 |
+
out_ch=3,
|
| 99 |
+
ch_mult=(1, 2, 4, 4),
|
| 100 |
+
num_res_blocks=2,
|
| 101 |
+
z_channels=16,
|
| 102 |
+
scale_factor=0.3611,
|
| 103 |
+
shift_factor=0.1159,
|
| 104 |
+
),
|
| 105 |
+
),
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
register_train_spec(
|
| 110 |
+
TrainSpec(
|
| 111 |
+
name="flux",
|
| 112 |
+
cls=FluxModel,
|
| 113 |
+
config=flux_configs,
|
| 114 |
+
parallelize_fn=parallelize_flux,
|
| 115 |
+
pipelining_fn=None,
|
| 116 |
+
build_optimizers_fn=build_optimizers,
|
| 117 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
| 118 |
+
build_dataloader_fn=build_flux_dataloader,
|
| 119 |
+
build_tokenizer_fn=None,
|
| 120 |
+
build_loss_fn=build_mse_loss,
|
| 121 |
+
)
|
| 122 |
+
)
|
torchtitan/experiments/flux/dataset/tokenizer.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 8 |
+
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 14 |
+
from transformers import CLIPTokenizer, T5Tokenizer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FluxTokenizer(Tokenizer):
|
| 18 |
+
"""
|
| 19 |
+
Tokenizing and encoding/decoding text using the T5 or Clip tokenizer.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_path (str): Path to the tokenzier from hugging face.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, model_path: str = "t5-small", max_length: int = 77):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self._n_words = 8 # TODO(jianiw): check
|
| 29 |
+
self._max_length = max_length
|
| 30 |
+
|
| 31 |
+
self.is_clip = model_path.startswith("openai")
|
| 32 |
+
|
| 33 |
+
if self.is_clip:
|
| 34 |
+
self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
| 35 |
+
model_path, max_length=max_length
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
|
| 39 |
+
model_path, max_length=max_length
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def encode(
|
| 43 |
+
self,
|
| 44 |
+
s: str,
|
| 45 |
+
) -> List[int]:
|
| 46 |
+
"""
|
| 47 |
+
Encode the prompt text into tokens.
|
| 48 |
+
"""
|
| 49 |
+
tokens = self._tokenizer(
|
| 50 |
+
s,
|
| 51 |
+
truncation=True,
|
| 52 |
+
max_length=self._max_length,
|
| 53 |
+
return_length=False,
|
| 54 |
+
return_overflowing_tokens=False,
|
| 55 |
+
padding="max_length",
|
| 56 |
+
return_tensors="pt", # return pytorch tensors, default return List[int]
|
| 57 |
+
)["input_ids"]
|
| 58 |
+
return tokens
|
| 59 |
+
|
| 60 |
+
def decode(self, t: List[int]) -> str:
|
| 61 |
+
"""
|
| 62 |
+
Decode function. This function will not be called.
|
| 63 |
+
"""
|
| 64 |
+
return self._tokenizer.decode(t)
|
torchtitan/experiments/flux/model/autoencoder.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from safetensors.torch import load_file as load_sft
|
| 13 |
+
from torch import nn, Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class AutoEncoderParams:
|
| 18 |
+
resolution: int = 256
|
| 19 |
+
in_channels: int = 3
|
| 20 |
+
ch: int = 128
|
| 21 |
+
out_ch: int = 3
|
| 22 |
+
ch_mult: tuple[int] = (1, 2, 4, 4)
|
| 23 |
+
num_res_blocks: int = 2
|
| 24 |
+
z_channels: int = 16
|
| 25 |
+
scale_factor: float = 0.3611
|
| 26 |
+
shift_factor: float = 0.1159
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def swish(x: Tensor) -> Tensor:
|
| 30 |
+
return x * torch.sigmoid(x)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AttnBlock(nn.Module):
|
| 34 |
+
def __init__(self, in_channels: int):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.in_channels = in_channels
|
| 37 |
+
|
| 38 |
+
self.norm = nn.GroupNorm(
|
| 39 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 43 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 44 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 45 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 46 |
+
|
| 47 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 48 |
+
h_ = self.norm(h_)
|
| 49 |
+
q = self.q(h_)
|
| 50 |
+
k = self.k(h_)
|
| 51 |
+
v = self.v(h_)
|
| 52 |
+
|
| 53 |
+
b, c, h, w = q.shape
|
| 54 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 55 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 56 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 57 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 58 |
+
|
| 59 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 62 |
+
return x + self.proj_out(self.attention(x))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ResnetBlock(nn.Module):
|
| 66 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.in_channels = in_channels
|
| 69 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 70 |
+
self.out_channels = out_channels
|
| 71 |
+
|
| 72 |
+
self.norm1 = nn.GroupNorm(
|
| 73 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 74 |
+
)
|
| 75 |
+
self.conv1 = nn.Conv2d(
|
| 76 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 77 |
+
)
|
| 78 |
+
self.norm2 = nn.GroupNorm(
|
| 79 |
+
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
| 80 |
+
)
|
| 81 |
+
self.conv2 = nn.Conv2d(
|
| 82 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 83 |
+
)
|
| 84 |
+
if self.in_channels != self.out_channels:
|
| 85 |
+
self.nin_shortcut = nn.Conv2d(
|
| 86 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
h = x
|
| 91 |
+
h = self.norm1(h)
|
| 92 |
+
h = swish(h)
|
| 93 |
+
h = self.conv1(h)
|
| 94 |
+
|
| 95 |
+
h = self.norm2(h)
|
| 96 |
+
h = swish(h)
|
| 97 |
+
h = self.conv2(h)
|
| 98 |
+
|
| 99 |
+
if self.in_channels != self.out_channels:
|
| 100 |
+
x = self.nin_shortcut(x)
|
| 101 |
+
|
| 102 |
+
return x + h
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Downsample(nn.Module):
|
| 106 |
+
def __init__(self, in_channels: int):
|
| 107 |
+
super().__init__()
|
| 108 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 109 |
+
self.conv = nn.Conv2d(
|
| 110 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x: Tensor):
|
| 114 |
+
pad = (0, 1, 0, 1)
|
| 115 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 116 |
+
x = self.conv(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Upsample(nn.Module):
|
| 121 |
+
def __init__(self, in_channels: int):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.conv = nn.Conv2d(
|
| 124 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def forward(self, x: Tensor):
|
| 128 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 129 |
+
x = self.conv(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Encoder(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
resolution: int,
|
| 137 |
+
in_channels: int,
|
| 138 |
+
ch: int,
|
| 139 |
+
ch_mult: list[int],
|
| 140 |
+
num_res_blocks: int,
|
| 141 |
+
z_channels: int,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.ch = ch
|
| 145 |
+
self.num_resolutions = len(ch_mult)
|
| 146 |
+
self.num_res_blocks = num_res_blocks
|
| 147 |
+
self.resolution = resolution
|
| 148 |
+
self.in_channels = in_channels
|
| 149 |
+
# downsampling
|
| 150 |
+
self.conv_in = nn.Conv2d(
|
| 151 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
curr_res = resolution
|
| 155 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 156 |
+
self.in_ch_mult = in_ch_mult
|
| 157 |
+
self.down = nn.ModuleList()
|
| 158 |
+
block_in = self.ch
|
| 159 |
+
for i_level in range(self.num_resolutions):
|
| 160 |
+
block = nn.ModuleList()
|
| 161 |
+
attn = nn.ModuleList()
|
| 162 |
+
block_in = ch * in_ch_mult[i_level]
|
| 163 |
+
block_out = ch * ch_mult[i_level]
|
| 164 |
+
for _ in range(self.num_res_blocks):
|
| 165 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 166 |
+
block_in = block_out
|
| 167 |
+
down = nn.Module()
|
| 168 |
+
down.block = block
|
| 169 |
+
down.attn = attn
|
| 170 |
+
if i_level != self.num_resolutions - 1:
|
| 171 |
+
down.downsample = Downsample(block_in)
|
| 172 |
+
curr_res = curr_res // 2
|
| 173 |
+
self.down.append(down)
|
| 174 |
+
|
| 175 |
+
# middle
|
| 176 |
+
self.mid = nn.Module()
|
| 177 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 178 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 179 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 180 |
+
|
| 181 |
+
# end
|
| 182 |
+
self.norm_out = nn.GroupNorm(
|
| 183 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 184 |
+
)
|
| 185 |
+
self.conv_out = nn.Conv2d(
|
| 186 |
+
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 190 |
+
# downsampling
|
| 191 |
+
hs = [self.conv_in(x)]
|
| 192 |
+
for i_level in range(self.num_resolutions):
|
| 193 |
+
for i_block in range(self.num_res_blocks):
|
| 194 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 195 |
+
if len(self.down[i_level].attn) > 0:
|
| 196 |
+
h = self.down[i_level].attn[i_block](h)
|
| 197 |
+
hs.append(h)
|
| 198 |
+
if i_level != self.num_resolutions - 1:
|
| 199 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 200 |
+
|
| 201 |
+
# middle
|
| 202 |
+
h = hs[-1]
|
| 203 |
+
h = self.mid.block_1(h)
|
| 204 |
+
h = self.mid.attn_1(h)
|
| 205 |
+
h = self.mid.block_2(h)
|
| 206 |
+
# end
|
| 207 |
+
h = self.norm_out(h)
|
| 208 |
+
h = swish(h)
|
| 209 |
+
h = self.conv_out(h)
|
| 210 |
+
return h
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Decoder(nn.Module):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
ch: int,
|
| 217 |
+
out_ch: int,
|
| 218 |
+
ch_mult: list[int],
|
| 219 |
+
num_res_blocks: int,
|
| 220 |
+
in_channels: int,
|
| 221 |
+
resolution: int,
|
| 222 |
+
z_channels: int,
|
| 223 |
+
):
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.ch = ch
|
| 226 |
+
self.num_resolutions = len(ch_mult)
|
| 227 |
+
self.num_res_blocks = num_res_blocks
|
| 228 |
+
self.resolution = resolution
|
| 229 |
+
self.in_channels = in_channels
|
| 230 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 231 |
+
|
| 232 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 233 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 234 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 235 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 236 |
+
|
| 237 |
+
# z to block_in
|
| 238 |
+
self.conv_in = nn.Conv2d(
|
| 239 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# middle
|
| 243 |
+
self.mid = nn.Module()
|
| 244 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 245 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 246 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 247 |
+
|
| 248 |
+
# upsampling
|
| 249 |
+
self.up = nn.ModuleList()
|
| 250 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 251 |
+
block = nn.ModuleList()
|
| 252 |
+
attn = nn.ModuleList()
|
| 253 |
+
block_out = ch * ch_mult[i_level]
|
| 254 |
+
for _ in range(self.num_res_blocks + 1):
|
| 255 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 256 |
+
block_in = block_out
|
| 257 |
+
up = nn.Module()
|
| 258 |
+
up.block = block
|
| 259 |
+
up.attn = attn
|
| 260 |
+
if i_level != 0:
|
| 261 |
+
up.upsample = Upsample(block_in)
|
| 262 |
+
curr_res = curr_res * 2
|
| 263 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 264 |
+
|
| 265 |
+
# end
|
| 266 |
+
self.norm_out = nn.GroupNorm(
|
| 267 |
+
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
| 268 |
+
)
|
| 269 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 270 |
+
|
| 271 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 272 |
+
# get dtype for proper tracing
|
| 273 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
| 274 |
+
|
| 275 |
+
# z to block_in
|
| 276 |
+
h = self.conv_in(z)
|
| 277 |
+
|
| 278 |
+
# middle
|
| 279 |
+
h = self.mid.block_1(h)
|
| 280 |
+
h = self.mid.attn_1(h)
|
| 281 |
+
h = self.mid.block_2(h)
|
| 282 |
+
|
| 283 |
+
# cast to proper dtype
|
| 284 |
+
h = h.to(upscale_dtype)
|
| 285 |
+
# upsampling
|
| 286 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 287 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 288 |
+
h = self.up[i_level].block[i_block](h)
|
| 289 |
+
if len(self.up[i_level].attn) > 0:
|
| 290 |
+
h = self.up[i_level].attn[i_block](h)
|
| 291 |
+
if i_level != 0:
|
| 292 |
+
h = self.up[i_level].upsample(h)
|
| 293 |
+
|
| 294 |
+
# end
|
| 295 |
+
h = self.norm_out(h)
|
| 296 |
+
h = swish(h)
|
| 297 |
+
h = self.conv_out(h)
|
| 298 |
+
return h
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class DiagonalGaussian(nn.Module):
|
| 302 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.sample = sample
|
| 305 |
+
self.chunk_dim = chunk_dim
|
| 306 |
+
|
| 307 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 308 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 309 |
+
if self.sample:
|
| 310 |
+
std = torch.exp(0.5 * logvar)
|
| 311 |
+
return mean + std * torch.randn_like(mean)
|
| 312 |
+
else:
|
| 313 |
+
return mean
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class AutoEncoder(nn.Module):
|
| 317 |
+
def __init__(self, params: AutoEncoderParams):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.params = params
|
| 320 |
+
self.encoder = Encoder(
|
| 321 |
+
resolution=params.resolution,
|
| 322 |
+
in_channels=params.in_channels,
|
| 323 |
+
ch=params.ch,
|
| 324 |
+
ch_mult=params.ch_mult,
|
| 325 |
+
num_res_blocks=params.num_res_blocks,
|
| 326 |
+
z_channels=params.z_channels,
|
| 327 |
+
)
|
| 328 |
+
self.decoder = Decoder(
|
| 329 |
+
resolution=params.resolution,
|
| 330 |
+
in_channels=params.in_channels,
|
| 331 |
+
ch=params.ch,
|
| 332 |
+
out_ch=params.out_ch,
|
| 333 |
+
ch_mult=params.ch_mult,
|
| 334 |
+
num_res_blocks=params.num_res_blocks,
|
| 335 |
+
z_channels=params.z_channels,
|
| 336 |
+
)
|
| 337 |
+
self.reg = DiagonalGaussian()
|
| 338 |
+
|
| 339 |
+
self.scale_factor = params.scale_factor
|
| 340 |
+
self.shift_factor = params.shift_factor
|
| 341 |
+
|
| 342 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 343 |
+
z = self.reg(self.encoder(x))
|
| 344 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 345 |
+
return z
|
| 346 |
+
|
| 347 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 348 |
+
z = z / self.scale_factor + self.shift_factor
|
| 349 |
+
return self.decoder(z)
|
| 350 |
+
|
| 351 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 352 |
+
return self.decode(self.encode(x))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def load_ae(
|
| 356 |
+
ckpt_path: str,
|
| 357 |
+
autoencoder_params: AutoEncoderParams,
|
| 358 |
+
device: str | torch.device = "cuda",
|
| 359 |
+
dtype=torch.bfloat16,
|
| 360 |
+
) -> AutoEncoder:
|
| 361 |
+
"""
|
| 362 |
+
Load the autoencoder from the given model name.
|
| 363 |
+
Args:
|
| 364 |
+
name (str): The name of the autoencoder.
|
| 365 |
+
device (str or torch.device): The device to load the autoencoder to.
|
| 366 |
+
Returns:
|
| 367 |
+
AutoEncoder: The loaded autoencoder.
|
| 368 |
+
"""
|
| 369 |
+
# Loading the autoencoder
|
| 370 |
+
print("Init AE")
|
| 371 |
+
with torch.device(device):
|
| 372 |
+
ae = AutoEncoder(autoencoder_params)
|
| 373 |
+
|
| 374 |
+
if not os.path.exists(ckpt_path):
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Autoencoder path {ckpt_path} does not exist. Please download it first."
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if ckpt_path is not None:
|
| 380 |
+
sd = load_sft(ckpt_path, device=str(device))
|
| 381 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
| 382 |
+
if len(missing) > 0:
|
| 383 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
| 384 |
+
if len(unexpected) > 0:
|
| 385 |
+
print(
|
| 386 |
+
f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)
|
| 387 |
+
)
|
| 388 |
+
return ae.to(dtype=dtype)
|
torchtitan/experiments/flux/model/hf_embedder.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torch import nn, Tensor
|
| 8 |
+
from transformers import CLIPTextModel, T5EncoderModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FluxEmbedder(nn.Module):
|
| 12 |
+
def __init__(self, version: str, **hf_kwargs):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.is_clip = version.startswith("openai")
|
| 15 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 16 |
+
|
| 17 |
+
if self.is_clip:
|
| 18 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
|
| 19 |
+
version, **hf_kwargs
|
| 20 |
+
)
|
| 21 |
+
else:
|
| 22 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
|
| 23 |
+
version, **hf_kwargs
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
| 27 |
+
|
| 28 |
+
def forward(self, batch_tokens: Tensor) -> Tensor:
|
| 29 |
+
"""
|
| 30 |
+
batch_tokens: [bsz, embedding_length]
|
| 31 |
+
|
| 32 |
+
For T5 Encoder, embeding_length is 768
|
| 33 |
+
For CLIP, embedding_length is 256
|
| 34 |
+
"""
|
| 35 |
+
outputs = self.hf_module(
|
| 36 |
+
input_ids=batch_tokens.to(self.hf_module.device),
|
| 37 |
+
attention_mask=None,
|
| 38 |
+
output_hidden_states=False,
|
| 39 |
+
)
|
| 40 |
+
return outputs[self.output_key]
|
torchtitan/experiments/flux/model/math.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 13 |
+
q, k = apply_rope(q, k, pe)
|
| 14 |
+
|
| 15 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 16 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 17 |
+
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 22 |
+
assert dim % 2 == 0
|
| 23 |
+
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
| 24 |
+
omega = 1.0 / (theta**scale)
|
| 25 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 26 |
+
out = torch.stack(
|
| 27 |
+
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
|
| 28 |
+
)
|
| 29 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 30 |
+
return out.float()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 34 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 35 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 36 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 37 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 38 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
torchtitan/experiments/flux/model/model.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torch import nn, Tensor
|
| 12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 13 |
+
from torchtitan.config_manager import JobConfig
|
| 14 |
+
|
| 15 |
+
from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
|
| 16 |
+
from torchtitan.experiments.flux.model.layers import (
|
| 17 |
+
DoubleStreamBlock,
|
| 18 |
+
EmbedND,
|
| 19 |
+
LastLayer,
|
| 20 |
+
MLPEmbedder,
|
| 21 |
+
SingleStreamBlock,
|
| 22 |
+
timestep_embedding,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
|
| 26 |
+
from torchtitan.tools.logging import logger
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class FluxModelArgs(BaseModelArgs):
|
| 31 |
+
in_channels: int = 64
|
| 32 |
+
out_channels: int = 64
|
| 33 |
+
vec_in_dim: int = 768
|
| 34 |
+
context_in_dim: int = 512
|
| 35 |
+
hidden_size: int = 3072
|
| 36 |
+
mlp_ratio: float = 4.0
|
| 37 |
+
num_heads: int = 24
|
| 38 |
+
depth: int = 19
|
| 39 |
+
depth_single_blocks: int = 38
|
| 40 |
+
axes_dim: tuple = (16, 56, 56)
|
| 41 |
+
theta: int = 10_000
|
| 42 |
+
qkv_bias: bool = True
|
| 43 |
+
guidance_embed: bool = True
|
| 44 |
+
autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
|
| 45 |
+
|
| 46 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
| 47 |
+
# context_in_dim is the same as the T5 embedding dimension
|
| 48 |
+
self.context_in_dim = job_config.encoder.max_t5_encoding_len
|
| 49 |
+
|
| 50 |
+
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
|
| 51 |
+
# TODO(jianiw): Add the number of flops for the autoencoder
|
| 52 |
+
nparams = sum(p.numel() for p in model.parameters())
|
| 53 |
+
logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
|
| 54 |
+
return nparams, 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class FluxModel(nn.Module, ModelProtocol):
|
| 58 |
+
"""
|
| 59 |
+
Transformer model for flow matching on sequences.
|
| 60 |
+
|
| 61 |
+
Agrs:
|
| 62 |
+
model_args: FluxModelArgs.
|
| 63 |
+
|
| 64 |
+
Attributes:
|
| 65 |
+
model_args (TransformerModelArgs): Model configuration arguments.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, model_args: FluxModelArgs):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.model_args = model_args
|
| 72 |
+
self.in_channels = model_args.in_channels
|
| 73 |
+
self.out_channels = model_args.out_channels
|
| 74 |
+
if model_args.hidden_size % model_args.num_heads != 0:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
|
| 77 |
+
)
|
| 78 |
+
pe_dim = model_args.hidden_size // model_args.num_heads
|
| 79 |
+
if sum(model_args.axes_dim) != pe_dim:
|
| 80 |
+
raise ValueError(
|
| 81 |
+
f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
|
| 82 |
+
)
|
| 83 |
+
self.hidden_size = model_args.hidden_size
|
| 84 |
+
self.num_heads = model_args.num_heads
|
| 85 |
+
self.pe_embedder = EmbedND(
|
| 86 |
+
dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
|
| 87 |
+
)
|
| 88 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 89 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 90 |
+
self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
|
| 91 |
+
self.guidance_in = (
|
| 92 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 93 |
+
if model_args.guidance_embed
|
| 94 |
+
else nn.Identity()
|
| 95 |
+
)
|
| 96 |
+
self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
|
| 97 |
+
|
| 98 |
+
self.double_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
DoubleStreamBlock(
|
| 101 |
+
self.hidden_size,
|
| 102 |
+
self.num_heads,
|
| 103 |
+
mlp_ratio=model_args.mlp_ratio,
|
| 104 |
+
qkv_bias=model_args.qkv_bias,
|
| 105 |
+
)
|
| 106 |
+
for _ in range(model_args.depth)
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.single_blocks = nn.ModuleList(
|
| 111 |
+
[
|
| 112 |
+
SingleStreamBlock(
|
| 113 |
+
self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
|
| 114 |
+
)
|
| 115 |
+
for _ in range(model_args.depth_single_blocks)
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 120 |
+
|
| 121 |
+
def init_weights(self, buffer_device=None):
|
| 122 |
+
# TODO(jianiw): replace placeholder with real weight init
|
| 123 |
+
for param in self.parameters():
|
| 124 |
+
param.data.uniform_(0, 0.1)
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self,
|
| 128 |
+
img: Tensor,
|
| 129 |
+
img_ids: Tensor,
|
| 130 |
+
txt: Tensor,
|
| 131 |
+
txt_ids: Tensor,
|
| 132 |
+
timesteps: Tensor,
|
| 133 |
+
y: Tensor,
|
| 134 |
+
guidance: Tensor | None = None,
|
| 135 |
+
) -> Tensor:
|
| 136 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 137 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 138 |
+
|
| 139 |
+
# running on sequences img
|
| 140 |
+
img = self.img_in(img)
|
| 141 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 142 |
+
if self.model_args.guidance_embed:
|
| 143 |
+
if guidance is None:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"Didn't get guidance strength for guidance distilled model."
|
| 146 |
+
)
|
| 147 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 148 |
+
vec = vec + self.vector_in(y)
|
| 149 |
+
txt = self.txt_in(txt)
|
| 150 |
+
|
| 151 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 152 |
+
pe = self.pe_embedder(ids)
|
| 153 |
+
|
| 154 |
+
for block in self.double_blocks:
|
| 155 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
| 156 |
+
|
| 157 |
+
img = torch.cat((txt, img), 1)
|
| 158 |
+
for block in self.single_blocks:
|
| 159 |
+
img = block(img, vec=vec, pe=pe)
|
| 160 |
+
img = img[:, txt.shape[1] :, ...]
|
| 161 |
+
|
| 162 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 163 |
+
return img
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
|
| 167 |
+
"""
|
| 168 |
+
Initialize a Flux model from a FluxModelArgs object.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
model_args (FluxModelArgs): Model configuration arguments.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
FluxModel: FluxModel model.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
return cls(model_args)
|
torchtitan/experiments/flux/scripts/download_autoencoder.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from requests.exceptions import HTTPError
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def hf_download(
|
| 13 |
+
repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None
|
| 14 |
+
) -> None:
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
hf_hub_download(
|
| 19 |
+
repo_id=repo_id,
|
| 20 |
+
filename=file_path,
|
| 21 |
+
local_dir=local_dir,
|
| 22 |
+
local_dir_use_symlinks=False,
|
| 23 |
+
token=hf_token,
|
| 24 |
+
)
|
| 25 |
+
except HTTPError as e:
|
| 26 |
+
if e.response.status_code == 401:
|
| 27 |
+
print(
|
| 28 |
+
"You need to pass a valid `--hf_token=...` to download private checkpoints."
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
raise e
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
import argparse
|
| 36 |
+
|
| 37 |
+
parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.")
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--repo_id",
|
| 40 |
+
type=str,
|
| 41 |
+
default="black-forest-labs/FLUX.1-dev",
|
| 42 |
+
help="Repository ID to download from. default to Flux-dev model",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--ae_path",
|
| 46 |
+
type=str,
|
| 47 |
+
default="ae.safetensors",
|
| 48 |
+
help="the autoencoder path relative to repo_id",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--hf_token", type=str, default=None, help="HuggingFace API token"
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--local_dir",
|
| 55 |
+
type=str,
|
| 56 |
+
default="torchtitan/experiments/flux/assets/autoencoder/",
|
| 57 |
+
help="local directory to save the autoencoder",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token)
|
torchtitan/experiments/flux/tests/test_generate_image.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from typing import Callable
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
|
| 15 |
+
from PIL import ExifTags, Image
|
| 16 |
+
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
|
| 20 |
+
|
| 21 |
+
from torchtitan.experiments.flux.model.autoencoder import (
|
| 22 |
+
AutoEncoder,
|
| 23 |
+
AutoEncoderParams,
|
| 24 |
+
load_ae,
|
| 25 |
+
)
|
| 26 |
+
from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
|
| 27 |
+
|
| 28 |
+
from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
|
| 29 |
+
from torchtitan.experiments.flux.utils import (
|
| 30 |
+
create_position_encoding_for_latents,
|
| 31 |
+
generate_noise_latent,
|
| 32 |
+
pack_latents,
|
| 33 |
+
preprocess_flux_data,
|
| 34 |
+
unpack_latents,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
| 39 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_lin_function(
|
| 43 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
| 44 |
+
) -> Callable[[float], float]:
|
| 45 |
+
m = (y2 - y1) / (x2 - x1)
|
| 46 |
+
b = y1 - m * x1
|
| 47 |
+
return lambda x: m * x + b
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_schedule(
|
| 51 |
+
num_steps: int,
|
| 52 |
+
image_seq_len: int,
|
| 53 |
+
base_shift: float = 0.5,
|
| 54 |
+
max_shift: float = 1.15,
|
| 55 |
+
shift: bool = True,
|
| 56 |
+
) -> list[float]:
|
| 57 |
+
# extra step for zero
|
| 58 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 59 |
+
|
| 60 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
| 61 |
+
if shift:
|
| 62 |
+
# estimate mu based on linear estimation between two points
|
| 63 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
| 64 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
| 65 |
+
|
| 66 |
+
return timesteps.tolist()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TestGenerateImage:
|
| 70 |
+
def test_generate_image(self):
|
| 71 |
+
"""
|
| 72 |
+
Run a forward pass of flux model to generate an image.
|
| 73 |
+
"""
|
| 74 |
+
name = "flux-dev"
|
| 75 |
+
img_width = 512
|
| 76 |
+
img_height = 512
|
| 77 |
+
seed = None
|
| 78 |
+
prompt = (
|
| 79 |
+
"a photo of a forest with mist swirling around the tree trunks. The word "
|
| 80 |
+
'"FLUX" is painted over it in big, red brush strokes with visible texture'
|
| 81 |
+
)
|
| 82 |
+
device = "cuda"
|
| 83 |
+
num_steps = None
|
| 84 |
+
loop = False
|
| 85 |
+
guidance = 3.5
|
| 86 |
+
output_dir = "output"
|
| 87 |
+
add_sampling_metadata = True
|
| 88 |
+
|
| 89 |
+
prompt = prompt.split("|")
|
| 90 |
+
if len(prompt) == 1:
|
| 91 |
+
prompt = prompt[0]
|
| 92 |
+
additional_prompts = None
|
| 93 |
+
else:
|
| 94 |
+
additional_prompts = prompt[1:]
|
| 95 |
+
prompt = prompt[0]
|
| 96 |
+
|
| 97 |
+
assert not (
|
| 98 |
+
(additional_prompts is not None) and loop
|
| 99 |
+
), "Do not provide additional prompts and set loop to True"
|
| 100 |
+
|
| 101 |
+
torch_device = torch.device(device)
|
| 102 |
+
if num_steps is None:
|
| 103 |
+
num_steps = 30
|
| 104 |
+
|
| 105 |
+
# allow for packing and conversion to latent space
|
| 106 |
+
img_height = 16 * (img_height // 16)
|
| 107 |
+
img_width = 16 * (img_width // 16)
|
| 108 |
+
|
| 109 |
+
# init all components
|
| 110 |
+
model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
|
| 111 |
+
|
| 112 |
+
ae = load_ae(
|
| 113 |
+
ckpt_path="assets/autoencoder/ae.safetensors",
|
| 114 |
+
autoencoder_params=AutoEncoderParams(),
|
| 115 |
+
device=torch_device,
|
| 116 |
+
dtype=torch.bfloat16,
|
| 117 |
+
)
|
| 118 |
+
clip_tokenizer = FluxTokenizer(
|
| 119 |
+
model_path="openai/clip-vit-large-patch14", max_length=77
|
| 120 |
+
)
|
| 121 |
+
t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
|
| 122 |
+
clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
|
| 123 |
+
torch_device, dtype=torch.bfloat16
|
| 124 |
+
)
|
| 125 |
+
t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
|
| 126 |
+
torch_device, dtype=torch.bfloat16
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
rng = torch.Generator(device="cpu")
|
| 130 |
+
|
| 131 |
+
if seed is None:
|
| 132 |
+
seed = rng.seed()
|
| 133 |
+
print(f"Generating with seed {seed}:\n{prompt}")
|
| 134 |
+
t0 = time.perf_counter()
|
| 135 |
+
output_name = os.path.join(output_dir, f"img_{seed}.jpg")
|
| 136 |
+
|
| 137 |
+
# Tokenize the prompt, on CPU
|
| 138 |
+
clip_tokens = clip_tokenizer.encode(prompt)
|
| 139 |
+
t5_tokens = t5_tokenizer.encode(prompt)
|
| 140 |
+
|
| 141 |
+
batch = preprocess_flux_data(
|
| 142 |
+
device=torch_device,
|
| 143 |
+
dtype=torch.bfloat16,
|
| 144 |
+
autoencoder=None,
|
| 145 |
+
clip_encoder=clip_encoder,
|
| 146 |
+
t5_encoder=t5_encoder,
|
| 147 |
+
batch={
|
| 148 |
+
"clip_tokens": clip_tokens,
|
| 149 |
+
"t5_tokens": t5_tokens,
|
| 150 |
+
},
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
img = self._generate_images(
|
| 154 |
+
device=torch_device,
|
| 155 |
+
dtype=torch.bfloat16,
|
| 156 |
+
model=model,
|
| 157 |
+
decoder=ae,
|
| 158 |
+
img_width=img_width,
|
| 159 |
+
img_height=img_height,
|
| 160 |
+
denoising_steps=num_steps,
|
| 161 |
+
seed=seed,
|
| 162 |
+
clip_encodings=batch["clip_encodings"],
|
| 163 |
+
t5_encodings=batch["t5_encodings"],
|
| 164 |
+
guidance=guidance,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if torch.cuda.is_available():
|
| 168 |
+
torch.cuda.synchronize()
|
| 169 |
+
t1 = time.perf_counter()
|
| 170 |
+
|
| 171 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
| 172 |
+
|
| 173 |
+
self._save_image(name, output_name, img, add_sampling_metadata, prompt)
|
| 174 |
+
|
| 175 |
+
def _generate_images(
|
| 176 |
+
self,
|
| 177 |
+
device: torch.device,
|
| 178 |
+
dtype: torch.dtype,
|
| 179 |
+
model: FluxModel,
|
| 180 |
+
decoder: AutoEncoder,
|
| 181 |
+
# image params:
|
| 182 |
+
img_width: int,
|
| 183 |
+
img_height: int,
|
| 184 |
+
# sampling params:
|
| 185 |
+
denoising_steps: int,
|
| 186 |
+
seed: int,
|
| 187 |
+
clip_encodings: torch.Tensor,
|
| 188 |
+
t5_encodings: torch.Tensor,
|
| 189 |
+
guidance: float = 4.0,
|
| 190 |
+
):
|
| 191 |
+
|
| 192 |
+
bsz = clip_encodings.shape[0]
|
| 193 |
+
latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
|
| 194 |
+
_, latent_channels, latent_height, latent_width = latents.shape
|
| 195 |
+
|
| 196 |
+
# create denoising schedule
|
| 197 |
+
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
|
| 198 |
+
|
| 199 |
+
# create positional encodings
|
| 200 |
+
POSITION_DIM = 3 # constant for Flux flow model
|
| 201 |
+
latent_pos_enc = create_position_encoding_for_latents(
|
| 202 |
+
bsz, latent_height, latent_width, POSITION_DIM
|
| 203 |
+
).to(latents)
|
| 204 |
+
text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
|
| 205 |
+
|
| 206 |
+
# convert img-like latents into sequences of patches
|
| 207 |
+
latents = pack_latents(latents)
|
| 208 |
+
|
| 209 |
+
# this is ignored for schnell
|
| 210 |
+
guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
|
| 211 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
| 212 |
+
t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
|
| 213 |
+
pred = model(
|
| 214 |
+
img=latents,
|
| 215 |
+
img_ids=latent_pos_enc,
|
| 216 |
+
txt=t5_encodings,
|
| 217 |
+
txt_ids=text_pos_enc,
|
| 218 |
+
y=clip_encodings,
|
| 219 |
+
timesteps=t_vec,
|
| 220 |
+
guidance=guidance_vec,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
latents = latents + (t_prev - t_curr) * pred
|
| 224 |
+
|
| 225 |
+
# convert sequences of patches into img-like latents
|
| 226 |
+
latents = unpack_latents(latents, latent_height, latent_width)
|
| 227 |
+
|
| 228 |
+
img = decoder.decode(latents)
|
| 229 |
+
return img
|
| 230 |
+
|
| 231 |
+
def _save_image(
|
| 232 |
+
self,
|
| 233 |
+
name: str,
|
| 234 |
+
output_name: str,
|
| 235 |
+
x: torch.Tensor,
|
| 236 |
+
add_sampling_metadata: bool,
|
| 237 |
+
prompt: str,
|
| 238 |
+
):
|
| 239 |
+
print(f"Saving {output_name}")
|
| 240 |
+
# bring into PIL format and save
|
| 241 |
+
x = x.clamp(-1, 1)
|
| 242 |
+
x = rearrange(x[0], "c h w -> h w c")
|
| 243 |
+
|
| 244 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
| 245 |
+
|
| 246 |
+
exif_data = Image.Exif()
|
| 247 |
+
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
| 248 |
+
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
| 249 |
+
exif_data[ExifTags.Base.Model] = name
|
| 250 |
+
if add_sampling_metadata:
|
| 251 |
+
exif_data[ExifTags.Base.ImageDescription] = prompt
|
| 252 |
+
img.save(output_name, exif=exif_data, quality=95, subsampling=0)
|
torchtitan/experiments/flux/train_configs/debug_model.toml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
[job]
|
| 3 |
+
dump_folder = "./outputs"
|
| 4 |
+
description = "Flux debug model"
|
| 5 |
+
print_args = false
|
| 6 |
+
use_for_integration_test = true
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = false
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 10
|
| 12 |
+
enable_memory_snapshot = false
|
| 13 |
+
save_memory_snapshot_folder = "memory_snapshot"
|
| 14 |
+
|
| 15 |
+
[metrics]
|
| 16 |
+
log_freq = 1
|
| 17 |
+
disable_color_printing = false
|
| 18 |
+
enable_tensorboard = false
|
| 19 |
+
save_tb_folder = "tb"
|
| 20 |
+
enable_wandb = false
|
| 21 |
+
|
| 22 |
+
[model]
|
| 23 |
+
name = "flux"
|
| 24 |
+
flavor = "flux-debug"
|
| 25 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 26 |
+
# test tokenizer.model, for debug purpose only
|
| 27 |
+
# tokenizer_path = "./tests/assets/test_tiktoken.model"
|
| 28 |
+
# converters = "float8"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
[optimizer]
|
| 32 |
+
name = "AdamW"
|
| 33 |
+
lr = 8e-4
|
| 34 |
+
eps = 1e-8
|
| 35 |
+
|
| 36 |
+
[lr_scheduler]
|
| 37 |
+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
|
| 38 |
+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
|
| 39 |
+
decay_type = "linear"
|
| 40 |
+
lr_min = 0.0
|
| 41 |
+
|
| 42 |
+
[training]
|
| 43 |
+
batch_size = 32
|
| 44 |
+
seq_len = 512
|
| 45 |
+
max_norm = 1.0 # grad norm clipping
|
| 46 |
+
steps = 10
|
| 47 |
+
compile = false
|
| 48 |
+
dataset = "cc12m"
|
| 49 |
+
guidance = 3.5
|
| 50 |
+
seed = 0
|
| 51 |
+
|
| 52 |
+
[encoder]
|
| 53 |
+
t5_encoder="google/t5-v1_1-small"
|
| 54 |
+
clip_encoder="openai/clip-vit-large-patch14"
|
| 55 |
+
max_t5_encoding_len=512
|
| 56 |
+
auto_encoder_path="torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image
|
| 57 |
+
|
| 58 |
+
[parallelism]
|
| 59 |
+
data_parallel_replicate_degree = 1
|
| 60 |
+
data_parallel_shard_degree = 1
|
| 61 |
+
fsdp_reshard_after_forward = "default" # default / never / always
|
| 62 |
+
tensor_parallel_degree = 1
|
| 63 |
+
enable_async_tensor_parallel = false
|
| 64 |
+
pipeline_parallel_degree = 1
|
| 65 |
+
context_parallel_degree = 1
|
| 66 |
+
|
| 67 |
+
[experimental]
|
| 68 |
+
custom_args_module = "torchtitan.experiments.flux.flux_argparser"
|
torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py
ADDED
|
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import logging
|
| 9 |
+
import math
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
from typing import Dict, List, Tuple
|
| 13 |
+
|
| 14 |
+
# import numpy as np
|
| 15 |
+
import torch #
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torch.optim as optim
|
| 19 |
+
|
| 20 |
+
# from torchao_pr.mg_grouped_gemm import mg_grouped_gemm
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Try to import the optimized MG GEMM implementation
|
| 28 |
+
try:
|
| 29 |
+
from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward,
|
| 30 |
+
grouped_gemm_forward,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
has_mg_gemm = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
logging.warning("MG GEMM implementation not found. Will use manual looping only.")
|
| 36 |
+
has_mg_gemm = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Router(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Router module that assigns tokens to experts.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.input_dim = input_dim
|
| 47 |
+
self.num_experts = num_experts
|
| 48 |
+
self.top_k = top_k
|
| 49 |
+
|
| 50 |
+
# Routing layer
|
| 51 |
+
self.router = nn.Linear(input_dim, num_experts)
|
| 52 |
+
|
| 53 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
| 54 |
+
"""
|
| 55 |
+
Route input tokens to experts.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Tuple containing:
|
| 62 |
+
- router_logits: Raw routing probabilities
|
| 63 |
+
- dispatch_tensor: One-hot tensor indicating expert assignment
|
| 64 |
+
- expert_indices: List of indices for each expert's tokens
|
| 65 |
+
"""
|
| 66 |
+
batch_size, seq_len, _ = x.shape
|
| 67 |
+
|
| 68 |
+
# Flatten batch and sequence dimensions
|
| 69 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
| 70 |
+
|
| 71 |
+
# Compute routing probabilities
|
| 72 |
+
router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts)
|
| 73 |
+
|
| 74 |
+
# Apply softmax to get probabilities
|
| 75 |
+
router_probs = F.softmax(router_logits, dim=-1)
|
| 76 |
+
|
| 77 |
+
# Get top-k experts for each token
|
| 78 |
+
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 79 |
+
|
| 80 |
+
# Normalize top-k probabilities
|
| 81 |
+
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
|
| 82 |
+
|
| 83 |
+
# Create dispatch tensor (one-hot representation of assignments)
|
| 84 |
+
dispatch_tensor = torch.zeros_like(router_probs)
|
| 85 |
+
token_indices = (
|
| 86 |
+
torch.arange(router_probs.size(0), device=router_probs.device)
|
| 87 |
+
.unsqueeze(1)
|
| 88 |
+
.expand(-1, self.top_k)
|
| 89 |
+
)
|
| 90 |
+
dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1))
|
| 91 |
+
|
| 92 |
+
# For each expert, get the indices of tokens routed to it
|
| 93 |
+
expert_indices = []
|
| 94 |
+
for expert_idx in range(self.num_experts):
|
| 95 |
+
# Get indices of tokens that have non-zero probability for this expert
|
| 96 |
+
indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[
|
| 97 |
+
0
|
| 98 |
+
]
|
| 99 |
+
expert_indices.append(indices)
|
| 100 |
+
|
| 101 |
+
return router_logits, dispatch_tensor, expert_indices
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Expert(nn.Module):
|
| 105 |
+
"""
|
| 106 |
+
Individual expert module.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 112 |
+
self.activation = nn.GELU()
|
| 113 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
|
| 114 |
+
|
| 115 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
x = self.fc1(x)
|
| 117 |
+
x = self.activation(x)
|
| 118 |
+
x = self.fc2(x)
|
| 119 |
+
return x
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class MixtureOfExperts(nn.Module):
|
| 123 |
+
"""
|
| 124 |
+
Mixture of Experts layer with support for both manual looping and grouped GEMM.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
input_dim: int,
|
| 130 |
+
hidden_dim: int,
|
| 131 |
+
output_dim: int,
|
| 132 |
+
num_experts: int,
|
| 133 |
+
top_k: int = 2,
|
| 134 |
+
use_mg_gemm: bool = False,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.input_dim = input_dim
|
| 138 |
+
self.hidden_dim = hidden_dim
|
| 139 |
+
self.output_dim = output_dim
|
| 140 |
+
self.num_experts = num_experts
|
| 141 |
+
self.top_k = top_k
|
| 142 |
+
self.use_mg_gemm = use_mg_gemm and has_mg_gemm
|
| 143 |
+
|
| 144 |
+
# Router
|
| 145 |
+
self.router = Router(input_dim, num_experts, top_k)
|
| 146 |
+
|
| 147 |
+
# Create expert modules
|
| 148 |
+
if self.use_mg_gemm:
|
| 149 |
+
# For MG GEMM, we need a single weight tensor for all experts
|
| 150 |
+
# First layer (input -> hidden)
|
| 151 |
+
self.expert_fc1_weight = nn.Parameter(
|
| 152 |
+
torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim)
|
| 153 |
+
)
|
| 154 |
+
# self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim))
|
| 155 |
+
|
| 156 |
+
# Second layer (hidden -> output)
|
| 157 |
+
self.expert_fc2_weight = nn.Parameter(
|
| 158 |
+
torch.randn(num_experts * output_dim, hidden_dim)
|
| 159 |
+
/ math.sqrt(hidden_dim)
|
| 160 |
+
)
|
| 161 |
+
# self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim))
|
| 162 |
+
else:
|
| 163 |
+
# For manual looping, create separate experts
|
| 164 |
+
self.experts = nn.ModuleList(
|
| 165 |
+
[Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
"""
|
| 170 |
+
Forward pass using manual looping over experts.
|
| 171 |
+
"""
|
| 172 |
+
batch_size, seq_len, _ = x.shape
|
| 173 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
| 174 |
+
|
| 175 |
+
# Get routing information
|
| 176 |
+
router_logits, dispatch_tensor, expert_indices = self.router(x)
|
| 177 |
+
|
| 178 |
+
# Initialize output tensor
|
| 179 |
+
final_output = torch.zeros(
|
| 180 |
+
batch_size * seq_len, self.output_dim, device=x.device
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Process each expert
|
| 184 |
+
for expert_idx, indices in enumerate(expert_indices):
|
| 185 |
+
if indices.numel() > 0:
|
| 186 |
+
# Get tokens routed to this expert
|
| 187 |
+
expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim)
|
| 188 |
+
|
| 189 |
+
# Process tokens through expert
|
| 190 |
+
expert_outputs = self.experts[expert_idx](
|
| 191 |
+
expert_inputs
|
| 192 |
+
) # (num_tokens_for_expert, output_dim)
|
| 193 |
+
|
| 194 |
+
# Scale outputs by router probabilities
|
| 195 |
+
scaled_outputs = expert_outputs * dispatch_tensor[
|
| 196 |
+
indices, expert_idx
|
| 197 |
+
].unsqueeze(1)
|
| 198 |
+
|
| 199 |
+
# Add to final output
|
| 200 |
+
final_output.index_add_(0, indices, scaled_outputs)
|
| 201 |
+
|
| 202 |
+
# Reshape back to original dimensions
|
| 203 |
+
output = final_output.reshape(batch_size, seq_len, self.output_dim)
|
| 204 |
+
|
| 205 |
+
return output, router_logits
|
| 206 |
+
|
| 207 |
+
def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor:
|
| 208 |
+
batch_size, seq_len, _ = x.shape
|
| 209 |
+
x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
|
| 210 |
+
total_tokens = batch_size * seq_len
|
| 211 |
+
|
| 212 |
+
# Get routing information
|
| 213 |
+
router_logits, dispatch_tensor, expert_indices = self.router(x)
|
| 214 |
+
|
| 215 |
+
# Get token counts for each expert
|
| 216 |
+
token_counts = [indices.numel() for indices in expert_indices]
|
| 217 |
+
m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device)
|
| 218 |
+
|
| 219 |
+
print(f"Token counts per expert: {token_counts}")
|
| 220 |
+
print(f"m_sizes: {m_sizes}")
|
| 221 |
+
|
| 222 |
+
# Create the combined input tensor
|
| 223 |
+
combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device)
|
| 224 |
+
|
| 225 |
+
start_idx = 0
|
| 226 |
+
for expert_idx, indices in enumerate(expert_indices):
|
| 227 |
+
if indices.numel() > 0:
|
| 228 |
+
end_idx = start_idx + indices.numel()
|
| 229 |
+
combined_input[start_idx:end_idx] = x_flat[indices]
|
| 230 |
+
start_idx = end_idx
|
| 231 |
+
|
| 232 |
+
print(f"combined_input shape: {combined_input.shape}")
|
| 233 |
+
|
| 234 |
+
# First layer: input -> hidden
|
| 235 |
+
fc1_weight_reshaped = self.expert_fc1_weight.reshape(
|
| 236 |
+
self.num_experts, self.hidden_dim, self.input_dim
|
| 237 |
+
)
|
| 238 |
+
fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim)
|
| 239 |
+
|
| 240 |
+
print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}")
|
| 241 |
+
|
| 242 |
+
# Run the grouped GEMM
|
| 243 |
+
hidden_outputs = grouped_gemm_forward(
|
| 244 |
+
combined_input, fc1_weight_combined, m_sizes
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}")
|
| 248 |
+
|
| 249 |
+
# Apply activation
|
| 250 |
+
hidden_outputs = F.gelu(hidden_outputs)
|
| 251 |
+
|
| 252 |
+
print(f"hidden_outputs shape after activation: {hidden_outputs.shape}")
|
| 253 |
+
|
| 254 |
+
# Second layer: hidden -> output
|
| 255 |
+
# Reshape hidden_outputs to match expected dimensions
|
| 256 |
+
reshaped_hidden_outputs = []
|
| 257 |
+
start_idx = 0
|
| 258 |
+
|
| 259 |
+
for expert_idx, count in enumerate(token_counts):
|
| 260 |
+
if count > 0:
|
| 261 |
+
end_idx = start_idx + count
|
| 262 |
+
# Take this expert's outputs and reshape to [count, hidden_dim]
|
| 263 |
+
expert_output = hidden_outputs[
|
| 264 |
+
start_idx:end_idx,
|
| 265 |
+
expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim,
|
| 266 |
+
]
|
| 267 |
+
reshaped_hidden_outputs.append(expert_output)
|
| 268 |
+
start_idx = end_idx
|
| 269 |
+
|
| 270 |
+
# Concatenate all reshaped outputs
|
| 271 |
+
hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0)
|
| 272 |
+
|
| 273 |
+
# Reshape expert weights for second layer
|
| 274 |
+
fc2_weight_reshaped = self.expert_fc2_weight.reshape(
|
| 275 |
+
self.num_experts, self.output_dim, self.hidden_dim
|
| 276 |
+
)
|
| 277 |
+
fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim)
|
| 278 |
+
|
| 279 |
+
print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}")
|
| 280 |
+
|
| 281 |
+
# Run the second grouped GEMM
|
| 282 |
+
expert_outputs_combined = grouped_gemm_forward(
|
| 283 |
+
hidden_outputs, fc2_weight_combined, m_sizes
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Initialize final output tensor with correct shape
|
| 287 |
+
final_output = torch.zeros(total_tokens, self.output_dim, device=x.device)
|
| 288 |
+
|
| 289 |
+
# Distribute the outputs back to the original token positions
|
| 290 |
+
start_idx = 0
|
| 291 |
+
for expert_idx, indices in enumerate(expert_indices):
|
| 292 |
+
if indices.numel() > 0:
|
| 293 |
+
end_idx = start_idx + indices.numel()
|
| 294 |
+
# Get this expert's outputs
|
| 295 |
+
expert_outputs = expert_outputs_combined[start_idx:end_idx]
|
| 296 |
+
|
| 297 |
+
print(
|
| 298 |
+
f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Scale outputs by router probabilities
|
| 302 |
+
scaled_outputs = expert_outputs * dispatch_tensor[
|
| 303 |
+
indices, expert_idx
|
| 304 |
+
].unsqueeze(1)
|
| 305 |
+
|
| 306 |
+
# Ensure dimensions match before using index_add_
|
| 307 |
+
if scaled_outputs.shape[1] != final_output.shape[1]:
|
| 308 |
+
# print(
|
| 309 |
+
# f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}"
|
| 310 |
+
# )
|
| 311 |
+
# Reshape if needed - make sure output_dim is correct
|
| 312 |
+
scaled_outputs = scaled_outputs[:, : self.output_dim]
|
| 313 |
+
|
| 314 |
+
# Add to final output
|
| 315 |
+
final_output.index_add_(0, indices, scaled_outputs)
|
| 316 |
+
|
| 317 |
+
start_idx = end_idx
|
| 318 |
+
|
| 319 |
+
# Reshape back to original dimensions
|
| 320 |
+
output = final_output.reshape(batch_size, seq_len, self.output_dim)
|
| 321 |
+
|
| 322 |
+
return output, router_logits
|
| 323 |
+
|
| 324 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 325 |
+
if self.use_mg_gemm and has_mg_gemm:
|
| 326 |
+
return self.forward_mg_gemm(x)
|
| 327 |
+
else:
|
| 328 |
+
return self.forward_manual_loop(x)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class MoEModel(nn.Module):
|
| 332 |
+
"""
|
| 333 |
+
Simple model using MoE layers.
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
vocab_size: int,
|
| 339 |
+
embed_dim: int,
|
| 340 |
+
hidden_dim: int,
|
| 341 |
+
num_experts: int,
|
| 342 |
+
top_k: int = 2,
|
| 343 |
+
use_mg_gemm: bool = False,
|
| 344 |
+
):
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 347 |
+
self.moe_layer = MixtureOfExperts(
|
| 348 |
+
input_dim=embed_dim,
|
| 349 |
+
hidden_dim=hidden_dim,
|
| 350 |
+
output_dim=embed_dim,
|
| 351 |
+
num_experts=num_experts,
|
| 352 |
+
top_k=top_k,
|
| 353 |
+
use_mg_gemm=use_mg_gemm,
|
| 354 |
+
)
|
| 355 |
+
self.output_layer = nn.Linear(embed_dim, vocab_size)
|
| 356 |
+
|
| 357 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 358 |
+
# x shape: (batch_size, seq_len)
|
| 359 |
+
embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
|
| 360 |
+
moe_output, router_logits = self.moe_layer(
|
| 361 |
+
embedded
|
| 362 |
+
) # (batch_size, seq_len, embed_dim)
|
| 363 |
+
logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size)
|
| 364 |
+
return logits, router_logits
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def compute_load_balancing_loss(
|
| 368 |
+
router_logits: torch.Tensor, num_experts: int
|
| 369 |
+
) -> torch.Tensor:
|
| 370 |
+
"""
|
| 371 |
+
Compute the load balancing loss for MoE training.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts)
|
| 375 |
+
num_experts (int): Number of experts
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
torch.Tensor: Load balancing loss
|
| 379 |
+
"""
|
| 380 |
+
# Get router probabilities
|
| 381 |
+
router_probs = F.softmax(
|
| 382 |
+
router_logits, dim=-1
|
| 383 |
+
) # (batch_size * seq_len, num_experts)
|
| 384 |
+
|
| 385 |
+
# Compute fraction of tokens routed to each expert
|
| 386 |
+
# Sum across the batch dimension and normalize
|
| 387 |
+
router_probs_sum = router_probs.sum(dim=0) # (num_experts,)
|
| 388 |
+
router_probs_sum = router_probs_sum / router_probs_sum.sum()
|
| 389 |
+
|
| 390 |
+
# Compute the mean probability per expert
|
| 391 |
+
mean_prob = 1.0 / num_experts
|
| 392 |
+
|
| 393 |
+
# Compute the fraction of tokens routed to each expert
|
| 394 |
+
# The goal is to have uniform routing across experts
|
| 395 |
+
load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum)
|
| 396 |
+
|
| 397 |
+
return load_balancing_loss
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def generate_sample_data(
|
| 401 |
+
batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda"
|
| 402 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 403 |
+
"""
|
| 404 |
+
Generate sample data for training.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
batch_size (int): Batch size
|
| 408 |
+
seq_len (int): Sequence length
|
| 409 |
+
vocab_size (int): Vocabulary size
|
| 410 |
+
device (str): Device to use
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
Tuple of input tokens and target tokens
|
| 414 |
+
"""
|
| 415 |
+
# Generate random input tokens
|
| 416 |
+
inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
| 417 |
+
|
| 418 |
+
# Generate random target tokens
|
| 419 |
+
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
| 420 |
+
|
| 421 |
+
return inputs, targets
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def train_epoch(
|
| 425 |
+
model: nn.Module,
|
| 426 |
+
optimizer: torch.optim.Optimizer,
|
| 427 |
+
batch_size: int,
|
| 428 |
+
seq_len: int,
|
| 429 |
+
vocab_size: int,
|
| 430 |
+
num_batches: int,
|
| 431 |
+
device: str,
|
| 432 |
+
load_balance_coef: float = 0.01,
|
| 433 |
+
) -> Dict[str, float]:
|
| 434 |
+
"""
|
| 435 |
+
Train the model for one epoch.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
model (nn.Module): Model to train
|
| 439 |
+
optimizer (torch.optim.Optimizer): Optimizer
|
| 440 |
+
batch_size (int): Batch size
|
| 441 |
+
seq_len (int): Sequence length
|
| 442 |
+
vocab_size (int): Vocabulary size
|
| 443 |
+
num_batches (int): Number of batches per epoch
|
| 444 |
+
device (str): Device to use
|
| 445 |
+
load_balance_coef (float): Coefficient for load balancing loss
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
Dict containing training metrics
|
| 449 |
+
"""
|
| 450 |
+
model.train()
|
| 451 |
+
total_loss = 0.0
|
| 452 |
+
total_acc = 0.0
|
| 453 |
+
start_time = time.time()
|
| 454 |
+
|
| 455 |
+
for i in range(num_batches):
|
| 456 |
+
# Generate sample data
|
| 457 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 458 |
+
|
| 459 |
+
# Forward pass
|
| 460 |
+
optimizer.zero_grad()
|
| 461 |
+
logits, router_logits = model(inputs)
|
| 462 |
+
|
| 463 |
+
# Compute loss
|
| 464 |
+
# Reshape for cross entropy loss
|
| 465 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
| 466 |
+
targets_flat = targets.reshape(-1)
|
| 467 |
+
|
| 468 |
+
# Cross entropy loss
|
| 469 |
+
ce_loss = F.cross_entropy(logits_flat, targets_flat)
|
| 470 |
+
|
| 471 |
+
# Load balancing loss
|
| 472 |
+
lb_loss = compute_load_balancing_loss(
|
| 473 |
+
router_logits, model.moe_layer.num_experts
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Combined loss
|
| 477 |
+
loss = ce_loss + load_balance_coef * lb_loss
|
| 478 |
+
|
| 479 |
+
# Backward pass
|
| 480 |
+
loss.backward()
|
| 481 |
+
optimizer.step()
|
| 482 |
+
|
| 483 |
+
# Compute accuracy
|
| 484 |
+
preds = logits_flat.argmax(dim=-1)
|
| 485 |
+
correct = (preds == targets_flat).float().sum()
|
| 486 |
+
acc = correct / (batch_size * seq_len)
|
| 487 |
+
|
| 488 |
+
# Accumulate metrics
|
| 489 |
+
total_loss += loss.item()
|
| 490 |
+
total_acc += acc.item()
|
| 491 |
+
|
| 492 |
+
# Log progress
|
| 493 |
+
if (i + 1) % 10 == 0:
|
| 494 |
+
logging.info(
|
| 495 |
+
f"Batch {i + 1}/{num_batches} | "
|
| 496 |
+
f"Loss: {loss.item():.4f} | "
|
| 497 |
+
f"CE Loss: {ce_loss.item():.4f} | "
|
| 498 |
+
f"LB Loss: {lb_loss.item():.4f} | "
|
| 499 |
+
f"Acc: {acc.item():.4f}"
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# Compute average metrics
|
| 503 |
+
avg_loss = total_loss / num_batches
|
| 504 |
+
avg_acc = total_acc / num_batches
|
| 505 |
+
epoch_time = time.time() - start_time
|
| 506 |
+
|
| 507 |
+
return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time}
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def evaluate(
|
| 511 |
+
model: nn.Module,
|
| 512 |
+
batch_size: int,
|
| 513 |
+
seq_len: int,
|
| 514 |
+
vocab_size: int,
|
| 515 |
+
num_batches: int,
|
| 516 |
+
device: str,
|
| 517 |
+
) -> Dict[str, float]:
|
| 518 |
+
"""
|
| 519 |
+
Evaluate the model.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
model (nn.Module): Model to evaluate
|
| 523 |
+
batch_size (int): Batch size
|
| 524 |
+
seq_len (int): Sequence length
|
| 525 |
+
vocab_size (int): Vocabulary size
|
| 526 |
+
num_batches (int): Number of batches for evaluation
|
| 527 |
+
device (str): Device to use
|
| 528 |
+
|
| 529 |
+
Returns:
|
| 530 |
+
Dict containing evaluation metrics
|
| 531 |
+
"""
|
| 532 |
+
model.eval()
|
| 533 |
+
total_loss = 0.0
|
| 534 |
+
total_acc = 0.0
|
| 535 |
+
|
| 536 |
+
with torch.no_grad():
|
| 537 |
+
for i in range(num_batches):
|
| 538 |
+
# Generate sample data
|
| 539 |
+
inputs, targets = generate_sample_data(
|
| 540 |
+
batch_size, seq_len, vocab_size, device
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# Forward pass
|
| 544 |
+
logits, router_logits = model(inputs)
|
| 545 |
+
|
| 546 |
+
# Compute loss
|
| 547 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
| 548 |
+
targets_flat = targets.reshape(-1)
|
| 549 |
+
|
| 550 |
+
# Cross entropy loss
|
| 551 |
+
loss = F.cross_entropy(logits_flat, targets_flat)
|
| 552 |
+
|
| 553 |
+
# Compute accuracy
|
| 554 |
+
preds = logits_flat.argmax(dim=-1)
|
| 555 |
+
correct = (preds == targets_flat).float().sum()
|
| 556 |
+
acc = correct / (batch_size * seq_len)
|
| 557 |
+
|
| 558 |
+
# Accumulate metrics
|
| 559 |
+
total_loss += loss.item()
|
| 560 |
+
total_acc += acc.item()
|
| 561 |
+
|
| 562 |
+
# Compute average metrics
|
| 563 |
+
avg_loss = total_loss / num_batches
|
| 564 |
+
avg_acc = total_acc / num_batches
|
| 565 |
+
|
| 566 |
+
return {"loss": avg_loss, "acc": avg_acc}
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def measure_performance(
|
| 570 |
+
model: nn.Module,
|
| 571 |
+
batch_size: int,
|
| 572 |
+
seq_len: int,
|
| 573 |
+
vocab_size: int,
|
| 574 |
+
num_batches: int,
|
| 575 |
+
device: str,
|
| 576 |
+
) -> Dict[str, float]:
|
| 577 |
+
"""
|
| 578 |
+
Measure forward and backward pass performance.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
model (nn.Module): Model to evaluate
|
| 582 |
+
batch_size (int): Batch size
|
| 583 |
+
seq_len (int): Sequence length
|
| 584 |
+
vocab_size (int): Vocabulary size
|
| 585 |
+
num_batches (int): Number of batches for measurement
|
| 586 |
+
device (str): Device to use
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Dict containing performance metrics
|
| 590 |
+
"""
|
| 591 |
+
model.train()
|
| 592 |
+
|
| 593 |
+
# Create dummy optimizer
|
| 594 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 595 |
+
|
| 596 |
+
# Warmup
|
| 597 |
+
for _ in range(5):
|
| 598 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 599 |
+
logits, router_logits = model(inputs)
|
| 600 |
+
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
|
| 601 |
+
loss.backward()
|
| 602 |
+
optimizer.zero_grad()
|
| 603 |
+
|
| 604 |
+
# Measure forward pass time
|
| 605 |
+
torch.cuda.synchronize()
|
| 606 |
+
forward_start = time.time()
|
| 607 |
+
|
| 608 |
+
for _ in range(num_batches):
|
| 609 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 610 |
+
with torch.no_grad():
|
| 611 |
+
logits, router_logits = model(inputs)
|
| 612 |
+
|
| 613 |
+
torch.cuda.synchronize()
|
| 614 |
+
forward_end = time.time()
|
| 615 |
+
forward_time = (forward_end - forward_start) / num_batches
|
| 616 |
+
|
| 617 |
+
# Measure backward pass time
|
| 618 |
+
torch.cuda.synchronize()
|
| 619 |
+
backward_start = time.time()
|
| 620 |
+
|
| 621 |
+
for _ in range(num_batches):
|
| 622 |
+
inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
|
| 623 |
+
logits, router_logits = model(inputs)
|
| 624 |
+
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
|
| 625 |
+
loss.backward()
|
| 626 |
+
optimizer.zero_grad()
|
| 627 |
+
|
| 628 |
+
torch.cuda.synchronize()
|
| 629 |
+
backward_end = time.time()
|
| 630 |
+
backward_time = (backward_end - backward_start) / num_batches
|
| 631 |
+
|
| 632 |
+
return {
|
| 633 |
+
"forward_time": forward_time * 1000, # Convert to ms
|
| 634 |
+
"backward_time": backward_time * 1000, # Convert to ms
|
| 635 |
+
"total_time": (forward_time + backward_time) * 1000, # Convert to ms
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def compare_methods(args):
|
| 640 |
+
"""
|
| 641 |
+
Compare manual looping and MG GEMM implementations.
|
| 642 |
+
"""
|
| 643 |
+
device = torch.device(args.device)
|
| 644 |
+
|
| 645 |
+
# Create models
|
| 646 |
+
manual_model = MoEModel(
|
| 647 |
+
vocab_size=args.vocab_size,
|
| 648 |
+
embed_dim=args.embed_dim,
|
| 649 |
+
hidden_dim=args.hidden_dim,
|
| 650 |
+
num_experts=args.num_experts,
|
| 651 |
+
top_k=args.top_k,
|
| 652 |
+
use_mg_gemm=False,
|
| 653 |
+
).to(device)
|
| 654 |
+
|
| 655 |
+
if has_mg_gemm:
|
| 656 |
+
mg_model = MoEModel(
|
| 657 |
+
vocab_size=args.vocab_size,
|
| 658 |
+
embed_dim=args.embed_dim,
|
| 659 |
+
hidden_dim=args.hidden_dim,
|
| 660 |
+
num_experts=args.num_experts,
|
| 661 |
+
top_k=args.top_k,
|
| 662 |
+
use_mg_gemm=True,
|
| 663 |
+
).to(device)
|
| 664 |
+
else:
|
| 665 |
+
mg_model = None
|
| 666 |
+
|
| 667 |
+
# Measure performance
|
| 668 |
+
logging.info("Measuring performance of manual looping method...")
|
| 669 |
+
manual_perf = measure_performance(
|
| 670 |
+
manual_model,
|
| 671 |
+
args.batch_size,
|
| 672 |
+
args.seq_len,
|
| 673 |
+
args.vocab_size,
|
| 674 |
+
args.perf_batches,
|
| 675 |
+
device,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
if mg_model is not None:
|
| 679 |
+
logging.info("Measuring performance of MG GEMM method...")
|
| 680 |
+
mg_perf = measure_performance(
|
| 681 |
+
mg_model,
|
| 682 |
+
args.batch_size,
|
| 683 |
+
args.seq_len,
|
| 684 |
+
args.vocab_size,
|
| 685 |
+
args.perf_batches,
|
| 686 |
+
device,
|
| 687 |
+
)
|
| 688 |
+
else:
|
| 689 |
+
mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0}
|
| 690 |
+
|
| 691 |
+
# Log results
|
| 692 |
+
logging.info("\n===== Performance Comparison =====")
|
| 693 |
+
logging.info("Model Configuration:")
|
| 694 |
+
logging.info(f" - Batch Size: {args.batch_size}")
|
| 695 |
+
logging.info(f" - Sequence Length: {args.seq_len}")
|
| 696 |
+
logging.info(f" - Embed Dimension: {args.embed_dim}")
|
| 697 |
+
logging.info(f" - Hidden Dimension: {args.hidden_dim}")
|
| 698 |
+
logging.info(f" - Number of Experts: {args.num_experts}")
|
| 699 |
+
logging.info(f" - Top-K: {args.top_k}")
|
| 700 |
+
logging.info("")
|
| 701 |
+
|
| 702 |
+
logging.info("Manual Looping Method:")
|
| 703 |
+
logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms")
|
| 704 |
+
logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms")
|
| 705 |
+
logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms")
|
| 706 |
+
logging.info("")
|
| 707 |
+
|
| 708 |
+
if mg_model is not None:
|
| 709 |
+
logging.info("MG GEMM Method:")
|
| 710 |
+
logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms")
|
| 711 |
+
logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms")
|
| 712 |
+
logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms")
|
| 713 |
+
logging.info("")
|
| 714 |
+
|
| 715 |
+
# Calculate speedup
|
| 716 |
+
forward_speedup = (
|
| 717 |
+
manual_perf["forward_time"] / mg_perf["forward_time"]
|
| 718 |
+
if mg_perf["forward_time"] > 0
|
| 719 |
+
else 0
|
| 720 |
+
)
|
| 721 |
+
backward_speedup = (
|
| 722 |
+
manual_perf["backward_time"] / mg_perf["backward_time"]
|
| 723 |
+
if mg_perf["backward_time"] > 0
|
| 724 |
+
else 0
|
| 725 |
+
)
|
| 726 |
+
total_speedup = (
|
| 727 |
+
manual_perf["total_time"] / mg_perf["total_time"]
|
| 728 |
+
if mg_perf["total_time"] > 0
|
| 729 |
+
else 0
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
logging.info("Speedup (MG GEMM vs Manual):")
|
| 733 |
+
logging.info(f" - Forward Speedup: {forward_speedup:.2f}x")
|
| 734 |
+
logging.info(f" - Backward Speedup: {backward_speedup:.2f}x")
|
| 735 |
+
logging.info(f" - Total Speedup: {total_speedup:.2f}x")
|
| 736 |
+
else:
|
| 737 |
+
logging.info("MG GEMM method not available.")
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def train_model(args):
|
| 741 |
+
"""
|
| 742 |
+
Train an MoE model.
|
| 743 |
+
"""
|
| 744 |
+
device = torch.device(args.device)
|
| 745 |
+
|
| 746 |
+
# Create model
|
| 747 |
+
model = MoEModel(
|
| 748 |
+
vocab_size=args.vocab_size,
|
| 749 |
+
embed_dim=args.embed_dim,
|
| 750 |
+
hidden_dim=args.hidden_dim,
|
| 751 |
+
num_experts=args.num_experts,
|
| 752 |
+
top_k=args.top_k,
|
| 753 |
+
use_mg_gemm=args.use_mg_gemm and has_mg_gemm,
|
| 754 |
+
).to(device)
|
| 755 |
+
|
| 756 |
+
# Create optimizer
|
| 757 |
+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
| 758 |
+
|
| 759 |
+
# Log model information
|
| 760 |
+
logging.info("Model configuration:")
|
| 761 |
+
logging.info(f" - Vocabulary Size: {args.vocab_size}")
|
| 762 |
+
logging.info(f" - Embedding Dimension: {args.embed_dim}")
|
| 763 |
+
logging.info(f" - Hidden Dimension: {args.hidden_dim}")
|
| 764 |
+
logging.info(f" - Number of Experts: {args.num_experts}")
|
| 765 |
+
logging.info(f" - Top-K: {args.top_k}")
|
| 766 |
+
logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}")
|
| 767 |
+
|
| 768 |
+
# Training loop
|
| 769 |
+
for epoch in range(args.epochs):
|
| 770 |
+
logging.info(f"\nEpoch {epoch + 1}/{args.epochs}")
|
| 771 |
+
|
| 772 |
+
# Train
|
| 773 |
+
train_metrics = train_epoch(
|
| 774 |
+
model=model,
|
| 775 |
+
optimizer=optimizer,
|
| 776 |
+
batch_size=args.batch_size,
|
| 777 |
+
seq_len=args.seq_len,
|
| 778 |
+
vocab_size=args.vocab_size,
|
| 779 |
+
num_batches=args.train_batches,
|
| 780 |
+
device=device,
|
| 781 |
+
load_balance_coef=args.load_balance_coef,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Evaluate
|
| 785 |
+
eval_metrics = evaluate(
|
| 786 |
+
model=model,
|
| 787 |
+
batch_size=args.batch_size,
|
| 788 |
+
seq_len=args.seq_len,
|
| 789 |
+
vocab_size=args.vocab_size,
|
| 790 |
+
num_batches=args.eval_batches,
|
| 791 |
+
device=device,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# Log metrics
|
| 795 |
+
logging.info(
|
| 796 |
+
f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}"
|
| 797 |
+
)
|
| 798 |
+
logging.info(
|
| 799 |
+
f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}"
|
| 800 |
+
)
|
| 801 |
+
logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds")
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
if __name__ == "__main__":
|
| 805 |
+
parser = argparse.ArgumentParser(description="Train MoE model")
|
| 806 |
+
|
| 807 |
+
# Model parameters
|
| 808 |
+
parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size")
|
| 809 |
+
parser.add_argument(
|
| 810 |
+
"--embed_dim", type=int, default=512, help="Embedding dimension"
|
| 811 |
+
)
|
| 812 |
+
parser.add_argument(
|
| 813 |
+
"--hidden_dim", type=int, default=1024, help="Hidden dimension in experts"
|
| 814 |
+
)
|
| 815 |
+
parser.add_argument("--num_experts", type=int, default=8, help="Number of experts")
|
| 816 |
+
parser.add_argument(
|
| 817 |
+
"--top_k", type=int, default=2, help="Top-k experts to route to"
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
# Training parameters
|
| 821 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
| 822 |
+
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
| 823 |
+
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
|
| 824 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
| 825 |
+
parser.add_argument(
|
| 826 |
+
"--train_batches",
|
| 827 |
+
type=int,
|
| 828 |
+
default=100,
|
| 829 |
+
help="Number of training batches per epoch",
|
| 830 |
+
)
|
| 831 |
+
parser.add_argument(
|
| 832 |
+
"--eval_batches", type=int, default=20, help="Number of evaluation batches"
|
| 833 |
+
)
|
| 834 |
+
parser.add_argument(
|
| 835 |
+
"--perf_batches",
|
| 836 |
+
type=int,
|
| 837 |
+
default=50,
|
| 838 |
+
help="Number of batches for performance testing",
|
| 839 |
+
)
|
| 840 |
+
parser.add_argument(
|
| 841 |
+
"--load_balance_coef",
|
| 842 |
+
type=float,
|
| 843 |
+
default=0.01,
|
| 844 |
+
help="Load balancing loss coefficient",
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# Runtime parameters
|
| 848 |
+
parser.add_argument(
|
| 849 |
+
"--device",
|
| 850 |
+
type=str,
|
| 851 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 852 |
+
help="Device to use (cuda or cpu)",
|
| 853 |
+
)
|
| 854 |
+
parser.add_argument(
|
| 855 |
+
"--use_mg_gemm",
|
| 856 |
+
action="store_true",
|
| 857 |
+
help="Use MG GEMM implementation if available",
|
| 858 |
+
)
|
| 859 |
+
parser.add_argument(
|
| 860 |
+
"--compare",
|
| 861 |
+
action="store_true",
|
| 862 |
+
help="Compare manual and MG GEMM implementations",
|
| 863 |
+
)
|
| 864 |
+
parser.add_argument("--train", action="store_true", help="Train the model")
|
| 865 |
+
|
| 866 |
+
args = parser.parse_args()
|
| 867 |
+
|
| 868 |
+
# Check for CUDA
|
| 869 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 870 |
+
logging.warning("CUDA not available, using CPU instead.")
|
| 871 |
+
args.device = "cpu"
|
| 872 |
+
|
| 873 |
+
# Log basic information
|
| 874 |
+
logging.info(f"PyTorch version: {torch.__version__}")
|
| 875 |
+
logging.info(f"Device: {args.device}")
|
| 876 |
+
logging.info(f"MG GEMM available: {has_mg_gemm}")
|
| 877 |
+
|
| 878 |
+
# Run the requested action
|
| 879 |
+
if args.compare:
|
| 880 |
+
compare_methods(args)
|
| 881 |
+
elif args.train:
|
| 882 |
+
train_model(args)
|
| 883 |
+
else:
|
| 884 |
+
# Default to comparison if no action specified
|
| 885 |
+
compare_methods(args)
|
torchtitan/experiments/llama4/README.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**The Llama 4 folder is still under development.**
|
| 2 |
+
|
| 3 |
+
#### Available features
|
| 4 |
+
- Llama 4 model definition (text-only), including the MoE architecture with token-choice routing
|
| 5 |
+
- Basic FSDP, TP, PP, CP support
|
| 6 |
+
- DCP checkpoint conversion scripts
|
| 7 |
+
|
| 8 |
+
#### Download Llama 4 tokenizer
|
| 9 |
+
```bash
|
| 10 |
+
# Llama 4 tokenizer.model
|
| 11 |
+
python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --tokenizer_path "" --hf_token=...
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
#### To be added
|
| 15 |
+
- Modeling
|
| 16 |
+
- iRoPE implementation
|
| 17 |
+
- load balance loss for token-choice MoE
|
| 18 |
+
- alternative expert-choice MoE
|
| 19 |
+
- multimodal support
|
| 20 |
+
- Kernel integration
|
| 21 |
+
- efficient bfloat16 GroupedGEMM kernels (from PyTorch core)
|
| 22 |
+
- efficient float8 GroupedGEMM kernels (from torchao)
|
| 23 |
+
- Parallelism
|
| 24 |
+
- performant TP implementation and torch.compile support for MoE layers
|
| 25 |
+
- Context Parallel support for FlexAttention, iRoPE, and multimodal inputs
|
| 26 |
+
- Expert Parallel support
|
| 27 |
+
- Testing
|
| 28 |
+
- perfomance and loss converging tests
|
| 29 |
+
- CI integration
|
torchtitan/experiments/llama4/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
torchtitan/experiments/llama4/model/__pycache__/moe.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
torchtitan/experiments/llama4/model/args.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torchtitan.components.tokenizer import Tokenizer
|
| 13 |
+
from torchtitan.config_manager import JobConfig
|
| 14 |
+
|
| 15 |
+
from torchtitan.protocols.train_spec import BaseModelArgs
|
| 16 |
+
from torchtitan.tools.logging import logger
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class TransformerModelArgs(BaseModelArgs):
|
| 21 |
+
dim: int = 4096
|
| 22 |
+
n_layers: int = 32
|
| 23 |
+
n_heads: int = 32
|
| 24 |
+
n_kv_heads: Optional[int] = None
|
| 25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
| 26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
| 27 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 28 |
+
norm_eps: float = 1e-5
|
| 29 |
+
rope_theta: float = 10000
|
| 30 |
+
|
| 31 |
+
max_seq_len: int = 2048
|
| 32 |
+
# If `True`, then each transformer block init uses its layer ID, and if
|
| 33 |
+
# `False`, each uses the total number of transformer blocks
|
| 34 |
+
depth_init: bool = True
|
| 35 |
+
norm_type: str = "rmsnorm"
|
| 36 |
+
|
| 37 |
+
use_flex_attn: bool = False
|
| 38 |
+
attn_mask_type: str = "causal"
|
| 39 |
+
eos_id: int = 0
|
| 40 |
+
|
| 41 |
+
# MoE args
|
| 42 |
+
moe_enabled: bool = True
|
| 43 |
+
num_experts: int = 8
|
| 44 |
+
use_shared_expert: bool = True
|
| 45 |
+
auto_scale_hidden_dim: bool = True
|
| 46 |
+
# frequency of using MoE layer instead of feedforward layer in a transformer block
|
| 47 |
+
interleave_moe_layer_step: int = 2
|
| 48 |
+
# token-choice
|
| 49 |
+
top_k: int = 1
|
| 50 |
+
|
| 51 |
+
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
|
| 52 |
+
self.norm_type = job_config.model.norm_type
|
| 53 |
+
self.vocab_size = tokenizer.n_words
|
| 54 |
+
self.max_seq_len = job_config.training.seq_len
|
| 55 |
+
self.use_flex_attn = job_config.model.use_flex_attn
|
| 56 |
+
|
| 57 |
+
def get_nparams_and_flops(
|
| 58 |
+
self, model: nn.Module, seq_len: int
|
| 59 |
+
) -> tuple[int, float]:
|
| 60 |
+
nparams_embedding = 0
|
| 61 |
+
nparams_moe_router = 0
|
| 62 |
+
nparams_shared_expert = 0
|
| 63 |
+
nparams_experts = 0
|
| 64 |
+
nparams_dense = 0
|
| 65 |
+
|
| 66 |
+
for name, p in model.named_parameters():
|
| 67 |
+
if "embedding" in name:
|
| 68 |
+
nparams_embedding += p.numel()
|
| 69 |
+
nparams_dense += p.numel()
|
| 70 |
+
elif "moe.shared_expert" in name:
|
| 71 |
+
nparams_shared_expert += p.numel()
|
| 72 |
+
elif "moe.router" in name:
|
| 73 |
+
nparams_moe_router += p.numel()
|
| 74 |
+
elif "moe.experts" in name:
|
| 75 |
+
nparams_experts += p.numel()
|
| 76 |
+
else:
|
| 77 |
+
nparams_dense += p.numel()
|
| 78 |
+
|
| 79 |
+
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
|
| 80 |
+
nparams = nparams_dense + nparams_sparse
|
| 81 |
+
nparams_sparse_active = (
|
| 82 |
+
nparams_moe_router
|
| 83 |
+
+ nparams_shared_expert
|
| 84 |
+
+ nparams_experts * self.top_k // self.num_experts
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
logger.info(
|
| 88 |
+
f"Total parameter count: dense {nparams_dense:,}, "
|
| 89 |
+
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
l, h, q, t = (
|
| 93 |
+
self.n_layers,
|
| 94 |
+
self.n_heads,
|
| 95 |
+
self.dim // self.n_heads,
|
| 96 |
+
seq_len,
|
| 97 |
+
)
|
| 98 |
+
# Reasoning behind the factor of 12 for the self-attention part of the formula:
|
| 99 |
+
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
|
| 100 |
+
# 2. the flash attention does 1 more matmul recomputation in the backward
|
| 101 |
+
# but recomputation should not be counted in calculating MFU (+0)
|
| 102 |
+
# 3. each matmul performs 1 multiplication and 1 addition (*2)
|
| 103 |
+
# 4. we follow the convention and do not account for sparsity in causal attention
|
| 104 |
+
num_flops_per_token = (
|
| 105 |
+
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
|
| 106 |
+
+ 12 * l * h * q * t
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return nparams, num_flops_per_token
|
torchtitan/experiments/llama4/model/moe.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from .args import TransformerModelArgs
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GroupedExperts(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
hidden_dim: int,
|
| 19 |
+
num_experts: int,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.num_experts = num_experts
|
| 23 |
+
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
|
| 24 |
+
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
|
| 25 |
+
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
|
| 26 |
+
|
| 27 |
+
def forward(
|
| 28 |
+
self,
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
num_local_tokens_per_expert: torch.Tensor | None = None,
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
if num_local_tokens_per_expert is not None:
|
| 33 |
+
# a tuple of tensors indexed by experts
|
| 34 |
+
# each with shape (tokens_per_expert(varying), dim)
|
| 35 |
+
x = torch.split(
|
| 36 |
+
x,
|
| 37 |
+
split_size_or_sections=num_local_tokens_per_expert.tolist(),
|
| 38 |
+
dim=0,
|
| 39 |
+
)
|
| 40 |
+
out_experts_splits = []
|
| 41 |
+
for expert_idx, x_expert in enumerate(x):
|
| 42 |
+
w1, w2, w3 = (
|
| 43 |
+
self.w1[expert_idx],
|
| 44 |
+
self.w2[expert_idx],
|
| 45 |
+
self.w3[expert_idx],
|
| 46 |
+
)
|
| 47 |
+
h = F.silu(torch.matmul(x_expert, w1))
|
| 48 |
+
h = h * torch.matmul(x_expert, w3)
|
| 49 |
+
h = torch.matmul(h, w2)
|
| 50 |
+
# h shape (tokens_per_expert(varying), dim)
|
| 51 |
+
out_experts_splits.append(h)
|
| 52 |
+
out = torch.cat(out_experts_splits, dim=0)
|
| 53 |
+
|
| 54 |
+
# TODO:optimize with GroupedGEMM
|
| 55 |
+
# https://github.com/pytorch/pytorch/pull/150374
|
| 56 |
+
# _gouped_mm requires shapes to be multiple of 8
|
| 57 |
+
# offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
|
| 58 |
+
# h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
|
| 59 |
+
# h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
|
| 60 |
+
# out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
|
| 61 |
+
else:
|
| 62 |
+
# x shape (num_experts, tokens_per_expert, dim)
|
| 63 |
+
h = F.silu(torch.bmm(x, self.w1))
|
| 64 |
+
h = h * torch.bmm(x, self.w3)
|
| 65 |
+
# out shape (num_experts, tokens_per_expert, dim)
|
| 66 |
+
out = torch.bmm(h, self.w2)
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
def init_weights(self, init_std: float):
|
| 70 |
+
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
|
| 71 |
+
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
|
| 72 |
+
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TokenChoiceTopKRouter(nn.Module):
|
| 76 |
+
"""This class implements token-choice routing. In token-choice top-K routing, each token is
|
| 77 |
+
routed to top K experts based on the router scores.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
|
| 81 |
+
dim (int): Dimension of input tokens.
|
| 82 |
+
num_experts (int): Number of experts in each moe layer.
|
| 83 |
+
top_k (int): Number of experts each token will be routed to in token-choice routing.
|
| 84 |
+
use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
dim: int,
|
| 90 |
+
num_experts: int,
|
| 91 |
+
top_k: int,
|
| 92 |
+
use_sigmoid: bool = False,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.gate = nn.Linear(dim, num_experts, bias=False)
|
| 96 |
+
self.num_experts = num_experts
|
| 97 |
+
self.top_k = top_k
|
| 98 |
+
self.use_sigmoid = use_sigmoid
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self, x: torch.Tensor
|
| 102 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 103 |
+
"""
|
| 104 |
+
Args:
|
| 105 |
+
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
routed_input (torch.Tensor):
|
| 109 |
+
Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
|
| 110 |
+
token_indices (torch.Tensor):
|
| 111 |
+
Token indices for routed_input with shape ``(bs*slen*top_k,)``.
|
| 112 |
+
num_local_tokens_per_expert (torch.Tensor):
|
| 113 |
+
Number of tokens assigned to each expert with shape ``(num_experts,)``.
|
| 114 |
+
"""
|
| 115 |
+
# scores shape (bs*slen, num_experts)
|
| 116 |
+
scores = self.gate(x)
|
| 117 |
+
|
| 118 |
+
# By default, sigmoid or softmax is performed in float32 to avoid loss explosion
|
| 119 |
+
if self.use_sigmoid:
|
| 120 |
+
scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
|
| 121 |
+
else:
|
| 122 |
+
scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
|
| 123 |
+
|
| 124 |
+
# top scores shape (bs*slen, top_k)
|
| 125 |
+
top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
|
| 126 |
+
# top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
|
| 127 |
+
|
| 128 |
+
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
|
| 129 |
+
num_local_tokens_per_expert = torch.histc(
|
| 130 |
+
selected_experts_indices.view(-1),
|
| 131 |
+
bins=self.num_experts,
|
| 132 |
+
min=0,
|
| 133 |
+
max=self.num_experts,
|
| 134 |
+
)
|
| 135 |
+
# token_indices_experts_sorted shape (bs*slen*top_k,)
|
| 136 |
+
token_indices_experts_sorted = torch.argsort(
|
| 137 |
+
selected_experts_indices.view(-1), stable=True
|
| 138 |
+
)
|
| 139 |
+
top_scores = top_scores.view(-1)[token_indices_experts_sorted]
|
| 140 |
+
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
|
| 141 |
+
|
| 142 |
+
return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
|
| 143 |
+
|
| 144 |
+
def init_weights(self, init_std: float):
|
| 145 |
+
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# TODO: implement load balancing auxiliary loss for token-choice routing
|
| 149 |
+
class MoE(nn.Module):
|
| 150 |
+
def __init__(self, model_args: TransformerModelArgs):
|
| 151 |
+
super().__init__()
|
| 152 |
+
dim = model_args.dim
|
| 153 |
+
hidden_dim = 4 * model_args.dim
|
| 154 |
+
ffn_dim_multiplier = model_args.ffn_dim_multiplier
|
| 155 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 156 |
+
if ffn_dim_multiplier is not None:
|
| 157 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 158 |
+
|
| 159 |
+
num_experts = model_args.num_experts
|
| 160 |
+
|
| 161 |
+
hidden_dim_denom = 1
|
| 162 |
+
if model_args.auto_scale_hidden_dim:
|
| 163 |
+
hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert)
|
| 164 |
+
|
| 165 |
+
if model_args.auto_scale_hidden_dim:
|
| 166 |
+
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
| 167 |
+
hidden_dim += -hidden_dim % model_args.multiple_of
|
| 168 |
+
|
| 169 |
+
self.experts = GroupedExperts(
|
| 170 |
+
dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
|
| 171 |
+
)
|
| 172 |
+
self.router = TokenChoiceTopKRouter(
|
| 173 |
+
dim=dim, num_experts=num_experts, top_k=model_args.top_k
|
| 174 |
+
)
|
| 175 |
+
self.shared_expert = (
|
| 176 |
+
GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
|
| 177 |
+
if model_args.use_shared_expert
|
| 178 |
+
else None
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
|
| 188 |
+
"""
|
| 189 |
+
bs, slen, dim = x.shape
|
| 190 |
+
# top_scores and selected_indices shape (bs*slen*top_k,)
|
| 191 |
+
# num_local_tokens_per_expert shape (num_experts,)
|
| 192 |
+
(
|
| 193 |
+
top_scores,
|
| 194 |
+
token_indices,
|
| 195 |
+
num_local_tokens_per_expert,
|
| 196 |
+
) = self.router(x.reshape(bs * slen, dim))
|
| 197 |
+
|
| 198 |
+
# shape (bs*slen*top_k, dim)
|
| 199 |
+
token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
|
| 200 |
+
|
| 201 |
+
# shape (bs*slen*top_k, dim)
|
| 202 |
+
routed_input = torch.gather(
|
| 203 |
+
x.view(-1, dim),
|
| 204 |
+
dim=0,
|
| 205 |
+
index=token_indices,
|
| 206 |
+
)
|
| 207 |
+
routed_input = routed_input * top_scores.reshape(-1, 1)
|
| 208 |
+
|
| 209 |
+
# shape (bs*slen*top_k, dim)
|
| 210 |
+
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
|
| 211 |
+
|
| 212 |
+
# shared expert
|
| 213 |
+
if self.shared_expert is not None:
|
| 214 |
+
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
|
| 215 |
+
bs * slen, dim
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
out = torch.zeros_like(x.reshape(bs * slen, dim))
|
| 219 |
+
|
| 220 |
+
out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
|
| 221 |
+
out = out.reshape(bs, slen, dim)
|
| 222 |
+
return out
|
| 223 |
+
|
| 224 |
+
def init_weights(self, init_std: float):
|
| 225 |
+
self.experts.init_weights(init_std)
|
| 226 |
+
self.router.init_weights(init_std)
|
| 227 |
+
if self.shared_expert is not None:
|
| 228 |
+
self.shared_expert.init_weights(init_std)
|
torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from typing import Any, Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor, Shard
|
| 16 |
+
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
|
| 17 |
+
from torchtitan.components.checkpoint import MODEL
|
| 18 |
+
from torchtitan.config_manager import JobConfig
|
| 19 |
+
from torchtitan.tools.logging import init_logger, logger
|
| 20 |
+
from torchtitan.train import Trainer
|
| 21 |
+
|
| 22 |
+
# Sharding dims for MP checkpoints
|
| 23 |
+
|
| 24 |
+
column_parallel = [
|
| 25 |
+
"tok_embeddings",
|
| 26 |
+
"wq",
|
| 27 |
+
"wk",
|
| 28 |
+
"wv",
|
| 29 |
+
"wqkv",
|
| 30 |
+
"w_in_shared_FD",
|
| 31 |
+
"w_out_eF_D",
|
| 32 |
+
"w_swiglu_FD",
|
| 33 |
+
"output",
|
| 34 |
+
"_linear",
|
| 35 |
+
"c_fc",
|
| 36 |
+
"vision_projection",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
row_parallel = [
|
| 40 |
+
"wo",
|
| 41 |
+
"w_out_shared_DF",
|
| 42 |
+
"w_in_eD_F",
|
| 43 |
+
"moe_w_swiglu_eD_F",
|
| 44 |
+
"c_proj",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def convert_to_titan_fqns(fqn: str) -> list[str]:
|
| 49 |
+
# From the stored checkpoint keys to TorchTitan keys.
|
| 50 |
+
if "wqkv" in fqn and "layer_norm_weight" not in fqn:
|
| 51 |
+
ret = []
|
| 52 |
+
for k in ("wq", "wk", "wv"):
|
| 53 |
+
ret.append(fqn.replace("wqkv", k))
|
| 54 |
+
return ret
|
| 55 |
+
return [fqn]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_shard_dim(fqn: str) -> Optional[int]:
|
| 59 |
+
if "bias" in fqn:
|
| 60 |
+
# Some bias params are still sharded
|
| 61 |
+
if "resblocks" in fqn:
|
| 62 |
+
for k in ("wq", "wk", "wv", "c_fc"):
|
| 63 |
+
if k in fqn:
|
| 64 |
+
return 0
|
| 65 |
+
return None
|
| 66 |
+
elif any([x in fqn for x in column_parallel]):
|
| 67 |
+
return 0
|
| 68 |
+
elif any([x in fqn for x in row_parallel]):
|
| 69 |
+
return 1
|
| 70 |
+
else:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def split_fused_qkv(shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
|
| 75 |
+
qkvs = [torch.split(shard, [640, 128, 128]) for shard in shards]
|
| 76 |
+
q = torch.cat([qkv[0] for qkv in qkvs], dim=0)
|
| 77 |
+
k = torch.cat([qkv[1] for qkv in qkvs], dim=0)
|
| 78 |
+
v = torch.cat([qkv[2] for qkv in qkvs], dim=0)
|
| 79 |
+
return q, k, v
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class _Assignment:
|
| 84 |
+
loader_id: int
|
| 85 |
+
filename: str
|
| 86 |
+
fqns: tuple[str, ...]
|
| 87 |
+
shapes: tuple[torch.Size, ...]
|
| 88 |
+
dtypes: tuple[torch.dtype, ...]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class _AssignmentRound:
|
| 93 |
+
loader_assignments: dict[int, _Assignment] # List of assignments for each loader
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class CheckpointConverter:
|
| 97 |
+
TOTAL_SHARDS = 8
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
process_group: dist.ProcessGroup,
|
| 102 |
+
path: str,
|
| 103 |
+
loader_every_n_ranks: int = 8,
|
| 104 |
+
) -> None:
|
| 105 |
+
self.path = path
|
| 106 |
+
self.pg = process_group
|
| 107 |
+
self.my_rank = dist.get_rank(self.pg)
|
| 108 |
+
self.loader_every_n_ranks = loader_every_n_ranks
|
| 109 |
+
self.loader_id = self.my_rank // loader_every_n_ranks
|
| 110 |
+
self.should_load = (
|
| 111 |
+
self.my_rank % loader_every_n_ranks == 0
|
| 112 |
+
and self.loader_id < CheckpointConverter.TOTAL_SHARDS
|
| 113 |
+
)
|
| 114 |
+
self.total_loader = CheckpointConverter.TOTAL_SHARDS
|
| 115 |
+
self.titan_fqn_to_stored_fqn: dict[str, str] = {}
|
| 116 |
+
self.stored_fqn_to_titan_fqn: dict[str, list[str]] = {}
|
| 117 |
+
self.total_send_bytes = 0
|
| 118 |
+
self.total_recv_bytes = 0
|
| 119 |
+
|
| 120 |
+
def convert(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 121 |
+
begin = time.time()
|
| 122 |
+
self._load_metadata()
|
| 123 |
+
self._create_fqn_mappings(state_dict)
|
| 124 |
+
rounds = self._get_load_assignments(state_dict)
|
| 125 |
+
|
| 126 |
+
for assignments in rounds:
|
| 127 |
+
loader_assignments = assignments.loader_assignments
|
| 128 |
+
loaded_state_dict = None
|
| 129 |
+
# Let each loader to load its own data and move to its GPU.
|
| 130 |
+
for i in range(self.total_loader):
|
| 131 |
+
# This loader doesn't have any loading assignment for this round.
|
| 132 |
+
if i not in loader_assignments:
|
| 133 |
+
continue
|
| 134 |
+
# This rank is not the loader
|
| 135 |
+
if i != self.loader_id or not self.should_load:
|
| 136 |
+
continue
|
| 137 |
+
loaded_state_dict = self._load_round(loader_assignments[i])
|
| 138 |
+
|
| 139 |
+
results = []
|
| 140 |
+
for i in range(self.total_loader):
|
| 141 |
+
if i not in loader_assignments:
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
if i == self.loader_id and self.should_load:
|
| 145 |
+
# This rank is the loader. It needs to send the loaded data to
|
| 146 |
+
# the other ranks.
|
| 147 |
+
assert loaded_state_dict is not None
|
| 148 |
+
results.append(
|
| 149 |
+
self._reshard_send(loader_assignments[i], loaded_state_dict)
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
results.append(
|
| 153 |
+
self._reshard_receive(loader_assignments[i], state_dict)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self._reshard(results, state_dict)
|
| 157 |
+
|
| 158 |
+
torch.cuda.synchronize()
|
| 159 |
+
logger.info(f"Checkpoint conversion took {time.time() - begin:.2f} seconds.")
|
| 160 |
+
logger.info(f"Total send bytes: {self.total_send_bytes / 1e9:.2f} GB")
|
| 161 |
+
logger.info(f"Total recv bytes: {self.total_recv_bytes / 1e9:.2f} GB")
|
| 162 |
+
return state_dict
|
| 163 |
+
|
| 164 |
+
def _get_file_path(self, loader_id: int) -> str:
|
| 165 |
+
return os.path.join(self.path, f"consolidated.0{loader_id}.pth")
|
| 166 |
+
|
| 167 |
+
def _load_metadata(self) -> None:
|
| 168 |
+
if not self.should_load:
|
| 169 |
+
self.read_dict = {}
|
| 170 |
+
return
|
| 171 |
+
self.read_dict = torch.load(
|
| 172 |
+
self._get_file_path(self.loader_id),
|
| 173 |
+
mmap=True,
|
| 174 |
+
weights_only=False,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def _create_fqn_mappings(self, state_dict: dict[str, torch.Tensor]) -> None:
|
| 178 |
+
if not self.read_dict:
|
| 179 |
+
return
|
| 180 |
+
|
| 181 |
+
# Create the mapping from the stored checkpoint keys to TorchTitan keys.
|
| 182 |
+
for fqn in list(self.read_dict.keys()):
|
| 183 |
+
titan_fqns = convert_to_titan_fqns(fqn)
|
| 184 |
+
# We don't know how to process _extra_state
|
| 185 |
+
if "_extra_state" in fqn:
|
| 186 |
+
self.read_dict.pop(fqn)
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
if titan_fqns[0] not in state_dict:
|
| 190 |
+
for titan_fqn in titan_fqns:
|
| 191 |
+
assert titan_fqns[0] not in state_dict
|
| 192 |
+
self.read_dict.pop(fqn)
|
| 193 |
+
continue
|
| 194 |
+
self.stored_fqn_to_titan_fqn[fqn] = titan_fqns
|
| 195 |
+
for titan_fqn in titan_fqns:
|
| 196 |
+
self.titan_fqn_to_stored_fqn[titan_fqn] = fqn
|
| 197 |
+
|
| 198 |
+
assert set(state_dict.keys()) == set(self.titan_fqn_to_stored_fqn.keys()), (
|
| 199 |
+
set(state_dict.keys()) - set(self.titan_fqn_to_stored_fqn.keys()),
|
| 200 |
+
set(self.titan_fqn_to_stored_fqn.keys()) - set(state_dict.keys()),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def _get_load_assignments(
|
| 204 |
+
self, state_dict: dict[str, torch.Tensor]
|
| 205 |
+
) -> list[_AssignmentRound]:
|
| 206 |
+
if self.my_rank == 0:
|
| 207 |
+
rounds: list[_AssignmentRound] = []
|
| 208 |
+
size = 0
|
| 209 |
+
fqns = []
|
| 210 |
+
shapes = []
|
| 211 |
+
dtypes = []
|
| 212 |
+
|
| 213 |
+
# All loader must load all the FQNs because the checkpoint is purely TP sharded.
|
| 214 |
+
all_keys = list(self.read_dict.keys())
|
| 215 |
+
for fqn in all_keys:
|
| 216 |
+
fqns.append(fqn)
|
| 217 |
+
shapes.append(self.read_dict[fqn].shape)
|
| 218 |
+
dtypes.append(self.read_dict[fqn].dtype)
|
| 219 |
+
size += self.read_dict[fqn].numel() * self.read_dict[fqn].element_size()
|
| 220 |
+
if size < 1e9 and fqn != all_keys[-1]:
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
logger.info(f"Adding {fqns} to round {len(rounds)}")
|
| 224 |
+
round_assignment = _AssignmentRound(loader_assignments={})
|
| 225 |
+
for loader_id in range(self.total_loader):
|
| 226 |
+
path = self._get_file_path(loader_id)
|
| 227 |
+
round_assignment.loader_assignments[loader_id] = _Assignment(
|
| 228 |
+
filename=path,
|
| 229 |
+
fqns=tuple(fqns),
|
| 230 |
+
shapes=tuple(shapes),
|
| 231 |
+
dtypes=tuple(dtypes),
|
| 232 |
+
loader_id=loader_id,
|
| 233 |
+
)
|
| 234 |
+
rounds.append(round_assignment)
|
| 235 |
+
size = 0
|
| 236 |
+
fqns.clear()
|
| 237 |
+
shapes.clear()
|
| 238 |
+
dtypes.clear()
|
| 239 |
+
|
| 240 |
+
object_list: list[Any] = [
|
| 241 |
+
rounds,
|
| 242 |
+
self.titan_fqn_to_stored_fqn,
|
| 243 |
+
self.stored_fqn_to_titan_fqn,
|
| 244 |
+
]
|
| 245 |
+
else:
|
| 246 |
+
object_list = [None, None, None]
|
| 247 |
+
|
| 248 |
+
dist.broadcast_object_list(object_list, src=0, group=self.pg)
|
| 249 |
+
rounds = object_list[0]
|
| 250 |
+
self.titan_fqn_to_stored_fqn = object_list[1]
|
| 251 |
+
self.stored_fqn_to_titan_fqn = object_list[2]
|
| 252 |
+
return rounds
|
| 253 |
+
|
| 254 |
+
def _load_round(self, assignment: _Assignment) -> dict[str, torch.Tensor]:
|
| 255 |
+
ret = {}
|
| 256 |
+
assert self.read_dict
|
| 257 |
+
for fqn in assignment.fqns:
|
| 258 |
+
ret[fqn] = self.read_dict[fqn].to(device="cuda")
|
| 259 |
+
return ret
|
| 260 |
+
|
| 261 |
+
def _reshard_send(
|
| 262 |
+
self,
|
| 263 |
+
assignment: _Assignment,
|
| 264 |
+
loaded_state_dict: dict[str, torch.Tensor],
|
| 265 |
+
) -> dict[str, torch.Tensor]:
|
| 266 |
+
flatten_tensors = [t.flatten() for t in loaded_state_dict.values()]
|
| 267 |
+
flatten_tensor = torch.concat(flatten_tensors)
|
| 268 |
+
assert self.loader_id == assignment.loader_id
|
| 269 |
+
rank = self.loader_id * self.loader_every_n_ranks
|
| 270 |
+
assert rank == self.my_rank
|
| 271 |
+
logger.info(f"Sending {assignment.filename} from {rank} {self.loader_id}")
|
| 272 |
+
logger.info(f"Sending {assignment.fqns}")
|
| 273 |
+
dist.broadcast(flatten_tensor, src=rank, group=self.pg)
|
| 274 |
+
self.total_send_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
|
| 275 |
+
return loaded_state_dict
|
| 276 |
+
|
| 277 |
+
def _reshard_receive(
|
| 278 |
+
self, assignment: _Assignment, state_dict: dict[str, torch.Tensor]
|
| 279 |
+
) -> dict[str, torch.Tensor]:
|
| 280 |
+
flatten_tensor = torch.empty(
|
| 281 |
+
sum(math.prod(s) for s, d in zip(assignment.shapes, assignment.dtypes)),
|
| 282 |
+
dtype=assignment.dtypes[0],
|
| 283 |
+
device="cuda",
|
| 284 |
+
)
|
| 285 |
+
rank = assignment.loader_id * self.loader_every_n_ranks
|
| 286 |
+
dist.broadcast(flatten_tensor, src=rank, group=self.pg)
|
| 287 |
+
self.total_recv_bytes += flatten_tensor.numel() * flatten_tensor.element_size()
|
| 288 |
+
|
| 289 |
+
ret: dict[str, torch.Tensor] = {}
|
| 290 |
+
loc = 0
|
| 291 |
+
for fqn, shape, dtype in zip(
|
| 292 |
+
assignment.fqns, assignment.shapes, assignment.dtypes
|
| 293 |
+
):
|
| 294 |
+
n_ele = math.prod(shape)
|
| 295 |
+
ret[fqn] = flatten_tensor[loc : loc + n_ele].view(shape)
|
| 296 |
+
loc += n_ele
|
| 297 |
+
return ret
|
| 298 |
+
|
| 299 |
+
def _reshard(
|
| 300 |
+
self,
|
| 301 |
+
results: list[dict[str, torch.Tensor]],
|
| 302 |
+
state_dict: dict[str, torch.Tensor],
|
| 303 |
+
) -> None:
|
| 304 |
+
def _inplace_copy(fqn: str, full_tensors: tuple[torch.Tensor, ...]):
|
| 305 |
+
titan_fqns = self.stored_fqn_to_titan_fqn[fqn]
|
| 306 |
+
assert len(titan_fqns) == len(full_tensors)
|
| 307 |
+
for titan_fqn, full_tensor in zip(titan_fqns, full_tensors):
|
| 308 |
+
dtensor = state_dict[titan_fqn]
|
| 309 |
+
logger.info(f"{titan_fqn} {full_tensor.sum()}")
|
| 310 |
+
assert isinstance(dtensor, DTensor)
|
| 311 |
+
shape, offset = compute_local_shape_and_global_offset(
|
| 312 |
+
full_tensor.shape, dtensor.device_mesh, dtensor.placements
|
| 313 |
+
)
|
| 314 |
+
slices = [
|
| 315 |
+
slice(cur_offset, cur_offset + cur_shape)
|
| 316 |
+
for cur_shape, cur_offset in zip(shape, offset)
|
| 317 |
+
]
|
| 318 |
+
logger.info(
|
| 319 |
+
f"Copying {titan_fqn} with {slices=} {dtensor._local_tensor.shape=} "
|
| 320 |
+
f"{shape=} {offset=} {self.my_rank=} {dtensor.shape=} {full_tensor.shape=} "
|
| 321 |
+
f"{dtensor.placements=} {dtensor.device_mesh=} "
|
| 322 |
+
)
|
| 323 |
+
dtensor.to_local().copy_(full_tensor[slices])
|
| 324 |
+
|
| 325 |
+
def _concat_shards(fqn, shards: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
|
| 326 |
+
if "wqkv" in fqn:
|
| 327 |
+
if "layer_norm" in fqn:
|
| 328 |
+
return (shards[0],)
|
| 329 |
+
return split_fused_qkv(shards)
|
| 330 |
+
|
| 331 |
+
shard_dim = get_shard_dim(fqn)
|
| 332 |
+
if shard_dim is None:
|
| 333 |
+
return (shards[0],)
|
| 334 |
+
return (torch.cat(shards, dim=shard_dim),)
|
| 335 |
+
|
| 336 |
+
fqns = list(results[0].keys())
|
| 337 |
+
for result in results:
|
| 338 |
+
assert list(result.keys()) == fqns
|
| 339 |
+
|
| 340 |
+
for fqn in fqns:
|
| 341 |
+
full_tensors = _concat_shards(fqn, [result[fqn] for result in results])
|
| 342 |
+
_inplace_copy(fqn, full_tensors)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _create_verified_state_dict(
|
| 346 |
+
pg: dist.ProcessGroup, mesh: DeviceMesh
|
| 347 |
+
) -> dict[str, torch.Tensor]:
|
| 348 |
+
placements = [Shard(0)]
|
| 349 |
+
state_dict = {
|
| 350 |
+
"tok_embeddings.weight": torch.rand(
|
| 351 |
+
25256 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
| 352 |
+
),
|
| 353 |
+
"layers.47.attention.wqkv.layer_norm_weight": torch.rand(
|
| 354 |
+
5120, device="cuda", dtype=torch.bfloat16
|
| 355 |
+
),
|
| 356 |
+
"layers.47.attention.wq.weight": torch.rand(
|
| 357 |
+
640 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
| 358 |
+
),
|
| 359 |
+
"layers.47.attention.wk.weight": torch.rand(
|
| 360 |
+
128 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
| 361 |
+
),
|
| 362 |
+
"layers.47.attention.wv.weight": torch.rand(
|
| 363 |
+
128 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
| 364 |
+
),
|
| 365 |
+
"layers.47.attention.wo.weight": torch.rand(
|
| 366 |
+
5120, 640 * 8, device="cuda", dtype=torch.bfloat16
|
| 367 |
+
),
|
| 368 |
+
# "layers.47.feed_forward.router_DE": torch.rand(5120, 128, device="cuda", dtype=torch.bfloat16),
|
| 369 |
+
# "layers.47.feed_forward.running_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
|
| 370 |
+
# "layers.47.feed_forward.global_gate_stats_3E": torch.rand(3, 128, device="cuda", dtype=torch.bfloat16),
|
| 371 |
+
"layers.47.feed_forward.w_in_shared_FD.weight": torch.rand(
|
| 372 |
+
1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
| 373 |
+
),
|
| 374 |
+
"layers.47.feed_forward.w_out_shared_DF.weight": torch.rand(
|
| 375 |
+
5120, 1024 * 8, device="cuda", dtype=torch.bfloat16
|
| 376 |
+
),
|
| 377 |
+
"layers.47.feed_forward.w_swiglu_FD.weight": torch.rand(
|
| 378 |
+
1024 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
| 379 |
+
),
|
| 380 |
+
"layers.47.feed_forward.norm.weight": torch.rand(
|
| 381 |
+
5120, device="cuda", dtype=torch.bfloat16
|
| 382 |
+
),
|
| 383 |
+
"layers.47.feed_forward.experts.moe_w_in_eD_F": torch.rand(
|
| 384 |
+
655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
|
| 385 |
+
),
|
| 386 |
+
"layers.47.feed_forward.experts.moe_w_out_eF_D": torch.rand(
|
| 387 |
+
131072 * 8, 5120, device="cuda", dtype=torch.bfloat16
|
| 388 |
+
),
|
| 389 |
+
"layers.47.feed_forward.experts.moe_w_swiglu_eD_F": torch.rand(
|
| 390 |
+
655360, 1024 * 8, device="cuda", dtype=torch.bfloat16
|
| 391 |
+
),
|
| 392 |
+
}
|
| 393 |
+
return {k: distribute_tensor(v, mesh, placements) for k, v in state_dict.items()}
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def _verify_state_dict(
|
| 397 |
+
state_dict: dict[str, torch.Tensor], path: str, rank: int
|
| 398 |
+
) -> None:
|
| 399 |
+
stored_state_dicts = [
|
| 400 |
+
torch.load(
|
| 401 |
+
os.path.join(path, f"consolidated.0{i}.pth"),
|
| 402 |
+
map_location="cpu",
|
| 403 |
+
weights_only=False,
|
| 404 |
+
mmap=True,
|
| 405 |
+
)
|
| 406 |
+
for i in range(8)
|
| 407 |
+
]
|
| 408 |
+
|
| 409 |
+
def read_and_verify_tensor(fqn: str, dtensor: DTensor) -> None:
|
| 410 |
+
logger.info(f"Verifying {fqn} {dtensor.shape=} {dtensor.placements=} ")
|
| 411 |
+
shards = [stored_state_dicts[i][fqn] for i in range(8)]
|
| 412 |
+
full_tensor = dtensor.full_tensor()
|
| 413 |
+
logger.info(f"Gather {fqn} {full_tensor.shape} completely.")
|
| 414 |
+
|
| 415 |
+
if rank > 0:
|
| 416 |
+
return
|
| 417 |
+
|
| 418 |
+
if len(shards[0].shape) == 1:
|
| 419 |
+
assert full_tensor.shape == shards[0].shape, fqn
|
| 420 |
+
assert torch.allclose(shards[0].to(device="cuda"), full_tensor), fqn
|
| 421 |
+
return
|
| 422 |
+
elif shards[0].shape[0] == full_tensor.shape[0]:
|
| 423 |
+
concat_shards = torch.cat(shards, dim=1)
|
| 424 |
+
logger.info(f"Load {fqn} completely.")
|
| 425 |
+
elif shards[0].shape[1] == full_tensor.shape[1]:
|
| 426 |
+
concat_shards = torch.cat(shards, dim=0)
|
| 427 |
+
logger.info(f"Load {fqn} completely.")
|
| 428 |
+
|
| 429 |
+
concat_shards = concat_shards.to(device="cuda")
|
| 430 |
+
logger.info(f"Move to GPU {fqn} completely.")
|
| 431 |
+
|
| 432 |
+
assert concat_shards.shape == full_tensor.shape, fqn
|
| 433 |
+
assert concat_shards.dtype == full_tensor.dtype, fqn
|
| 434 |
+
assert concat_shards.device == full_tensor.device, fqn
|
| 435 |
+
assert torch.allclose(concat_shards, full_tensor), fqn
|
| 436 |
+
|
| 437 |
+
for k, v in state_dict.items():
|
| 438 |
+
if "wq" in k and "wqkv" not in k:
|
| 439 |
+
pass
|
| 440 |
+
elif "wk" in k:
|
| 441 |
+
pass
|
| 442 |
+
elif "wv" in k:
|
| 443 |
+
pass
|
| 444 |
+
else:
|
| 445 |
+
assert v is not None, k
|
| 446 |
+
read_and_verify_tensor(k, v)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
if __name__ == "__main__":
|
| 450 |
+
init_logger()
|
| 451 |
+
config = JobConfig()
|
| 452 |
+
config.parser.add_argument(
|
| 453 |
+
"--checkpoint.convert_path",
|
| 454 |
+
type=str,
|
| 455 |
+
default="",
|
| 456 |
+
help="""Specify the path of the target checkpoint to convert.""",
|
| 457 |
+
)
|
| 458 |
+
config.parser.add_argument(
|
| 459 |
+
"--checkpoint.convert_load_every_n_ranks",
|
| 460 |
+
type=int,
|
| 461 |
+
default=8,
|
| 462 |
+
help="""
|
| 463 |
+
Specify the interval at which ranks are assigned to load checkpoints.
|
| 464 |
+
|
| 465 |
+
For example, if this number is 4, then ranks 0, 4, 8, ... will load the
|
| 466 |
+
checkpoint. Each loader is responsible for loading one file. If there
|
| 467 |
+
are more loaders than files, only the first few loaders will be assigned
|
| 468 |
+
to load the checkpoint. The default value is 8.
|
| 469 |
+
""",
|
| 470 |
+
)
|
| 471 |
+
config.parser.add_argument(
|
| 472 |
+
"--checkpoint.fake_model",
|
| 473 |
+
action="store_true",
|
| 474 |
+
help="""If true, the model will be fake.""",
|
| 475 |
+
)
|
| 476 |
+
config.parse_args()
|
| 477 |
+
assert config.checkpoint.convert_path != ""
|
| 478 |
+
|
| 479 |
+
trainer: Optional[Trainer] = None
|
| 480 |
+
|
| 481 |
+
try:
|
| 482 |
+
trainer = Trainer(config)
|
| 483 |
+
if os.path.exists(trainer.checkpointer.folder):
|
| 484 |
+
raise RuntimeError(
|
| 485 |
+
"The checkpoint folder already exists. Abort to avoid overwriting "
|
| 486 |
+
f"the checkpoint. {trainer.checkpointer.folder=}"
|
| 487 |
+
)
|
| 488 |
+
if config.checkpoint.fake_model:
|
| 489 |
+
state_dict = _create_verified_state_dict(
|
| 490 |
+
trainer.world_mesh.get_group(), trainer.world_mesh
|
| 491 |
+
)
|
| 492 |
+
else:
|
| 493 |
+
state_dict = trainer.checkpointer.states[MODEL].state_dict()
|
| 494 |
+
|
| 495 |
+
size = 0
|
| 496 |
+
for v in state_dict.values():
|
| 497 |
+
size += v.numel() * v.element_size()
|
| 498 |
+
logger.info(f"Total size of the model: {size / 1e9:.2f} GB")
|
| 499 |
+
|
| 500 |
+
# Do not support PP yet, we will need to iterate over the PP dimension and
|
| 501 |
+
# extract the corresponding state_dict and device_mesh.
|
| 502 |
+
if "freq_cis" in state_dict:
|
| 503 |
+
state_dict.pop("freqs_cis")
|
| 504 |
+
|
| 505 |
+
state_dict = CheckpointConverter(
|
| 506 |
+
process_group=trainer.world_mesh.get_group(),
|
| 507 |
+
path=config.checkpoint.convert_path,
|
| 508 |
+
loader_every_n_ranks=config.checkpoint.convert_load_every_n_ranks,
|
| 509 |
+
).convert(state_dict)
|
| 510 |
+
|
| 511 |
+
class DummyModel:
|
| 512 |
+
def __init__(self, state_dict: dict[str, torch.Tensor]) -> None:
|
| 513 |
+
self._state_dict = state_dict
|
| 514 |
+
|
| 515 |
+
def state_dict(self) -> dict[str, torch.Tensor]:
|
| 516 |
+
return self._state_dict
|
| 517 |
+
|
| 518 |
+
if config.checkpoint.fake_model:
|
| 519 |
+
begin = time.time()
|
| 520 |
+
_verify_state_dict(
|
| 521 |
+
state_dict,
|
| 522 |
+
config.checkpoint.convert_path,
|
| 523 |
+
trainer.world_mesh.get_rank(),
|
| 524 |
+
)
|
| 525 |
+
dist.barrier()
|
| 526 |
+
logger.info(f"Verifies state_dict {time.time() - begin}.")
|
| 527 |
+
else:
|
| 528 |
+
# oh, this is pretty bad, when can we get rid of the freqs_cis issue?
|
| 529 |
+
state_dict["freqs_cis"] = None
|
| 530 |
+
trainer.checkpointer.states[MODEL] = DummyModel(state_dict)
|
| 531 |
+
trainer.checkpointer.model_weights_only = True
|
| 532 |
+
trainer.checkpointer.export_dtype = next(iter(state_dict.values())).dtype
|
| 533 |
+
trainer.checkpointer.save(curr_step=0, force=True)
|
| 534 |
+
time.sleep(2)
|
| 535 |
+
finally:
|
| 536 |
+
pass
|
torchtitan/experiments/multimodal/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from mm_dataset import build_mm_dataloader
|
| 8 |
+
|
| 9 |
+
from torchtitan.components.loss import build_cross_entropy_loss
|
| 10 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
| 11 |
+
from torchtitan.components.optimizer import build_optimizers
|
| 12 |
+
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
|
| 13 |
+
from torchtitan.models.llama3 import parallelize_llama, pipeline_llama
|
| 14 |
+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
|
| 15 |
+
|
| 16 |
+
from .model import ModelArgs, MultimodalDecoder, VisionEncoder
|
| 17 |
+
|
| 18 |
+
__all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"]
|
| 19 |
+
|
| 20 |
+
llama4_mm_configs = {
|
| 21 |
+
# TODO: add configs for llama4 multimodal
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
register_train_spec(
|
| 25 |
+
TrainSpec(
|
| 26 |
+
name="llama4_multimodal",
|
| 27 |
+
cls=MultimodalDecoder,
|
| 28 |
+
config=llama4_mm_configs,
|
| 29 |
+
parallelize_fn=parallelize_llama,
|
| 30 |
+
pipelining_fn=pipeline_llama,
|
| 31 |
+
build_optimizers_fn=build_optimizers,
|
| 32 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
| 33 |
+
build_dataloader_fn=build_mm_dataloader,
|
| 34 |
+
build_tokenizer_fn=build_tiktoken_tokenizer,
|
| 35 |
+
build_loss_fn=build_cross_entropy_loss,
|
| 36 |
+
)
|
| 37 |
+
)
|
torchtitan/experiments/multimodal/mm_collator.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from tokenizer.tiktoken import IGNORE_INDEX
|
| 16 |
+
|
| 17 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def padded_collate(
|
| 21 |
+
batch: List[Dict[str, List[int]]],
|
| 22 |
+
padding_idx: int = 0,
|
| 23 |
+
ignore_idx: int = -100,
|
| 24 |
+
) -> Dict[str, torch.Tensor]:
|
| 25 |
+
"""Pad a batch of sequences to the longest sequence length in the batch, and
|
| 26 |
+
convert integer lists to tensors.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
|
| 30 |
+
padding_idx (int): Padding index for input ids. Defaults to 0.
|
| 31 |
+
ignore_idx (int): Padding index for labels. Defaults to -100.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Dict[str, torch.Tensor]: Collated input and label tensors.
|
| 35 |
+
|
| 36 |
+
Example:
|
| 37 |
+
>>> token_pairs = [
|
| 38 |
+
>>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
|
| 39 |
+
>>> {"input_ids": [7,], "labels": [10,]},
|
| 40 |
+
>>> ]
|
| 41 |
+
>>> collated = padded_collate(
|
| 42 |
+
>>> batch=token_pairs,
|
| 43 |
+
>>> padding_idx=padding_idx,
|
| 44 |
+
>>> ignore_idx=ignore_idx,
|
| 45 |
+
>>> )
|
| 46 |
+
>>> collated["input_ids"]
|
| 47 |
+
>>> tensor([[1, 2, 3], [7, 0, 0]])
|
| 48 |
+
>>> collated["labels"]
|
| 49 |
+
>>> tensor([[4, 5, 6], [10, -100, -100]])
|
| 50 |
+
"""
|
| 51 |
+
input_ids = pad_sequence(
|
| 52 |
+
[x["input_ids"] for x in batch],
|
| 53 |
+
batch_first=True,
|
| 54 |
+
padding_value=padding_idx,
|
| 55 |
+
)
|
| 56 |
+
labels = pad_sequence(
|
| 57 |
+
[x["labels"] for x in batch],
|
| 58 |
+
batch_first=True,
|
| 59 |
+
padding_value=ignore_idx,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
input_ids_seq_len = input_ids.shape[-1]
|
| 63 |
+
labels_seq_len = labels.shape[-1]
|
| 64 |
+
|
| 65 |
+
# Hack to pad correctly and not use max_seq_len, which is costly
|
| 66 |
+
if input_ids_seq_len > labels_seq_len:
|
| 67 |
+
labels = F.pad(
|
| 68 |
+
labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx
|
| 69 |
+
)
|
| 70 |
+
elif labels_seq_len > input_ids_seq_len:
|
| 71 |
+
input_ids = F.pad(
|
| 72 |
+
input_ids,
|
| 73 |
+
(0, labels_seq_len - input_ids_seq_len),
|
| 74 |
+
value=padding_idx,
|
| 75 |
+
)
|
| 76 |
+
return {"input_ids": input_ids, "labels": labels}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# NOTE Inspired from torchtune.data._collate.py
|
| 80 |
+
@dataclass
|
| 81 |
+
class MultiModalCollator:
|
| 82 |
+
padding_idx: int = 128004
|
| 83 |
+
ignore_idx: int = IGNORE_INDEX
|
| 84 |
+
pad_max_tiles: Optional[int] = None
|
| 85 |
+
pad_max_images: Optional[int] = None
|
| 86 |
+
|
| 87 |
+
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 88 |
+
"""Pad a batch of text sequences, tiled image tensors, aspect ratios,
|
| 89 |
+
and cross attention masks. This can be used for both training and inference.
|
| 90 |
+
|
| 91 |
+
``batch`` is expected to be a list of sample dicts containing the following::
|
| 92 |
+
- "input_ids": List[int] of length text_seq_len, varies across samples
|
| 93 |
+
- "labels": List[int] of length text_seq_len, varies across samples
|
| 94 |
+
- "encoder_input": Dict[str, List[torch.Tensor]]
|
| 95 |
+
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w)
|
| 96 |
+
- "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio
|
| 97 |
+
|
| 98 |
+
Shape notation:
|
| 99 |
+
- c = channel dim
|
| 100 |
+
- h = height dim
|
| 101 |
+
- w = weight dim
|
| 102 |
+
|
| 103 |
+
Note:
|
| 104 |
+
For each element in the batch, ``len(images) == len(aspect_ratio)``.
|
| 105 |
+
|
| 106 |
+
This collater does the following:
|
| 107 |
+
(1) Pad text sequence and encoder mask to the longest sequence length in the batch
|
| 108 |
+
(2) Pad image tensors in the tile dimension with zeros to the largest number
|
| 109 |
+
of tiles in the batch
|
| 110 |
+
(3) Add empty images of zeros to samples up to max number of images in the batch
|
| 111 |
+
(4) Pad aspect ratios with (1,1) for all added padding images
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids,
|
| 115 |
+
labels, images, and aspect_ratio.
|
| 116 |
+
padding_idx (int): Padding index for input token ids. Defaults to 0.
|
| 117 |
+
ignore_idx (int): Padding index for labels. Defaults to -100.
|
| 118 |
+
pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles
|
| 119 |
+
in the batch. Defaults to None.
|
| 120 |
+
pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images
|
| 121 |
+
in the batch. Defaults to None.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors.
|
| 125 |
+
- tokens: Tensor of shape (bsz, max_seq_len)
|
| 126 |
+
- labels: Tensor of shape (bsz, max_seq_len)
|
| 127 |
+
- images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w)
|
| 128 |
+
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
|
| 129 |
+
|
| 130 |
+
Example:
|
| 131 |
+
>>> image_id = 1
|
| 132 |
+
>>> tokens_per_tile = 5
|
| 133 |
+
>>> c, h, w = 1, 1, 1
|
| 134 |
+
>>> batch = [
|
| 135 |
+
... {
|
| 136 |
+
... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7],
|
| 137 |
+
... "encoder_input": {
|
| 138 |
+
... # One image with two tiles, one image with three tiles
|
| 139 |
+
... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
|
| 140 |
+
... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
|
| 141 |
+
... },
|
| 142 |
+
... },
|
| 143 |
+
... {
|
| 144 |
+
... "input_ids": [1, 4], "labels": [8, 9],
|
| 145 |
+
... "encoder_input": {
|
| 146 |
+
... # One image with four tiles
|
| 147 |
+
... "images": [torch.ones(4, c, h, w)],
|
| 148 |
+
... "aspect_ratio": [torch.tensor([2, 2])],
|
| 149 |
+
... },
|
| 150 |
+
... },
|
| 151 |
+
... ]
|
| 152 |
+
... collator = MultiModalCollator(pad_max_tiles=4)
|
| 153 |
+
>>> model_inputs = collator(batch=batch)
|
| 154 |
+
>>> print(model_inputs["input_ids"])
|
| 155 |
+
tensor([[1, 2, 1, 3],
|
| 156 |
+
[1, 4, 0, 0]])
|
| 157 |
+
>>> print(model_inputs["labels"])
|
| 158 |
+
tensor([[4, 5, 6, 7],
|
| 159 |
+
[8, 9, -100, -100]])
|
| 160 |
+
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
|
| 161 |
+
torch.Size([2, 2, 4, 1, 1, 1])
|
| 162 |
+
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2)
|
| 163 |
+
torch.Size([2, 2, 2])
|
| 164 |
+
>>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four
|
| 165 |
+
tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]])
|
| 166 |
+
>>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four
|
| 167 |
+
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]])
|
| 168 |
+
>>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded
|
| 169 |
+
tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]])
|
| 170 |
+
>>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample
|
| 171 |
+
tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]])
|
| 172 |
+
"""
|
| 173 |
+
# Text tokens can be handled independently by existing collaters
|
| 174 |
+
text_only = [
|
| 175 |
+
{"input_ids": sample["input_ids"], "labels": sample["labels"]}
|
| 176 |
+
for sample in batch
|
| 177 |
+
]
|
| 178 |
+
collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx)
|
| 179 |
+
|
| 180 |
+
if self.pad_max_tiles is None:
|
| 181 |
+
# Get max number of tiles in batch
|
| 182 |
+
max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch)
|
| 183 |
+
else:
|
| 184 |
+
max_num_tiles = self.pad_max_tiles
|
| 185 |
+
|
| 186 |
+
# Pad images and aspect ratios to max number of tiles
|
| 187 |
+
batch_images = []
|
| 188 |
+
batch_aspect_ratios = []
|
| 189 |
+
|
| 190 |
+
for sample in batch:
|
| 191 |
+
sample_images = []
|
| 192 |
+
for image in sample["encoder_input"]["images"]:
|
| 193 |
+
# Single image in each sample has shape (n_tiles, c, h, w)
|
| 194 |
+
n_tiles = image.shape[0]
|
| 195 |
+
# Single mask in each sample corresponds to a single image and has shape (text_seq_len, image_seq_len)
|
| 196 |
+
# where image_seq_len = n_tiles * tokens_per_tile
|
| 197 |
+
padding_tiles = max_num_tiles - n_tiles
|
| 198 |
+
|
| 199 |
+
# Image should now have shape (max_num_tiles, c, h, w)
|
| 200 |
+
padded_image = F.pad(
|
| 201 |
+
image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
sample_images.append(padded_image)
|
| 205 |
+
# Stack multiple images and masks per sample in num_images dimension
|
| 206 |
+
batch_images.append(torch.stack(sample_images))
|
| 207 |
+
batch_aspect_ratios.append(
|
| 208 |
+
torch.stack(sample["encoder_input"]["aspect_ratio"])
|
| 209 |
+
)
|
| 210 |
+
# Finally, pad images, masks, aspect ratios to max number of images in batch
|
| 211 |
+
# (bsz, max_num_images, max_num_tiles, c, h, w)
|
| 212 |
+
collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0)
|
| 213 |
+
# (bsz, max_num_images, 2)
|
| 214 |
+
collated_aspect_ratios = pad_sequence(
|
| 215 |
+
batch_aspect_ratios, batch_first=True, padding_value=1
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
batch_dict = {
|
| 219 |
+
"input_ids": collated_text["input_ids"],
|
| 220 |
+
"labels": collated_text["labels"],
|
| 221 |
+
"encoder_input": {
|
| 222 |
+
"images": collated_images,
|
| 223 |
+
"aspect_ratio": collated_aspect_ratios,
|
| 224 |
+
},
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
return batch_dict
|
torchtitan/experiments/multimodal/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
torchvision
|
torchtitan/experiments/simple_fsdp/__pycache__/simple_fsdp.cpython-312.pyc
ADDED
|
Binary file (6.83 kB). View file
|
|
|
torchtitan/experiments/simple_fsdp/model.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from torchtitan.models.llama3 import Transformer, TransformerModelArgs
|
| 8 |
+
from .simple_fsdp import disable_data_parallel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SimpleFSDPTransformer(Transformer):
|
| 12 |
+
def __init__(self, model_args: TransformerModelArgs):
|
| 13 |
+
super().__init__(model_args)
|
| 14 |
+
self.init_weights()
|
| 15 |
+
|
| 16 |
+
def init_weights(self, *args, **kwargs):
|
| 17 |
+
with disable_data_parallel():
|
| 18 |
+
super().init_weights(*args, **kwargs)
|
torchtitan/experiments/simple_fsdp/simple_fsdp.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from torch.distributed._tensor import (
|
| 15 |
+
distribute_tensor,
|
| 16 |
+
DTensor,
|
| 17 |
+
Partial,
|
| 18 |
+
Replicate,
|
| 19 |
+
Shard,
|
| 20 |
+
)
|
| 21 |
+
from torch.utils.checkpoint import (
|
| 22 |
+
checkpoint,
|
| 23 |
+
CheckpointPolicy,
|
| 24 |
+
create_selective_checkpoint_contexts,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_active_parametrization = True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@contextmanager
|
| 32 |
+
def disable_data_parallel():
|
| 33 |
+
global _active_parametrization
|
| 34 |
+
try:
|
| 35 |
+
_active_parametrization = False
|
| 36 |
+
yield
|
| 37 |
+
finally:
|
| 38 |
+
_active_parametrization = True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass(frozen=True)
|
| 42 |
+
class MixedPrecisionPolicy:
|
| 43 |
+
param_dtype: Optional[torch.dtype] = None
|
| 44 |
+
reduce_dtype: Optional[torch.dtype] = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def fsdp_policy():
|
| 48 |
+
def _fsdp_recomp_policy():
|
| 49 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
| 50 |
+
to_recompute = func in {
|
| 51 |
+
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
| 52 |
+
torch.ops._c10d_functional.wait_tensor.default,
|
| 53 |
+
torch.ops.aten._to_copy.default, # for dtype cast in FSDP
|
| 54 |
+
}
|
| 55 |
+
return (
|
| 56 |
+
CheckpointPolicy.MUST_RECOMPUTE
|
| 57 |
+
if to_recompute
|
| 58 |
+
else CheckpointPolicy.MUST_SAVE
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return _custom_policy
|
| 62 |
+
|
| 63 |
+
return create_selective_checkpoint_contexts(_fsdp_recomp_policy())
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ReplicateComputation(torch.nn.Module):
|
| 67 |
+
def __init__(self, device_mesh, param_sharding, mode, regional_ac, mp_policy):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.device_mesh = device_mesh
|
| 70 |
+
self.param_sharding = param_sharding
|
| 71 |
+
self.mode = mode
|
| 72 |
+
self.compute_placements = [Replicate()] * self.device_mesh.ndim
|
| 73 |
+
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
|
| 74 |
+
self.regional_ac = regional_ac
|
| 75 |
+
mp_policy = mp_policy or MixedPrecisionPolicy()
|
| 76 |
+
self.param_dtype = mp_policy.param_dtype
|
| 77 |
+
self.reduce_dtype = mp_policy.reduce_dtype
|
| 78 |
+
|
| 79 |
+
def replicate_compute(self, x):
|
| 80 |
+
# data parallel runtime replicate parameters and do local compute
|
| 81 |
+
# the gradients are partial tensors that needs to perform reduction
|
| 82 |
+
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
|
| 83 |
+
|
| 84 |
+
# NOTE: specifying mixed precision is only available in pytorch_intern24
|
| 85 |
+
# https://github.com/tianyu-l/pytorch_intern24/pull/20
|
| 86 |
+
# support for FSDP + TP (assuming TP shards the inner-most dim)
|
| 87 |
+
if self.mode == "fully_shard" and x._spec.mesh.ndim == 2:
|
| 88 |
+
dp_placement, tp_placement = x._spec.placements
|
| 89 |
+
dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"]
|
| 90 |
+
|
| 91 |
+
# re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
|
| 92 |
+
# TODO: we should consider merging this logic into DTensor redistribute API
|
| 93 |
+
sharded_local_tensor = x.to_local()
|
| 94 |
+
sharded_dtensor = DTensor.from_local(
|
| 95 |
+
sharded_local_tensor, dp_mesh, self.param_sharding
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# the actuall FSDP all-gather on dp_mesh
|
| 99 |
+
# TODO(ruisizhang123): enable mixed-precision training here
|
| 100 |
+
# add the forward_dtype and backward_dtype back after landing changes in PyTorch DTensor
|
| 101 |
+
replicated_dtensor = sharded_dtensor.redistribute(
|
| 102 |
+
placements=self.compute_placements,
|
| 103 |
+
# forward_dtype=self.param_dtype,
|
| 104 |
+
# backward_dtype=self.reduce_dtype,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# re-wrap 1D all-gathered DTensor on dp_mesh to 1D DTensor on tp_mesh
|
| 108 |
+
# TODO: DTensor should support this mesh collasping operation
|
| 109 |
+
replicated_local_tensor = replicated_dtensor.to_local(
|
| 110 |
+
grad_placements=self.grad_placements
|
| 111 |
+
)
|
| 112 |
+
output = DTensor.from_local(
|
| 113 |
+
replicated_local_tensor, tp_mesh, (tp_placement,)
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
output = x.redistribute(
|
| 117 |
+
placements=self.compute_placements,
|
| 118 |
+
# forward_dtype=self.param_dtype,
|
| 119 |
+
# backward_dtype=self.reduce_dtype,
|
| 120 |
+
).to_local(grad_placements=self.grad_placements)
|
| 121 |
+
|
| 122 |
+
return output
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
global _active_parametrization
|
| 126 |
+
# This should never be set to true during forward, only outside for model
|
| 127 |
+
# inspection / debugging / initialization
|
| 128 |
+
# model initialization can be done now through
|
| 129 |
+
# with disable_data_parallel():
|
| 130 |
+
# model.init_weights()
|
| 131 |
+
if not _active_parametrization:
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
if self.regional_ac and self.mode in ("fully_shard", "hybrid_shard"):
|
| 135 |
+
# apply checkpointing to implement reshard_after_forward
|
| 136 |
+
output = checkpoint(
|
| 137 |
+
self.replicate_compute, x, use_reentrant=False, context_fn=fsdp_policy
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
output = self.replicate_compute(x)
|
| 141 |
+
|
| 142 |
+
return output
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def data_parallel(
|
| 146 |
+
model,
|
| 147 |
+
device_mesh,
|
| 148 |
+
mode="replicate",
|
| 149 |
+
ac_mode: str = "none",
|
| 150 |
+
mp_policy: Optional[MixedPrecisionPolicy] = None,
|
| 151 |
+
):
|
| 152 |
+
if mode == "replicate":
|
| 153 |
+
param_sharding = (Replicate(),)
|
| 154 |
+
elif mode == "fully_shard":
|
| 155 |
+
param_sharding = (Shard(0),)
|
| 156 |
+
elif mode == "hybrid_shard":
|
| 157 |
+
# replicate inter-host, fully shard intra-host
|
| 158 |
+
param_sharding = (Replicate(), Shard(0))
|
| 159 |
+
assert (
|
| 160 |
+
device_mesh.ndim == 2
|
| 161 |
+
), "hybrid sharded data parallel requires 2D DeviceMesh"
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError(f"Unsupported mode {mode}")
|
| 164 |
+
|
| 165 |
+
modules = list(model.modules())
|
| 166 |
+
|
| 167 |
+
# apply regional ac (with fsdp_policy) if no global ac is to be applied
|
| 168 |
+
regional_ac = ac_mode == "none"
|
| 169 |
+
|
| 170 |
+
for mod in modules:
|
| 171 |
+
params_dict = dict(mod.named_parameters(recurse=False))
|
| 172 |
+
for p_name, p in params_dict.items():
|
| 173 |
+
if p is not None and p.numel() > 0:
|
| 174 |
+
mod.register_parameter(
|
| 175 |
+
p_name,
|
| 176 |
+
# NOTE: for 2D we need to distribute_tensor a DTensor
|
| 177 |
+
# which requires latest change in pytorch_intern24
|
| 178 |
+
# https://github.com/tianyu-l/pytorch_intern24/pull/25
|
| 179 |
+
nn.Parameter(distribute_tensor(p, device_mesh, param_sharding)),
|
| 180 |
+
)
|
| 181 |
+
nn.utils.parametrize.register_parametrization(
|
| 182 |
+
mod,
|
| 183 |
+
p_name,
|
| 184 |
+
ReplicateComputation(
|
| 185 |
+
device_mesh,
|
| 186 |
+
param_sharding,
|
| 187 |
+
mode,
|
| 188 |
+
regional_ac,
|
| 189 |
+
mp_policy=mp_policy,
|
| 190 |
+
),
|
| 191 |
+
unsafe=True,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
return model
|
torchtitan/experiments/simple_fsdp/tests/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
torchtitan/models/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Import the built-in models here so that the corresponding register_model_spec()
|
| 9 |
+
# will be called.
|
| 10 |
+
import torchtitan.models.llama3 # noqa: F401
|
torchtitan/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (195 Bytes). View file
|
|
|
torchtitan/models/__pycache__/norms.cpython-312.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
torchtitan/models/llama3/__init__.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
|
| 8 |
+
|
| 9 |
+
from torchtitan.components.loss import build_cross_entropy_loss
|
| 10 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
| 11 |
+
from torchtitan.components.optimizer import build_optimizers
|
| 12 |
+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
|
| 13 |
+
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
|
| 14 |
+
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
|
| 15 |
+
|
| 16 |
+
from .model import Transformer, TransformerModelArgs
|
| 17 |
+
from .parallelize_llama import parallelize_llama
|
| 18 |
+
from .pipeline_llama import pipeline_llama
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"parallelize_llama",
|
| 22 |
+
"pipeline_llama",
|
| 23 |
+
"TransformerModelArgs",
|
| 24 |
+
"Transformer",
|
| 25 |
+
"llama3_configs",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
llama3_configs = {
|
| 30 |
+
"debugmodel": TransformerModelArgs(
|
| 31 |
+
dim=256, n_layers=8, n_heads=16, rope_theta=500000
|
| 32 |
+
),
|
| 33 |
+
"8B": TransformerModelArgs(
|
| 34 |
+
dim=4096,
|
| 35 |
+
n_layers=32,
|
| 36 |
+
n_heads=32,
|
| 37 |
+
n_kv_heads=8,
|
| 38 |
+
ffn_dim_multiplier=1.3,
|
| 39 |
+
multiple_of=1024,
|
| 40 |
+
rope_theta=500000,
|
| 41 |
+
),
|
| 42 |
+
"70B": TransformerModelArgs(
|
| 43 |
+
dim=8192,
|
| 44 |
+
n_layers=80,
|
| 45 |
+
n_heads=64,
|
| 46 |
+
n_kv_heads=8,
|
| 47 |
+
ffn_dim_multiplier=1.3,
|
| 48 |
+
multiple_of=4096,
|
| 49 |
+
rope_theta=500000,
|
| 50 |
+
),
|
| 51 |
+
"405B": TransformerModelArgs(
|
| 52 |
+
dim=16384,
|
| 53 |
+
n_layers=126,
|
| 54 |
+
n_heads=128,
|
| 55 |
+
n_kv_heads=8,
|
| 56 |
+
ffn_dim_multiplier=1.2,
|
| 57 |
+
multiple_of=4096,
|
| 58 |
+
rope_theta=500000,
|
| 59 |
+
),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
register_train_spec(
|
| 64 |
+
TrainSpec(
|
| 65 |
+
name="llama3",
|
| 66 |
+
cls=Transformer,
|
| 67 |
+
config=llama3_configs,
|
| 68 |
+
parallelize_fn=parallelize_llama,
|
| 69 |
+
pipelining_fn=pipeline_llama,
|
| 70 |
+
build_optimizers_fn=build_optimizers,
|
| 71 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
| 72 |
+
build_dataloader_fn=build_hf_dataloader,
|
| 73 |
+
build_tokenizer_fn=build_tiktoken_tokenizer,
|
| 74 |
+
build_loss_fn=build_cross_entropy_loss,
|
| 75 |
+
)
|
| 76 |
+
)
|
torchtitan/models/llama3/parallelize_llama.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
| 8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.distributed._composable.replicate import replicate
|
| 15 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| 16 |
+
checkpoint_wrapper as ptd_checkpoint_wrapper,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 20 |
+
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
|
| 21 |
+
from torch.distributed.tensor import Replicate, Shard
|
| 22 |
+
from torch.distributed.tensor.parallel import (
|
| 23 |
+
ColwiseParallel,
|
| 24 |
+
parallelize_module,
|
| 25 |
+
PrepareModuleInput,
|
| 26 |
+
RowwiseParallel,
|
| 27 |
+
SequenceParallel,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
|
| 31 |
+
from torchtitan.distributed import ParallelDims
|
| 32 |
+
from torchtitan.tools.logging import logger
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def parallelize_llama(
|
| 36 |
+
model: nn.Module,
|
| 37 |
+
world_mesh: DeviceMesh,
|
| 38 |
+
parallel_dims: ParallelDims,
|
| 39 |
+
job_config: JobConfig,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
| 43 |
+
parallelism to the model.
|
| 44 |
+
|
| 45 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
| 46 |
+
the model must fit on GPU or CPU memory.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
if parallel_dims.tp_enabled:
|
| 50 |
+
if (
|
| 51 |
+
job_config.parallelism.enable_async_tensor_parallel
|
| 52 |
+
and not job_config.training.compile
|
| 53 |
+
):
|
| 54 |
+
raise RuntimeError("Async TP requires --training.compile")
|
| 55 |
+
|
| 56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
| 57 |
+
float8_is_rowwise = job_config.float8.recipe_name in (
|
| 58 |
+
"rowwise",
|
| 59 |
+
"rowwise_with_gw_hp",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# For now, float8 all-gather with TP is only supported for tensorwise
|
| 63 |
+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
|
| 64 |
+
# all-gather happens in high precision.
|
| 65 |
+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
|
| 66 |
+
|
| 67 |
+
apply_tp(
|
| 68 |
+
model,
|
| 69 |
+
world_mesh["tp"],
|
| 70 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
| 71 |
+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
|
| 72 |
+
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if job_config.model.use_flex_attn:
|
| 76 |
+
if job_config.activation_checkpoint.mode == "selective":
|
| 77 |
+
raise ValueError(
|
| 78 |
+
"FlexAttention is not compatible with selective AC yet. "
|
| 79 |
+
"See https://github.com/pytorch/pytorch/issues/147879"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if parallel_dims.cp_enabled:
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"FlexAttention is not compatible with CP yet. "
|
| 85 |
+
"We are still working on this."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if job_config.activation_checkpoint.mode != "none":
|
| 89 |
+
apply_ac(model, job_config.activation_checkpoint)
|
| 90 |
+
|
| 91 |
+
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
|
| 92 |
+
if job_config.training.compile:
|
| 93 |
+
apply_compile(model)
|
| 94 |
+
|
| 95 |
+
if (
|
| 96 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
| 97 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
| 98 |
+
if parallel_dims.dp_replicate_enabled:
|
| 99 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
| 100 |
+
else:
|
| 101 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
| 102 |
+
|
| 103 |
+
apply_fsdp(
|
| 104 |
+
model,
|
| 105 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
| 106 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
| 107 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
| 108 |
+
pp_enabled=parallel_dims.pp_enabled,
|
| 109 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
| 110 |
+
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if parallel_dims.dp_replicate_enabled:
|
| 114 |
+
logger.info("Applied HSDP to the model")
|
| 115 |
+
else:
|
| 116 |
+
logger.info("Applied FSDP to the model")
|
| 117 |
+
|
| 118 |
+
if parallel_dims.cp_enabled:
|
| 119 |
+
logger.info("Applied Context Parallel to the model")
|
| 120 |
+
|
| 121 |
+
if job_config.training.enable_cpu_offload:
|
| 122 |
+
logger.info("Applied CPU Offloading to the model")
|
| 123 |
+
elif parallel_dims.dp_replicate_enabled:
|
| 124 |
+
if world_mesh.ndim > 1:
|
| 125 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
| 126 |
+
apply_ddp(
|
| 127 |
+
model,
|
| 128 |
+
world_mesh,
|
| 129 |
+
enable_compile=job_config.training.compile,
|
| 130 |
+
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return model
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def apply_tp(
|
| 137 |
+
model: nn.Module,
|
| 138 |
+
tp_mesh: DeviceMesh,
|
| 139 |
+
loss_parallel: bool,
|
| 140 |
+
enable_float8_tensorwise_tp: bool,
|
| 141 |
+
enable_async_tp: bool,
|
| 142 |
+
):
|
| 143 |
+
"""Apply tensor parallelism."""
|
| 144 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
| 145 |
+
# transformer block's inputs)
|
| 146 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
| 147 |
+
# 3. Parallelize the final linear output layer
|
| 148 |
+
parallelize_module(
|
| 149 |
+
model,
|
| 150 |
+
tp_mesh,
|
| 151 |
+
{
|
| 152 |
+
"tok_embeddings": RowwiseParallel(
|
| 153 |
+
input_layouts=Replicate(),
|
| 154 |
+
output_layouts=Shard(1),
|
| 155 |
+
),
|
| 156 |
+
"norm": SequenceParallel(),
|
| 157 |
+
"output": ColwiseParallel(
|
| 158 |
+
input_layouts=Shard(1),
|
| 159 |
+
output_layouts=Shard(-1) if loss_parallel else Replicate(),
|
| 160 |
+
use_local_output=not loss_parallel,
|
| 161 |
+
),
|
| 162 |
+
},
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Parallel styles used for transformer block linear weights and their
|
| 166 |
+
# inputs may be different for float8 linears with tensorwise scaling.
|
| 167 |
+
if enable_float8_tensorwise_tp:
|
| 168 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
| 169 |
+
from torchao.float8.float8_tensor_parallel import (
|
| 170 |
+
Float8ColwiseParallel,
|
| 171 |
+
Float8RowwiseParallel,
|
| 172 |
+
PrepareFloat8ModuleInput,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
rowwise_parallel, colwise_parallel, prepare_module_input = (
|
| 176 |
+
Float8RowwiseParallel,
|
| 177 |
+
Float8ColwiseParallel,
|
| 178 |
+
PrepareFloat8ModuleInput,
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
rowwise_parallel, colwise_parallel, prepare_module_input = (
|
| 182 |
+
RowwiseParallel,
|
| 183 |
+
ColwiseParallel,
|
| 184 |
+
PrepareModuleInput,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Apply tensor + sequence parallelism to every transformer block
|
| 188 |
+
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
|
| 189 |
+
# by folding (and unfolding) the batch dimension and the sequence dimension.
|
| 190 |
+
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
|
| 191 |
+
for layer_id, transformer_block in model.layers.items():
|
| 192 |
+
layer_plan = {
|
| 193 |
+
"attention_norm": SequenceParallel(),
|
| 194 |
+
"attention": prepare_module_input(
|
| 195 |
+
input_layouts=(Shard(1), None),
|
| 196 |
+
desired_input_layouts=(Replicate(), None),
|
| 197 |
+
),
|
| 198 |
+
"attention.wq": colwise_parallel(),
|
| 199 |
+
"attention.wk": colwise_parallel(),
|
| 200 |
+
"attention.wv": colwise_parallel(),
|
| 201 |
+
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
|
| 202 |
+
"ffn_norm": SequenceParallel(),
|
| 203 |
+
"feed_forward": prepare_module_input(
|
| 204 |
+
input_layouts=(Shard(1),),
|
| 205 |
+
desired_input_layouts=(Replicate(),),
|
| 206 |
+
),
|
| 207 |
+
"feed_forward.w1": colwise_parallel(),
|
| 208 |
+
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
|
| 209 |
+
"feed_forward.w3": colwise_parallel(),
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
parallelize_module(
|
| 213 |
+
module=transformer_block,
|
| 214 |
+
device_mesh=tp_mesh,
|
| 215 |
+
parallelize_plan=layer_plan,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if enable_async_tp:
|
| 219 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
| 220 |
+
|
| 221 |
+
torch._inductor.config._micro_pipeline_tp = True
|
| 222 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
| 223 |
+
|
| 224 |
+
logger.info(
|
| 225 |
+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
|
| 226 |
+
"Tensor Parallelism to the model"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# for selective op activation checkpointing
|
| 231 |
+
_save_list = {
|
| 232 |
+
torch.ops.aten.mm.default,
|
| 233 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
| 234 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
| 235 |
+
# for low precision training, it's useful to always save
|
| 236 |
+
# the result of max, since the absolute maximum is
|
| 237 |
+
# used to compute the scaling factor for quantization.
|
| 238 |
+
torch.ops.aten.max.default,
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
|
| 243 |
+
valid_ac_modes = ("full", "selective")
|
| 244 |
+
if ac_config.mode not in valid_ac_modes:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if ac_config.mode == "full":
|
| 250 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 251 |
+
|
| 252 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
| 253 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
| 254 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
| 255 |
+
if not use_op_sac and not use_layer_sac:
|
| 256 |
+
raise ValueError(
|
| 257 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
| 258 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
| 259 |
+
)
|
| 260 |
+
if use_op_sac:
|
| 261 |
+
from torch.utils.checkpoint import (
|
| 262 |
+
CheckpointPolicy,
|
| 263 |
+
create_selective_checkpoint_contexts,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def _get_custom_policy(meta):
|
| 267 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
| 268 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
| 269 |
+
mm_count_key = f"{mode}_mm_count"
|
| 270 |
+
if func == torch.ops.aten.mm.default:
|
| 271 |
+
meta[mm_count_key] += 1
|
| 272 |
+
# Saves output of all compute ops, except every second mm
|
| 273 |
+
to_save = func in _save_list and not (
|
| 274 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
| 275 |
+
)
|
| 276 |
+
return (
|
| 277 |
+
CheckpointPolicy.MUST_SAVE
|
| 278 |
+
if to_save
|
| 279 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return _custom_policy
|
| 283 |
+
|
| 284 |
+
def selective_checkpointing_context_fn():
|
| 285 |
+
meta = defaultdict(int)
|
| 286 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
| 287 |
+
|
| 288 |
+
return ptd_checkpoint_wrapper(
|
| 289 |
+
module,
|
| 290 |
+
context_fn=selective_checkpointing_context_fn,
|
| 291 |
+
preserve_rng_state=False,
|
| 292 |
+
)
|
| 293 |
+
elif use_layer_sac:
|
| 294 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
| 295 |
+
ac_freq = int(ac_config.selective_ac_option)
|
| 296 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
| 297 |
+
ptd_checkpoint_wrapper._count += 1
|
| 298 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
| 299 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
| 300 |
+
else:
|
| 301 |
+
return module
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def apply_ac(model: nn.Module, ac_config):
|
| 305 |
+
"""Apply activation checkpointing to the model."""
|
| 306 |
+
for layer_id, transformer_block in model.layers.named_children():
|
| 307 |
+
transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config)
|
| 308 |
+
model.layers.register_module(layer_id, transformer_block)
|
| 309 |
+
|
| 310 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def apply_compile(model: nn.Module):
|
| 314 |
+
"""
|
| 315 |
+
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
|
| 316 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
| 317 |
+
"""
|
| 318 |
+
for layer_id, transformer_block in model.layers.named_children():
|
| 319 |
+
transformer_block = torch.compile(transformer_block, fullgraph=True)
|
| 320 |
+
model.layers.register_module(layer_id, transformer_block)
|
| 321 |
+
|
| 322 |
+
logger.info("Compiling each TransformerBlock with torch.compile")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def apply_fsdp(
|
| 326 |
+
model: nn.Module,
|
| 327 |
+
dp_mesh: DeviceMesh,
|
| 328 |
+
param_dtype: torch.dtype,
|
| 329 |
+
reduce_dtype: torch.dtype,
|
| 330 |
+
pp_enabled: bool,
|
| 331 |
+
cpu_offload: bool = False,
|
| 332 |
+
reshard_after_forward_policy: str = "default",
|
| 333 |
+
):
|
| 334 |
+
"""
|
| 335 |
+
Apply data parallelism (via FSDP2) to the model.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
model (nn.Module): The model to apply data parallelism to.
|
| 339 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
| 340 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
| 341 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
| 342 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
| 343 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
| 344 |
+
reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default".
|
| 345 |
+
Other options: "never", "always".
|
| 346 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
| 347 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
| 348 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
| 349 |
+
|
| 350 |
+
"""
|
| 351 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
| 352 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
| 353 |
+
if cpu_offload:
|
| 354 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
| 355 |
+
|
| 356 |
+
for layer_id, transformer_block in model.layers.items():
|
| 357 |
+
if reshard_after_forward_policy == "always":
|
| 358 |
+
reshard_after_forward = True
|
| 359 |
+
elif reshard_after_forward_policy == "never":
|
| 360 |
+
reshard_after_forward = False
|
| 361 |
+
elif reshard_after_forward_policy == "default":
|
| 362 |
+
if pp_enabled:
|
| 363 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
| 364 |
+
# all-gathers, which can be expensive and non-overlapped
|
| 365 |
+
reshard_after_forward = False
|
| 366 |
+
else:
|
| 367 |
+
# As an optimization, do not reshard after forward for the last
|
| 368 |
+
# transformer block since FSDP would prefetch it immediately
|
| 369 |
+
reshard_after_forward = int(layer_id) < len(model.layers) - 1
|
| 370 |
+
else:
|
| 371 |
+
raise ValueError(
|
| 372 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
| 373 |
+
)
|
| 374 |
+
fully_shard(
|
| 375 |
+
transformer_block,
|
| 376 |
+
**fsdp_config,
|
| 377 |
+
reshard_after_forward=reshard_after_forward,
|
| 378 |
+
)
|
| 379 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def apply_ddp(
|
| 383 |
+
model: nn.Module,
|
| 384 |
+
dp_mesh: DeviceMesh,
|
| 385 |
+
enable_compile: bool,
|
| 386 |
+
enable_compiled_autograd: bool,
|
| 387 |
+
):
|
| 388 |
+
if enable_compile:
|
| 389 |
+
if enable_compiled_autograd:
|
| 390 |
+
torch._dynamo.config.optimize_ddp = (
|
| 391 |
+
"python_reducer_without_compiled_forward"
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
| 395 |
+
|
| 396 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
| 397 |
+
|
| 398 |
+
logger.info("Applied DDP to the model")
|
torchtitan/models/llama3/pipeline_llama.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file applies the PT-D pipeline parallelism to the Llama model.
|
| 8 |
+
|
| 9 |
+
import copy
|
| 10 |
+
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.distributed import DeviceMesh
|
| 13 |
+
from torch.distributed.pipelining import PipelineStage
|
| 14 |
+
from torch.distributed.pipelining.schedules import (
|
| 15 |
+
_PipelineSchedule,
|
| 16 |
+
get_schedule_class,
|
| 17 |
+
ScheduleZBVZeroBubble,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from torchtitan.components.loss import LossFunction
|
| 21 |
+
from torchtitan.config_manager import JobConfig
|
| 22 |
+
from torchtitan.distributed import ParallelDims
|
| 23 |
+
from torchtitan.distributed.pipeline import (
|
| 24 |
+
build_pipeline_schedule,
|
| 25 |
+
generate_split_points,
|
| 26 |
+
stage_ids_this_rank,
|
| 27 |
+
)
|
| 28 |
+
from torchtitan.protocols.train_spec import DeviceType, ParallelizeFunction
|
| 29 |
+
from torchtitan.tools.logging import logger
|
| 30 |
+
|
| 31 |
+
from .model import TransformerModelArgs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pipeline_llama(
|
| 35 |
+
model: nn.Module,
|
| 36 |
+
world_mesh: DeviceMesh,
|
| 37 |
+
parallel_dims: ParallelDims,
|
| 38 |
+
job_config: JobConfig,
|
| 39 |
+
device: DeviceType,
|
| 40 |
+
model_config: TransformerModelArgs,
|
| 41 |
+
parallelize_fn: ParallelizeFunction,
|
| 42 |
+
loss_fn: LossFunction,
|
| 43 |
+
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
|
| 44 |
+
pp_mesh = world_mesh["pp"]
|
| 45 |
+
|
| 46 |
+
stages, model_parts = pipeline_llama_manual_split(
|
| 47 |
+
model, pp_mesh, parallel_dims, job_config, device, model_config
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
|
| 51 |
+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
|
| 52 |
+
# optimizer, and checkpointing
|
| 53 |
+
for i, m in enumerate(model_parts):
|
| 54 |
+
# apply SPMD-style PT-D techniques
|
| 55 |
+
m = parallelize_fn(m, world_mesh, parallel_dims, job_config)
|
| 56 |
+
model_parts[i] = m
|
| 57 |
+
# NOTE: this is to update the model in the stage
|
| 58 |
+
# in case the model is modified e.g. by torch.compile
|
| 59 |
+
stages[i].submod = m
|
| 60 |
+
|
| 61 |
+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
|
| 62 |
+
|
| 63 |
+
# This is used in the train loop to determine whether to pass in the input_ids and labels
|
| 64 |
+
has_first_stage = False
|
| 65 |
+
has_last_stage = False
|
| 66 |
+
for stage in stages:
|
| 67 |
+
if stage.is_first:
|
| 68 |
+
has_first_stage = True
|
| 69 |
+
if stage.is_last:
|
| 70 |
+
has_last_stage = True
|
| 71 |
+
|
| 72 |
+
return pp_schedule, model_parts, has_first_stage, has_last_stage
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def pipeline_llama_manual_split(
|
| 76 |
+
whole_model: nn.Module,
|
| 77 |
+
pp_mesh: DeviceMesh,
|
| 78 |
+
parallel_dims: ParallelDims,
|
| 79 |
+
job_config: JobConfig,
|
| 80 |
+
device: DeviceType,
|
| 81 |
+
model_config: TransformerModelArgs,
|
| 82 |
+
) -> tuple[list[PipelineStage], list[nn.Module]]:
|
| 83 |
+
"""
|
| 84 |
+
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
|
| 85 |
+
|
| 86 |
+
It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
|
| 87 |
+
|
| 88 |
+
The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
|
| 89 |
+
parallelism.
|
| 90 |
+
"""
|
| 91 |
+
pp_rank = pp_mesh.get_local_rank()
|
| 92 |
+
pp_size = pp_mesh.size()
|
| 93 |
+
parallelism_config = job_config.parallelism
|
| 94 |
+
|
| 95 |
+
splits = parallelism_config.pipeline_parallel_split_points or generate_split_points(
|
| 96 |
+
parallelism_config.pipeline_parallel_schedule,
|
| 97 |
+
parallelism_config.pipeline_parallel_layers_per_stage,
|
| 98 |
+
parallel_dims.pp,
|
| 99 |
+
model_config.n_layers,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def _build_stage(
|
| 103 |
+
stage_idx: int,
|
| 104 |
+
start_layer: str | None,
|
| 105 |
+
stop_layer: str | None,
|
| 106 |
+
is_first: bool = False,
|
| 107 |
+
is_last: bool = False,
|
| 108 |
+
) -> tuple[PipelineStage, nn.Module]:
|
| 109 |
+
model = copy.deepcopy(whole_model)
|
| 110 |
+
if not is_first:
|
| 111 |
+
model.tok_embeddings = None
|
| 112 |
+
|
| 113 |
+
drop_layers = start_layer is not None
|
| 114 |
+
for name in list(model.layers.keys()):
|
| 115 |
+
# we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
|
| 116 |
+
if f"layers.{name}" == start_layer:
|
| 117 |
+
drop_layers = False
|
| 118 |
+
if f"layers.{name}" == stop_layer:
|
| 119 |
+
drop_layers = True
|
| 120 |
+
if drop_layers:
|
| 121 |
+
del model.layers[name]
|
| 122 |
+
|
| 123 |
+
if not is_last:
|
| 124 |
+
model.norm = None
|
| 125 |
+
model.output = None
|
| 126 |
+
|
| 127 |
+
stage = PipelineStage(
|
| 128 |
+
model,
|
| 129 |
+
stage_idx,
|
| 130 |
+
num_stages,
|
| 131 |
+
device,
|
| 132 |
+
group=pp_mesh.get_group("pp"),
|
| 133 |
+
)
|
| 134 |
+
return stage, model
|
| 135 |
+
|
| 136 |
+
num_stages = len(splits) + 1
|
| 137 |
+
stage_idx = pp_rank
|
| 138 |
+
|
| 139 |
+
stages = []
|
| 140 |
+
models = []
|
| 141 |
+
|
| 142 |
+
schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule)
|
| 143 |
+
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
|
| 144 |
+
|
| 145 |
+
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
|
| 146 |
+
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
|
| 147 |
+
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
|
| 148 |
+
stage, model_chunk = _build_stage(
|
| 149 |
+
stage_idx,
|
| 150 |
+
start_layer,
|
| 151 |
+
stop_layer,
|
| 152 |
+
is_first=stage_idx == 0,
|
| 153 |
+
is_last=stage_idx == num_stages - 1,
|
| 154 |
+
)
|
| 155 |
+
logger.info(
|
| 156 |
+
f"PP rank {pp_rank} is building stage_idx {stage_idx}"
|
| 157 |
+
f" with start_layer {start_layer}, stop_layer {stop_layer}"
|
| 158 |
+
)
|
| 159 |
+
stages.append(stage)
|
| 160 |
+
models.append(model_chunk)
|
| 161 |
+
return stages, models
|
torchtitan/models/llama3/train_configs/llama3_405b.toml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# torchtitan Config.toml
|
| 2 |
+
# NOTE: this toml config is a preset for 128 H100 GPUs.
|
| 3 |
+
|
| 4 |
+
[job]
|
| 5 |
+
dump_folder = "./outputs"
|
| 6 |
+
description = "Llama 3 405B training"
|
| 7 |
+
|
| 8 |
+
[profiling]
|
| 9 |
+
enable_profiling = true
|
| 10 |
+
save_traces_folder = "profile_trace"
|
| 11 |
+
profile_freq = 100
|
| 12 |
+
|
| 13 |
+
[metrics]
|
| 14 |
+
log_freq = 10
|
| 15 |
+
enable_tensorboard = true
|
| 16 |
+
save_tb_folder = "tb"
|
| 17 |
+
|
| 18 |
+
[model]
|
| 19 |
+
name = "llama3"
|
| 20 |
+
flavor = "405B"
|
| 21 |
+
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
|
| 22 |
+
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
|
| 23 |
+
converters = "float8"
|
| 24 |
+
|
| 25 |
+
[optimizer]
|
| 26 |
+
name = "AdamW"
|
| 27 |
+
lr = 8e-5
|
| 28 |
+
eps = 1e-8
|
| 29 |
+
|
| 30 |
+
[lr_scheduler]
|
| 31 |
+
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
|
| 32 |
+
|
| 33 |
+
[training]
|
| 34 |
+
batch_size = 2
|
| 35 |
+
seq_len = 8192
|
| 36 |
+
max_norm = 1.0 # grad norm clipping
|
| 37 |
+
steps = 3000
|
| 38 |
+
compile = true
|
| 39 |
+
dataset = "c4"
|
| 40 |
+
|
| 41 |
+
[parallelism]
|
| 42 |
+
data_parallel_replicate_degree = 1
|
| 43 |
+
data_parallel_shard_degree = -1
|
| 44 |
+
tensor_parallel_degree = 8 # 8-way TP
|
| 45 |
+
enable_async_tensor_parallel = true
|
| 46 |
+
pipeline_parallel_degree = 1
|
| 47 |
+
context_parallel_degree = 1
|
| 48 |
+
|
| 49 |
+
[checkpoint]
|
| 50 |
+
enable_checkpoint = false
|
| 51 |
+
folder = "checkpoint"
|
| 52 |
+
interval = 500
|
| 53 |
+
model_weights_only = false
|
| 54 |
+
export_dtype = "float32"
|
| 55 |
+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
|
| 56 |
+
|
| 57 |
+
[activation_checkpoint]
|
| 58 |
+
mode = 'full' # ['none', 'selective', 'full']
|
| 59 |
+
|
| 60 |
+
[float8]
|
| 61 |
+
enable_fsdp_float8_all_gather = true
|
| 62 |
+
precompute_float8_dynamic_scale_for_fsdp = true
|
| 63 |
+
filter_fqns = "output"
|
torchtitan/tools/profiling.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import contextlib
|
| 8 |
+
import os
|
| 9 |
+
import pickle
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from torchtitan.config_manager import JobConfig
|
| 15 |
+
from torchtitan.tools.logging import logger
|
| 16 |
+
|
| 17 |
+
# the number of warmup steps before the active step in each profiling cycle
|
| 18 |
+
WARMUP = 3
|
| 19 |
+
|
| 20 |
+
# how much memory allocation/free ops to record in memory snapshots
|
| 21 |
+
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@contextlib.contextmanager
|
| 25 |
+
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
|
| 26 |
+
# get user defined profiler settings
|
| 27 |
+
enable_profiling = config.profiling.enable_profiling
|
| 28 |
+
|
| 29 |
+
if enable_profiling:
|
| 30 |
+
dump_dir = config.job.dump_folder
|
| 31 |
+
save_trace_dir = config.profiling.save_traces_folder
|
| 32 |
+
trace_dir = os.path.join(dump_dir, save_trace_dir)
|
| 33 |
+
profile_freq = config.profiling.profile_freq
|
| 34 |
+
|
| 35 |
+
rank = torch.distributed.get_rank()
|
| 36 |
+
|
| 37 |
+
def trace_handler(prof):
|
| 38 |
+
curr_trace_dir_name = "iteration_" + str(prof.step_num)
|
| 39 |
+
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
|
| 40 |
+
if not os.path.exists(curr_trace_dir):
|
| 41 |
+
os.makedirs(curr_trace_dir, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
logger.info(f"Dumping profiler traces at step {prof.step_num}")
|
| 44 |
+
begin = time.monotonic()
|
| 45 |
+
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
|
| 46 |
+
logger.info(
|
| 47 |
+
f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
|
| 51 |
+
|
| 52 |
+
if not os.path.exists(trace_dir):
|
| 53 |
+
os.makedirs(trace_dir, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
warmup, active = WARMUP, 1
|
| 56 |
+
wait = profile_freq - (active + warmup)
|
| 57 |
+
assert (
|
| 58 |
+
wait >= 0
|
| 59 |
+
), "profile_freq must be greater than or equal to warmup + active"
|
| 60 |
+
gpu_device_profiled = None
|
| 61 |
+
if torch.cuda.is_available():
|
| 62 |
+
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
|
| 63 |
+
elif torch.xpu.is_available():
|
| 64 |
+
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
|
| 65 |
+
with torch.profiler.profile(
|
| 66 |
+
activities=[
|
| 67 |
+
torch.profiler.ProfilerActivity.CPU,
|
| 68 |
+
gpu_device_profiled,
|
| 69 |
+
],
|
| 70 |
+
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
|
| 71 |
+
on_trace_ready=trace_handler,
|
| 72 |
+
record_shapes=True,
|
| 73 |
+
) as torch_profiler:
|
| 74 |
+
torch_profiler.step_num = global_step
|
| 75 |
+
yield torch_profiler
|
| 76 |
+
else:
|
| 77 |
+
torch_profiler = contextlib.nullcontext()
|
| 78 |
+
yield None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@contextlib.contextmanager
|
| 82 |
+
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0):
|
| 83 |
+
enable_snapshot = config.profiling.enable_memory_snapshot
|
| 84 |
+
if enable_snapshot:
|
| 85 |
+
snapshot_folder = config.profiling.save_memory_snapshot_folder
|
| 86 |
+
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder)
|
| 87 |
+
if not os.path.exists(snapshot_dir):
|
| 88 |
+
os.makedirs(snapshot_dir, exist_ok=True)
|
| 89 |
+
rank = torch.distributed.get_rank()
|
| 90 |
+
|
| 91 |
+
class MemoryProfiler:
|
| 92 |
+
def __init__(self, step_num: int, freq: int):
|
| 93 |
+
torch.cuda.memory._record_memory_history(
|
| 94 |
+
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES
|
| 95 |
+
)
|
| 96 |
+
# when resume training, we start from the last step
|
| 97 |
+
self.step_num = step_num
|
| 98 |
+
self.freq = freq
|
| 99 |
+
|
| 100 |
+
def step(self, exit_ctx: bool = False):
|
| 101 |
+
self.step_num += 1
|
| 102 |
+
if not exit_ctx and self.step_num % self.freq != 0:
|
| 103 |
+
return
|
| 104 |
+
if not exit_ctx:
|
| 105 |
+
curr_step = self.step_num
|
| 106 |
+
dir_name = f"iteration_{curr_step}"
|
| 107 |
+
else:
|
| 108 |
+
# dump as iteration_0_exit if OOM at iter 1
|
| 109 |
+
curr_step = self.step_num - 1
|
| 110 |
+
dir_name = f"iteration_{curr_step}_exit"
|
| 111 |
+
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
|
| 112 |
+
if not os.path.exists(curr_snapshot_dir):
|
| 113 |
+
os.makedirs(curr_snapshot_dir, exist_ok=True)
|
| 114 |
+
logger.info(f"Dumping memory snapshot at step {curr_step}")
|
| 115 |
+
begin = time.monotonic()
|
| 116 |
+
with open(
|
| 117 |
+
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
|
| 118 |
+
) as output:
|
| 119 |
+
pickle.dump(torch.cuda.memory._snapshot(), output)
|
| 120 |
+
logger.info(
|
| 121 |
+
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
|
| 125 |
+
profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
|
| 126 |
+
try:
|
| 127 |
+
yield profiler
|
| 128 |
+
except torch.OutOfMemoryError as e:
|
| 129 |
+
profiler.step(exit_ctx=True)
|
| 130 |
+
else:
|
| 131 |
+
yield None
|