File size: 5,248 Bytes
2605842 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import gc
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
import torch
from torch._utils import _get_available_device_type, _get_device_module
from torchtitan.tools.logging import logger
def get_device_info():
device_type = _get_available_device_type()
if device_type is None:
device_type = "cuda" # default device_type: cuda
device_module = _get_device_module(device_type) # default device_module:torch.cuda
return device_type, device_module
device_type, device_module = get_device_info()
# used to avoid stragglers in garbage collection
class GarbageCollection:
def __init__(self, gc_freq=1000):
assert gc_freq > 0, "gc_freq must be a positive integer"
self.gc_freq = gc_freq
gc.disable()
self.collect("Initial GC collection.")
def run(self, step_count):
if step_count > 1 and step_count % self.gc_freq == 0:
self.collect("Peforming periodical GC collection.")
@staticmethod
def collect(reason: str):
begin = time.monotonic()
gc.collect(1)
logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin)
# hardcoded BF16 type peak flops for NVIDIA A100, H100, H200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC
def get_peak_flops(device_name: str) -> int:
try:
# Run the lspci command and capture the output
result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
# Filter the output for lines containing both "NVIDIA" and "H100"
filtered_lines = [
line
for line in result.stdout.splitlines()
if "NVIDIA" in line and "H100" in line
]
# Join all filtered lines into a single string
device_name = " ".join(filtered_lines) or device_name
except FileNotFoundError as e:
logger.warning(f"Error running lspci: {e}, fallback to use device_name")
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 835e12
elif "PCIe" in device_name:
return 756e12
else: # for H100 SXM and other variants
return 989e12
elif "H200" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h200/
return 989e12
elif "MI300X" in device_name or "MI325X" in device_name:
# MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
# MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
return 1300e12
elif "MI250X" in device_name:
# data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
return 191.5e12
elif "Data Center GPU Max 1550" in device_name:
# Also known as Ponte Vecchio (PVC).
# data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
# Dot Product Accumulate Systolic (DPAS):
# - Freq: 1300MHz
# - #ops: 512
# Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
# Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
return 512 * max_comp_units * 1300 * 10**6
else: # for other GPU types, assume A100
logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
return 312e12
@dataclass(frozen=True)
class Color:
black = "\033[30m"
red = "\033[31m"
green = "\033[32m"
yellow = "\033[33m"
blue = "\033[34m"
magenta = "\033[35m"
cyan = "\033[36m"
white = "\033[37m"
reset = "\033[39m"
@dataclass(frozen=True)
class NoColor:
black = ""
red = ""
green = ""
yellow = ""
blue = ""
magenta = ""
cyan = ""
white = ""
reset = ""
def check_if_feature_in_pytorch(
feature_name: str,
pull_request: str,
min_nightly_version: Optional[str] = None,
) -> None:
if "git" in torch.__version__: # pytorch is built from source
# notify users to check if the pull request is included in their pytorch
logger.warning(
"detected that the pytorch is built from source. Please make sure the PR "
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
)
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
logger.warning(
f"detected that the pytorch version {torch.__version__} is older than "
f"{min_nightly_version}. Please upgrade a newer version to include the "
f"change in ({pull_request_link}) for correct {feature_name}."
)
|