Build uploaded using `kernels`.
Browse files- build/torch-cpu/__init__.py +5 -0
- build/torch-cpu/_ops.py +8 -0
- build/torch-cpu/fp8_fbgemm/__init__.py +26 -0
- build/torch-cpu/metadata.json +1 -0
- build/torch-cpu/quantizer.py +262 -0
- build/torch-cuda/__init__.py +5 -0
- build/torch-cuda/_ops.py +8 -0
- build/torch-cuda/fp8_fbgemm/__init__.py +26 -0
- build/torch-cuda/metadata.json +1 -0
- build/torch-cuda/quantizer.py +262 -0
- build/torch-rocm/__init__.py +5 -0
- build/torch-rocm/_ops.py +8 -0
- build/torch-rocm/fp8_fbgemm/__init__.py +26 -0
- build/torch-rocm/metadata.json +1 -0
- build/torch-rocm/quantizer.py +262 -0
- build/torch-xpu/__init__.py +5 -0
- build/torch-xpu/_ops.py +8 -0
- build/torch-xpu/fp8_fbgemm/__init__.py +26 -0
- build/torch-xpu/metadata.json +1 -0
- build/torch-xpu/quantizer.py +262 -0
build/torch-cpu/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantizer import quantize_fp8_per_row
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
quantize_fp8_per_row
|
| 5 |
+
]
|
build/torch-cpu/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
|
build/torch-cpu/fp8_fbgemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-cpu/metadata.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"python-depends":[]}
|
build/torch-cpu/quantizer.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license
|
| 5 |
+
# copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
from torch import nn
|
| 12 |
+
from triton import Config
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
|
| 15 |
+
def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
|
| 16 |
+
"""
|
| 17 |
+
Helper function to get constant values for the current platform.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
pt_dtype (torch.dtype): The correct torch fp8 datatype.
|
| 21 |
+
tl_dtype (tl.dtype): The correct triton fp8 datatype.
|
| 22 |
+
max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
|
| 23 |
+
eps (float): Minimum clip value to prevent divide by zero.
|
| 24 |
+
"""
|
| 25 |
+
pt_fp8_dtype = torch.float8_e4m3fn
|
| 26 |
+
tl_fp8_dtype = tl.float8e4nv
|
| 27 |
+
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@triton.autotune(
|
| 31 |
+
configs=[
|
| 32 |
+
Config({"BLOCK_SIZE": 512}),
|
| 33 |
+
Config({"BLOCK_SIZE": 1024}),
|
| 34 |
+
Config({"BLOCK_SIZE": 2048}),
|
| 35 |
+
Config({"BLOCK_SIZE": 4096}),
|
| 36 |
+
Config({"BLOCK_SIZE": 8192}),
|
| 37 |
+
],
|
| 38 |
+
key=["K"],
|
| 39 |
+
)
|
| 40 |
+
@triton.jit
|
| 41 |
+
def _kernel_quantize_fp8_row(
|
| 42 |
+
A,
|
| 43 |
+
A_scale,
|
| 44 |
+
A_fp8,
|
| 45 |
+
scale_ub,
|
| 46 |
+
zero_start_index_M,
|
| 47 |
+
B,
|
| 48 |
+
M,
|
| 49 |
+
N,
|
| 50 |
+
K,
|
| 51 |
+
K_fp8, # used when padding
|
| 52 |
+
stride_ab,
|
| 53 |
+
stride_am,
|
| 54 |
+
stride_an,
|
| 55 |
+
stride_ak,
|
| 56 |
+
stride_ob,
|
| 57 |
+
stride_om,
|
| 58 |
+
stride_on,
|
| 59 |
+
stride_ok,
|
| 60 |
+
stride_zb,
|
| 61 |
+
stride_zm,
|
| 62 |
+
TL_FP8_DTYPE: tl.constexpr,
|
| 63 |
+
MAX_FP8: tl.constexpr,
|
| 64 |
+
EPS: tl.constexpr,
|
| 65 |
+
CLAMP_MAX: tl.constexpr,
|
| 66 |
+
JAGGED: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
USE_INT64: tl.constexpr,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""Quantize and scale each row.
|
| 71 |
+
|
| 72 |
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
| 73 |
+
|
| 74 |
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
| 75 |
+
in a max pass then scale/quantize pass.
|
| 76 |
+
|
| 77 |
+
Todo:
|
| 78 |
+
* Better tiling schemes.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
A (Tensor): higher precision input tensor of 4 dimension.
|
| 82 |
+
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
| 83 |
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
| 84 |
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
| 85 |
+
B (int): Size of dimenion 0
|
| 86 |
+
M (int): Size of dimenion 1
|
| 87 |
+
N (int): Size of dimenion 2
|
| 88 |
+
K (int): Size of dimenion 3 (input row size)
|
| 89 |
+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
|
| 90 |
+
stride_ab (int): Stride of b dimension of A.
|
| 91 |
+
stride_am (int): Stride of m dimension of A.
|
| 92 |
+
stride_an (int): Stride of n dimension of A.
|
| 93 |
+
stride_ak (int): Stride of k dimension of A.
|
| 94 |
+
stride_ob (int): Stride of b dimension of output.
|
| 95 |
+
stride_om (int): Stride of m dimension of output.
|
| 96 |
+
stride_on (int): Stride of n dimension of output.
|
| 97 |
+
stride_ok (int): Stride of k dimension of output.
|
| 98 |
+
stride_zb (int): Stride of b dimension of jagged index.
|
| 99 |
+
stride_zm (int): Stride of m dimension of jagged index.
|
| 100 |
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
| 101 |
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
| 102 |
+
EPS (float): Epsilon value for numerical stability.
|
| 103 |
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
| 104 |
+
JAGGED (bool): Whether to use jagged indexing.
|
| 105 |
+
BLOCK_SIZE (int): Block size for reduction.
|
| 106 |
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
| 107 |
+
"""
|
| 108 |
+
pid = tl.program_id(0)
|
| 109 |
+
# Use int64 indexing for large inputs. This is slower, but
|
| 110 |
+
# needed to avoid index overflows.
|
| 111 |
+
if USE_INT64:
|
| 112 |
+
pid = pid.to(tl.int64)
|
| 113 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 114 |
+
a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
|
| 115 |
+
a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
|
| 116 |
+
|
| 117 |
+
K_in = K
|
| 118 |
+
if JAGGED:
|
| 119 |
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
| 120 |
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
| 121 |
+
current_row = pid % N
|
| 122 |
+
# If this row is empty, dont process any of it.
|
| 123 |
+
if current_row >= group_rows:
|
| 124 |
+
K_in = 0
|
| 125 |
+
|
| 126 |
+
# Calculate max.
|
| 127 |
+
cur_max = 0.0
|
| 128 |
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
| 129 |
+
a = tl.load(
|
| 130 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 131 |
+
mask=n_offset < K_in,
|
| 132 |
+
other=0.0,
|
| 133 |
+
)
|
| 134 |
+
tile_max = tl.max(tl.abs(a))
|
| 135 |
+
cur_max = tl.maximum(tile_max, cur_max)
|
| 136 |
+
n_offset += BLOCK_SIZE
|
| 137 |
+
# Clamp max value appropriately.
|
| 138 |
+
if CLAMP_MAX:
|
| 139 |
+
ub = tl.load(scale_ub)
|
| 140 |
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
| 141 |
+
else:
|
| 142 |
+
cur_max = tl.maximum(cur_max, EPS)
|
| 143 |
+
# Scale and quantize.
|
| 144 |
+
a_scale = MAX_FP8 / cur_max
|
| 145 |
+
tl.store(A_scale + pid, 1.0 / a_scale)
|
| 146 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 147 |
+
|
| 148 |
+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
|
| 149 |
+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
|
| 150 |
+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
|
| 151 |
+
a = tl.load(
|
| 152 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 153 |
+
mask=n_offset < K_in,
|
| 154 |
+
other=0.0,
|
| 155 |
+
)
|
| 156 |
+
# For elements >= K, a will be 0
|
| 157 |
+
a_fp8 = a * a_scale
|
| 158 |
+
# Clamp A to fp8 range to make sure there's no overflow.
|
| 159 |
+
# This is required for AMD. Nvidia's default saturation
|
| 160 |
+
# handles it, but it's nice to have anyway.
|
| 161 |
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
| 162 |
+
|
| 163 |
+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
|
| 164 |
+
tl.store(
|
| 165 |
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
| 166 |
+
a_fp8,
|
| 167 |
+
mask=n_offset < K_fp8,
|
| 168 |
+
)
|
| 169 |
+
n_offset += BLOCK_SIZE
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def quantize_fp8_per_row(
|
| 173 |
+
a: torch.Tensor,
|
| 174 |
+
scale_ub: Optional[torch.Tensor] = None,
|
| 175 |
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
| 176 |
+
align_rows_to: Optional[int] = None,
|
| 177 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 178 |
+
"""
|
| 179 |
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
a (Tensor): higher precision input tensor of 4 dimension.
|
| 183 |
+
scale_ub (Tensor): Maximum allowed value for scale.
|
| 184 |
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
| 185 |
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
| 186 |
+
Returns:
|
| 187 |
+
torch.Tensor: fp8 scaled tensor.
|
| 188 |
+
torch.Tensor: reciprocal scale tensor per row.
|
| 189 |
+
"""
|
| 190 |
+
# Handle meta tensors (skip kernel execution)
|
| 191 |
+
if a.device.type == "meta":
|
| 192 |
+
pt_dtype, _, _, _ = get_fp8_constants()
|
| 193 |
+
a_shape = list(a.shape)
|
| 194 |
+
if align_rows_to is not None:
|
| 195 |
+
last_dim = a_shape[-1]
|
| 196 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 197 |
+
a_shape[-1] = padded_last_dim
|
| 198 |
+
|
| 199 |
+
# Return empty meta tensors with correct shapes
|
| 200 |
+
return (
|
| 201 |
+
torch.empty(a_shape, device="meta", dtype=pt_dtype),
|
| 202 |
+
torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if scale_ub is not None and scale_ub.device != a.device:
|
| 206 |
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
| 207 |
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
| 208 |
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
| 209 |
+
|
| 210 |
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
| 211 |
+
a_shape = a.shape
|
| 212 |
+
while a.dim() < 4:
|
| 213 |
+
a = a.unsqueeze(0)
|
| 214 |
+
if zero_start_index_M is not None:
|
| 215 |
+
# There should be one value of zero_start_index_M per NxK matrix.
|
| 216 |
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
| 217 |
+
# Get constant values.
|
| 218 |
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
| 219 |
+
num_rows = a.numel() // a.shape[-1]
|
| 220 |
+
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
|
| 221 |
+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
|
| 222 |
+
if align_rows_to is not None:
|
| 223 |
+
last_dim = a.shape[-1]
|
| 224 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 225 |
+
a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
|
| 226 |
+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
|
| 227 |
+
else:
|
| 228 |
+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
|
| 229 |
+
|
| 230 |
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
| 231 |
+
use_int64 = a.numel() > (2**31 - 1)
|
| 232 |
+
grid = (num_rows,)
|
| 233 |
+
_kernel_quantize_fp8_row[grid](
|
| 234 |
+
a,
|
| 235 |
+
a_scale,
|
| 236 |
+
a_fp8,
|
| 237 |
+
scale_ub,
|
| 238 |
+
zero_start_index_M,
|
| 239 |
+
a.shape[0],
|
| 240 |
+
a.shape[1],
|
| 241 |
+
a.shape[2],
|
| 242 |
+
a.shape[3],
|
| 243 |
+
a_fp8.shape[3],
|
| 244 |
+
a.stride(0),
|
| 245 |
+
a.stride(1),
|
| 246 |
+
a.stride(2),
|
| 247 |
+
a.stride(3),
|
| 248 |
+
a_fp8.stride(0),
|
| 249 |
+
a_fp8.stride(1),
|
| 250 |
+
a_fp8.stride(2),
|
| 251 |
+
a_fp8.stride(3),
|
| 252 |
+
(zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
|
| 253 |
+
(zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
|
| 254 |
+
TL_FP8_DTYPE=tl_dtype,
|
| 255 |
+
MAX_FP8=max_fp8,
|
| 256 |
+
EPS=eps,
|
| 257 |
+
CLAMP_MAX=scale_ub is not None,
|
| 258 |
+
JAGGED=zero_start_index_M is not None,
|
| 259 |
+
USE_INT64=use_int64,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
|
build/torch-cuda/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantizer import quantize_fp8_per_row
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
quantize_fp8_per_row
|
| 5 |
+
]
|
build/torch-cuda/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
|
build/torch-cuda/fp8_fbgemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-cuda/metadata.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"python-depends":[]}
|
build/torch-cuda/quantizer.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license
|
| 5 |
+
# copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
from torch import nn
|
| 12 |
+
from triton import Config
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
|
| 15 |
+
def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
|
| 16 |
+
"""
|
| 17 |
+
Helper function to get constant values for the current platform.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
pt_dtype (torch.dtype): The correct torch fp8 datatype.
|
| 21 |
+
tl_dtype (tl.dtype): The correct triton fp8 datatype.
|
| 22 |
+
max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
|
| 23 |
+
eps (float): Minimum clip value to prevent divide by zero.
|
| 24 |
+
"""
|
| 25 |
+
pt_fp8_dtype = torch.float8_e4m3fn
|
| 26 |
+
tl_fp8_dtype = tl.float8e4nv
|
| 27 |
+
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@triton.autotune(
|
| 31 |
+
configs=[
|
| 32 |
+
Config({"BLOCK_SIZE": 512}),
|
| 33 |
+
Config({"BLOCK_SIZE": 1024}),
|
| 34 |
+
Config({"BLOCK_SIZE": 2048}),
|
| 35 |
+
Config({"BLOCK_SIZE": 4096}),
|
| 36 |
+
Config({"BLOCK_SIZE": 8192}),
|
| 37 |
+
],
|
| 38 |
+
key=["K"],
|
| 39 |
+
)
|
| 40 |
+
@triton.jit
|
| 41 |
+
def _kernel_quantize_fp8_row(
|
| 42 |
+
A,
|
| 43 |
+
A_scale,
|
| 44 |
+
A_fp8,
|
| 45 |
+
scale_ub,
|
| 46 |
+
zero_start_index_M,
|
| 47 |
+
B,
|
| 48 |
+
M,
|
| 49 |
+
N,
|
| 50 |
+
K,
|
| 51 |
+
K_fp8, # used when padding
|
| 52 |
+
stride_ab,
|
| 53 |
+
stride_am,
|
| 54 |
+
stride_an,
|
| 55 |
+
stride_ak,
|
| 56 |
+
stride_ob,
|
| 57 |
+
stride_om,
|
| 58 |
+
stride_on,
|
| 59 |
+
stride_ok,
|
| 60 |
+
stride_zb,
|
| 61 |
+
stride_zm,
|
| 62 |
+
TL_FP8_DTYPE: tl.constexpr,
|
| 63 |
+
MAX_FP8: tl.constexpr,
|
| 64 |
+
EPS: tl.constexpr,
|
| 65 |
+
CLAMP_MAX: tl.constexpr,
|
| 66 |
+
JAGGED: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
USE_INT64: tl.constexpr,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""Quantize and scale each row.
|
| 71 |
+
|
| 72 |
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
| 73 |
+
|
| 74 |
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
| 75 |
+
in a max pass then scale/quantize pass.
|
| 76 |
+
|
| 77 |
+
Todo:
|
| 78 |
+
* Better tiling schemes.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
A (Tensor): higher precision input tensor of 4 dimension.
|
| 82 |
+
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
| 83 |
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
| 84 |
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
| 85 |
+
B (int): Size of dimenion 0
|
| 86 |
+
M (int): Size of dimenion 1
|
| 87 |
+
N (int): Size of dimenion 2
|
| 88 |
+
K (int): Size of dimenion 3 (input row size)
|
| 89 |
+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
|
| 90 |
+
stride_ab (int): Stride of b dimension of A.
|
| 91 |
+
stride_am (int): Stride of m dimension of A.
|
| 92 |
+
stride_an (int): Stride of n dimension of A.
|
| 93 |
+
stride_ak (int): Stride of k dimension of A.
|
| 94 |
+
stride_ob (int): Stride of b dimension of output.
|
| 95 |
+
stride_om (int): Stride of m dimension of output.
|
| 96 |
+
stride_on (int): Stride of n dimension of output.
|
| 97 |
+
stride_ok (int): Stride of k dimension of output.
|
| 98 |
+
stride_zb (int): Stride of b dimension of jagged index.
|
| 99 |
+
stride_zm (int): Stride of m dimension of jagged index.
|
| 100 |
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
| 101 |
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
| 102 |
+
EPS (float): Epsilon value for numerical stability.
|
| 103 |
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
| 104 |
+
JAGGED (bool): Whether to use jagged indexing.
|
| 105 |
+
BLOCK_SIZE (int): Block size for reduction.
|
| 106 |
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
| 107 |
+
"""
|
| 108 |
+
pid = tl.program_id(0)
|
| 109 |
+
# Use int64 indexing for large inputs. This is slower, but
|
| 110 |
+
# needed to avoid index overflows.
|
| 111 |
+
if USE_INT64:
|
| 112 |
+
pid = pid.to(tl.int64)
|
| 113 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 114 |
+
a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
|
| 115 |
+
a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
|
| 116 |
+
|
| 117 |
+
K_in = K
|
| 118 |
+
if JAGGED:
|
| 119 |
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
| 120 |
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
| 121 |
+
current_row = pid % N
|
| 122 |
+
# If this row is empty, dont process any of it.
|
| 123 |
+
if current_row >= group_rows:
|
| 124 |
+
K_in = 0
|
| 125 |
+
|
| 126 |
+
# Calculate max.
|
| 127 |
+
cur_max = 0.0
|
| 128 |
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
| 129 |
+
a = tl.load(
|
| 130 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 131 |
+
mask=n_offset < K_in,
|
| 132 |
+
other=0.0,
|
| 133 |
+
)
|
| 134 |
+
tile_max = tl.max(tl.abs(a))
|
| 135 |
+
cur_max = tl.maximum(tile_max, cur_max)
|
| 136 |
+
n_offset += BLOCK_SIZE
|
| 137 |
+
# Clamp max value appropriately.
|
| 138 |
+
if CLAMP_MAX:
|
| 139 |
+
ub = tl.load(scale_ub)
|
| 140 |
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
| 141 |
+
else:
|
| 142 |
+
cur_max = tl.maximum(cur_max, EPS)
|
| 143 |
+
# Scale and quantize.
|
| 144 |
+
a_scale = MAX_FP8 / cur_max
|
| 145 |
+
tl.store(A_scale + pid, 1.0 / a_scale)
|
| 146 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 147 |
+
|
| 148 |
+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
|
| 149 |
+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
|
| 150 |
+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
|
| 151 |
+
a = tl.load(
|
| 152 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 153 |
+
mask=n_offset < K_in,
|
| 154 |
+
other=0.0,
|
| 155 |
+
)
|
| 156 |
+
# For elements >= K, a will be 0
|
| 157 |
+
a_fp8 = a * a_scale
|
| 158 |
+
# Clamp A to fp8 range to make sure there's no overflow.
|
| 159 |
+
# This is required for AMD. Nvidia's default saturation
|
| 160 |
+
# handles it, but it's nice to have anyway.
|
| 161 |
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
| 162 |
+
|
| 163 |
+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
|
| 164 |
+
tl.store(
|
| 165 |
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
| 166 |
+
a_fp8,
|
| 167 |
+
mask=n_offset < K_fp8,
|
| 168 |
+
)
|
| 169 |
+
n_offset += BLOCK_SIZE
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def quantize_fp8_per_row(
|
| 173 |
+
a: torch.Tensor,
|
| 174 |
+
scale_ub: Optional[torch.Tensor] = None,
|
| 175 |
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
| 176 |
+
align_rows_to: Optional[int] = None,
|
| 177 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 178 |
+
"""
|
| 179 |
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
a (Tensor): higher precision input tensor of 4 dimension.
|
| 183 |
+
scale_ub (Tensor): Maximum allowed value for scale.
|
| 184 |
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
| 185 |
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
| 186 |
+
Returns:
|
| 187 |
+
torch.Tensor: fp8 scaled tensor.
|
| 188 |
+
torch.Tensor: reciprocal scale tensor per row.
|
| 189 |
+
"""
|
| 190 |
+
# Handle meta tensors (skip kernel execution)
|
| 191 |
+
if a.device.type == "meta":
|
| 192 |
+
pt_dtype, _, _, _ = get_fp8_constants()
|
| 193 |
+
a_shape = list(a.shape)
|
| 194 |
+
if align_rows_to is not None:
|
| 195 |
+
last_dim = a_shape[-1]
|
| 196 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 197 |
+
a_shape[-1] = padded_last_dim
|
| 198 |
+
|
| 199 |
+
# Return empty meta tensors with correct shapes
|
| 200 |
+
return (
|
| 201 |
+
torch.empty(a_shape, device="meta", dtype=pt_dtype),
|
| 202 |
+
torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if scale_ub is not None and scale_ub.device != a.device:
|
| 206 |
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
| 207 |
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
| 208 |
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
| 209 |
+
|
| 210 |
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
| 211 |
+
a_shape = a.shape
|
| 212 |
+
while a.dim() < 4:
|
| 213 |
+
a = a.unsqueeze(0)
|
| 214 |
+
if zero_start_index_M is not None:
|
| 215 |
+
# There should be one value of zero_start_index_M per NxK matrix.
|
| 216 |
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
| 217 |
+
# Get constant values.
|
| 218 |
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
| 219 |
+
num_rows = a.numel() // a.shape[-1]
|
| 220 |
+
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
|
| 221 |
+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
|
| 222 |
+
if align_rows_to is not None:
|
| 223 |
+
last_dim = a.shape[-1]
|
| 224 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 225 |
+
a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
|
| 226 |
+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
|
| 227 |
+
else:
|
| 228 |
+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
|
| 229 |
+
|
| 230 |
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
| 231 |
+
use_int64 = a.numel() > (2**31 - 1)
|
| 232 |
+
grid = (num_rows,)
|
| 233 |
+
_kernel_quantize_fp8_row[grid](
|
| 234 |
+
a,
|
| 235 |
+
a_scale,
|
| 236 |
+
a_fp8,
|
| 237 |
+
scale_ub,
|
| 238 |
+
zero_start_index_M,
|
| 239 |
+
a.shape[0],
|
| 240 |
+
a.shape[1],
|
| 241 |
+
a.shape[2],
|
| 242 |
+
a.shape[3],
|
| 243 |
+
a_fp8.shape[3],
|
| 244 |
+
a.stride(0),
|
| 245 |
+
a.stride(1),
|
| 246 |
+
a.stride(2),
|
| 247 |
+
a.stride(3),
|
| 248 |
+
a_fp8.stride(0),
|
| 249 |
+
a_fp8.stride(1),
|
| 250 |
+
a_fp8.stride(2),
|
| 251 |
+
a_fp8.stride(3),
|
| 252 |
+
(zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
|
| 253 |
+
(zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
|
| 254 |
+
TL_FP8_DTYPE=tl_dtype,
|
| 255 |
+
MAX_FP8=max_fp8,
|
| 256 |
+
EPS=eps,
|
| 257 |
+
CLAMP_MAX=scale_ub is not None,
|
| 258 |
+
JAGGED=zero_start_index_M is not None,
|
| 259 |
+
USE_INT64=use_int64,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
|
build/torch-rocm/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantizer import quantize_fp8_per_row
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
quantize_fp8_per_row
|
| 5 |
+
]
|
build/torch-rocm/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
|
build/torch-rocm/fp8_fbgemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-rocm/metadata.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"python-depends":[]}
|
build/torch-rocm/quantizer.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license
|
| 5 |
+
# copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
from torch import nn
|
| 12 |
+
from triton import Config
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
|
| 15 |
+
def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
|
| 16 |
+
"""
|
| 17 |
+
Helper function to get constant values for the current platform.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
pt_dtype (torch.dtype): The correct torch fp8 datatype.
|
| 21 |
+
tl_dtype (tl.dtype): The correct triton fp8 datatype.
|
| 22 |
+
max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
|
| 23 |
+
eps (float): Minimum clip value to prevent divide by zero.
|
| 24 |
+
"""
|
| 25 |
+
pt_fp8_dtype = torch.float8_e4m3fn
|
| 26 |
+
tl_fp8_dtype = tl.float8e4nv
|
| 27 |
+
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@triton.autotune(
|
| 31 |
+
configs=[
|
| 32 |
+
Config({"BLOCK_SIZE": 512}),
|
| 33 |
+
Config({"BLOCK_SIZE": 1024}),
|
| 34 |
+
Config({"BLOCK_SIZE": 2048}),
|
| 35 |
+
Config({"BLOCK_SIZE": 4096}),
|
| 36 |
+
Config({"BLOCK_SIZE": 8192}),
|
| 37 |
+
],
|
| 38 |
+
key=["K"],
|
| 39 |
+
)
|
| 40 |
+
@triton.jit
|
| 41 |
+
def _kernel_quantize_fp8_row(
|
| 42 |
+
A,
|
| 43 |
+
A_scale,
|
| 44 |
+
A_fp8,
|
| 45 |
+
scale_ub,
|
| 46 |
+
zero_start_index_M,
|
| 47 |
+
B,
|
| 48 |
+
M,
|
| 49 |
+
N,
|
| 50 |
+
K,
|
| 51 |
+
K_fp8, # used when padding
|
| 52 |
+
stride_ab,
|
| 53 |
+
stride_am,
|
| 54 |
+
stride_an,
|
| 55 |
+
stride_ak,
|
| 56 |
+
stride_ob,
|
| 57 |
+
stride_om,
|
| 58 |
+
stride_on,
|
| 59 |
+
stride_ok,
|
| 60 |
+
stride_zb,
|
| 61 |
+
stride_zm,
|
| 62 |
+
TL_FP8_DTYPE: tl.constexpr,
|
| 63 |
+
MAX_FP8: tl.constexpr,
|
| 64 |
+
EPS: tl.constexpr,
|
| 65 |
+
CLAMP_MAX: tl.constexpr,
|
| 66 |
+
JAGGED: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
USE_INT64: tl.constexpr,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""Quantize and scale each row.
|
| 71 |
+
|
| 72 |
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
| 73 |
+
|
| 74 |
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
| 75 |
+
in a max pass then scale/quantize pass.
|
| 76 |
+
|
| 77 |
+
Todo:
|
| 78 |
+
* Better tiling schemes.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
A (Tensor): higher precision input tensor of 4 dimension.
|
| 82 |
+
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
| 83 |
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
| 84 |
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
| 85 |
+
B (int): Size of dimenion 0
|
| 86 |
+
M (int): Size of dimenion 1
|
| 87 |
+
N (int): Size of dimenion 2
|
| 88 |
+
K (int): Size of dimenion 3 (input row size)
|
| 89 |
+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
|
| 90 |
+
stride_ab (int): Stride of b dimension of A.
|
| 91 |
+
stride_am (int): Stride of m dimension of A.
|
| 92 |
+
stride_an (int): Stride of n dimension of A.
|
| 93 |
+
stride_ak (int): Stride of k dimension of A.
|
| 94 |
+
stride_ob (int): Stride of b dimension of output.
|
| 95 |
+
stride_om (int): Stride of m dimension of output.
|
| 96 |
+
stride_on (int): Stride of n dimension of output.
|
| 97 |
+
stride_ok (int): Stride of k dimension of output.
|
| 98 |
+
stride_zb (int): Stride of b dimension of jagged index.
|
| 99 |
+
stride_zm (int): Stride of m dimension of jagged index.
|
| 100 |
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
| 101 |
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
| 102 |
+
EPS (float): Epsilon value for numerical stability.
|
| 103 |
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
| 104 |
+
JAGGED (bool): Whether to use jagged indexing.
|
| 105 |
+
BLOCK_SIZE (int): Block size for reduction.
|
| 106 |
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
| 107 |
+
"""
|
| 108 |
+
pid = tl.program_id(0)
|
| 109 |
+
# Use int64 indexing for large inputs. This is slower, but
|
| 110 |
+
# needed to avoid index overflows.
|
| 111 |
+
if USE_INT64:
|
| 112 |
+
pid = pid.to(tl.int64)
|
| 113 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 114 |
+
a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
|
| 115 |
+
a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
|
| 116 |
+
|
| 117 |
+
K_in = K
|
| 118 |
+
if JAGGED:
|
| 119 |
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
| 120 |
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
| 121 |
+
current_row = pid % N
|
| 122 |
+
# If this row is empty, dont process any of it.
|
| 123 |
+
if current_row >= group_rows:
|
| 124 |
+
K_in = 0
|
| 125 |
+
|
| 126 |
+
# Calculate max.
|
| 127 |
+
cur_max = 0.0
|
| 128 |
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
| 129 |
+
a = tl.load(
|
| 130 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 131 |
+
mask=n_offset < K_in,
|
| 132 |
+
other=0.0,
|
| 133 |
+
)
|
| 134 |
+
tile_max = tl.max(tl.abs(a))
|
| 135 |
+
cur_max = tl.maximum(tile_max, cur_max)
|
| 136 |
+
n_offset += BLOCK_SIZE
|
| 137 |
+
# Clamp max value appropriately.
|
| 138 |
+
if CLAMP_MAX:
|
| 139 |
+
ub = tl.load(scale_ub)
|
| 140 |
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
| 141 |
+
else:
|
| 142 |
+
cur_max = tl.maximum(cur_max, EPS)
|
| 143 |
+
# Scale and quantize.
|
| 144 |
+
a_scale = MAX_FP8 / cur_max
|
| 145 |
+
tl.store(A_scale + pid, 1.0 / a_scale)
|
| 146 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 147 |
+
|
| 148 |
+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
|
| 149 |
+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
|
| 150 |
+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
|
| 151 |
+
a = tl.load(
|
| 152 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 153 |
+
mask=n_offset < K_in,
|
| 154 |
+
other=0.0,
|
| 155 |
+
)
|
| 156 |
+
# For elements >= K, a will be 0
|
| 157 |
+
a_fp8 = a * a_scale
|
| 158 |
+
# Clamp A to fp8 range to make sure there's no overflow.
|
| 159 |
+
# This is required for AMD. Nvidia's default saturation
|
| 160 |
+
# handles it, but it's nice to have anyway.
|
| 161 |
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
| 162 |
+
|
| 163 |
+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
|
| 164 |
+
tl.store(
|
| 165 |
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
| 166 |
+
a_fp8,
|
| 167 |
+
mask=n_offset < K_fp8,
|
| 168 |
+
)
|
| 169 |
+
n_offset += BLOCK_SIZE
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def quantize_fp8_per_row(
|
| 173 |
+
a: torch.Tensor,
|
| 174 |
+
scale_ub: Optional[torch.Tensor] = None,
|
| 175 |
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
| 176 |
+
align_rows_to: Optional[int] = None,
|
| 177 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 178 |
+
"""
|
| 179 |
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
a (Tensor): higher precision input tensor of 4 dimension.
|
| 183 |
+
scale_ub (Tensor): Maximum allowed value for scale.
|
| 184 |
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
| 185 |
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
| 186 |
+
Returns:
|
| 187 |
+
torch.Tensor: fp8 scaled tensor.
|
| 188 |
+
torch.Tensor: reciprocal scale tensor per row.
|
| 189 |
+
"""
|
| 190 |
+
# Handle meta tensors (skip kernel execution)
|
| 191 |
+
if a.device.type == "meta":
|
| 192 |
+
pt_dtype, _, _, _ = get_fp8_constants()
|
| 193 |
+
a_shape = list(a.shape)
|
| 194 |
+
if align_rows_to is not None:
|
| 195 |
+
last_dim = a_shape[-1]
|
| 196 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 197 |
+
a_shape[-1] = padded_last_dim
|
| 198 |
+
|
| 199 |
+
# Return empty meta tensors with correct shapes
|
| 200 |
+
return (
|
| 201 |
+
torch.empty(a_shape, device="meta", dtype=pt_dtype),
|
| 202 |
+
torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if scale_ub is not None and scale_ub.device != a.device:
|
| 206 |
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
| 207 |
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
| 208 |
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
| 209 |
+
|
| 210 |
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
| 211 |
+
a_shape = a.shape
|
| 212 |
+
while a.dim() < 4:
|
| 213 |
+
a = a.unsqueeze(0)
|
| 214 |
+
if zero_start_index_M is not None:
|
| 215 |
+
# There should be one value of zero_start_index_M per NxK matrix.
|
| 216 |
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
| 217 |
+
# Get constant values.
|
| 218 |
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
| 219 |
+
num_rows = a.numel() // a.shape[-1]
|
| 220 |
+
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
|
| 221 |
+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
|
| 222 |
+
if align_rows_to is not None:
|
| 223 |
+
last_dim = a.shape[-1]
|
| 224 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 225 |
+
a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
|
| 226 |
+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
|
| 227 |
+
else:
|
| 228 |
+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
|
| 229 |
+
|
| 230 |
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
| 231 |
+
use_int64 = a.numel() > (2**31 - 1)
|
| 232 |
+
grid = (num_rows,)
|
| 233 |
+
_kernel_quantize_fp8_row[grid](
|
| 234 |
+
a,
|
| 235 |
+
a_scale,
|
| 236 |
+
a_fp8,
|
| 237 |
+
scale_ub,
|
| 238 |
+
zero_start_index_M,
|
| 239 |
+
a.shape[0],
|
| 240 |
+
a.shape[1],
|
| 241 |
+
a.shape[2],
|
| 242 |
+
a.shape[3],
|
| 243 |
+
a_fp8.shape[3],
|
| 244 |
+
a.stride(0),
|
| 245 |
+
a.stride(1),
|
| 246 |
+
a.stride(2),
|
| 247 |
+
a.stride(3),
|
| 248 |
+
a_fp8.stride(0),
|
| 249 |
+
a_fp8.stride(1),
|
| 250 |
+
a_fp8.stride(2),
|
| 251 |
+
a_fp8.stride(3),
|
| 252 |
+
(zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
|
| 253 |
+
(zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
|
| 254 |
+
TL_FP8_DTYPE=tl_dtype,
|
| 255 |
+
MAX_FP8=max_fp8,
|
| 256 |
+
EPS=eps,
|
| 257 |
+
CLAMP_MAX=scale_ub is not None,
|
| 258 |
+
JAGGED=zero_start_index_M is not None,
|
| 259 |
+
USE_INT64=use_int64,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
|
build/torch-xpu/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantizer import quantize_fp8_per_row
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
quantize_fp8_per_row
|
| 5 |
+
]
|
build/torch-xpu/_ops.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
|
| 3 |
+
|
| 4 |
+
def add_op_namespace_prefix(op_name: str):
|
| 5 |
+
"""
|
| 6 |
+
Prefix op by namespace.
|
| 7 |
+
"""
|
| 8 |
+
return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
|
build/torch-xpu/fp8_fbgemm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import importlib
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from types import ModuleType
|
| 7 |
+
|
| 8 |
+
def _import_from_path(file_path: Path) -> ModuleType:
|
| 9 |
+
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
| 10 |
+
# it would also be used for other imports. So, we make a module name that
|
| 11 |
+
# depends on the path for it to be unique using the hex-encoded hash of
|
| 12 |
+
# the path.
|
| 13 |
+
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
|
| 14 |
+
module_name = path_hash
|
| 15 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 16 |
+
if spec is None:
|
| 17 |
+
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
| 18 |
+
module = importlib.util.module_from_spec(spec)
|
| 19 |
+
if module is None:
|
| 20 |
+
raise ImportError(f"Cannot load module {module_name} from spec")
|
| 21 |
+
sys.modules[module_name] = module
|
| 22 |
+
spec.loader.exec_module(module) # type: ignore
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
|
build/torch-xpu/metadata.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"python-depends":[]}
|
build/torch-xpu/quantizer.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license
|
| 5 |
+
# copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import triton
|
| 10 |
+
import triton.language as tl
|
| 11 |
+
from torch import nn
|
| 12 |
+
from triton import Config
|
| 13 |
+
from typing import Any, Optional
|
| 14 |
+
|
| 15 |
+
def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
|
| 16 |
+
"""
|
| 17 |
+
Helper function to get constant values for the current platform.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
pt_dtype (torch.dtype): The correct torch fp8 datatype.
|
| 21 |
+
tl_dtype (tl.dtype): The correct triton fp8 datatype.
|
| 22 |
+
max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
|
| 23 |
+
eps (float): Minimum clip value to prevent divide by zero.
|
| 24 |
+
"""
|
| 25 |
+
pt_fp8_dtype = torch.float8_e4m3fn
|
| 26 |
+
tl_fp8_dtype = tl.float8e4nv
|
| 27 |
+
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@triton.autotune(
|
| 31 |
+
configs=[
|
| 32 |
+
Config({"BLOCK_SIZE": 512}),
|
| 33 |
+
Config({"BLOCK_SIZE": 1024}),
|
| 34 |
+
Config({"BLOCK_SIZE": 2048}),
|
| 35 |
+
Config({"BLOCK_SIZE": 4096}),
|
| 36 |
+
Config({"BLOCK_SIZE": 8192}),
|
| 37 |
+
],
|
| 38 |
+
key=["K"],
|
| 39 |
+
)
|
| 40 |
+
@triton.jit
|
| 41 |
+
def _kernel_quantize_fp8_row(
|
| 42 |
+
A,
|
| 43 |
+
A_scale,
|
| 44 |
+
A_fp8,
|
| 45 |
+
scale_ub,
|
| 46 |
+
zero_start_index_M,
|
| 47 |
+
B,
|
| 48 |
+
M,
|
| 49 |
+
N,
|
| 50 |
+
K,
|
| 51 |
+
K_fp8, # used when padding
|
| 52 |
+
stride_ab,
|
| 53 |
+
stride_am,
|
| 54 |
+
stride_an,
|
| 55 |
+
stride_ak,
|
| 56 |
+
stride_ob,
|
| 57 |
+
stride_om,
|
| 58 |
+
stride_on,
|
| 59 |
+
stride_ok,
|
| 60 |
+
stride_zb,
|
| 61 |
+
stride_zm,
|
| 62 |
+
TL_FP8_DTYPE: tl.constexpr,
|
| 63 |
+
MAX_FP8: tl.constexpr,
|
| 64 |
+
EPS: tl.constexpr,
|
| 65 |
+
CLAMP_MAX: tl.constexpr,
|
| 66 |
+
JAGGED: tl.constexpr,
|
| 67 |
+
BLOCK_SIZE: tl.constexpr,
|
| 68 |
+
USE_INT64: tl.constexpr,
|
| 69 |
+
) -> None:
|
| 70 |
+
"""Quantize and scale each row.
|
| 71 |
+
|
| 72 |
+
Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
|
| 73 |
+
|
| 74 |
+
Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
|
| 75 |
+
in a max pass then scale/quantize pass.
|
| 76 |
+
|
| 77 |
+
Todo:
|
| 78 |
+
* Better tiling schemes.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
A (Tensor): higher precision input tensor of 4 dimension.
|
| 82 |
+
A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
|
| 83 |
+
A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
|
| 84 |
+
scale_ub (Tensor): [1] Maximum value allowed for scale.
|
| 85 |
+
B (int): Size of dimenion 0
|
| 86 |
+
M (int): Size of dimenion 1
|
| 87 |
+
N (int): Size of dimenion 2
|
| 88 |
+
K (int): Size of dimenion 3 (input row size)
|
| 89 |
+
K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
|
| 90 |
+
stride_ab (int): Stride of b dimension of A.
|
| 91 |
+
stride_am (int): Stride of m dimension of A.
|
| 92 |
+
stride_an (int): Stride of n dimension of A.
|
| 93 |
+
stride_ak (int): Stride of k dimension of A.
|
| 94 |
+
stride_ob (int): Stride of b dimension of output.
|
| 95 |
+
stride_om (int): Stride of m dimension of output.
|
| 96 |
+
stride_on (int): Stride of n dimension of output.
|
| 97 |
+
stride_ok (int): Stride of k dimension of output.
|
| 98 |
+
stride_zb (int): Stride of b dimension of jagged index.
|
| 99 |
+
stride_zm (int): Stride of m dimension of jagged index.
|
| 100 |
+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
|
| 101 |
+
MAX_FP8 (float): Maxmimum expressible value for FP8.
|
| 102 |
+
EPS (float): Epsilon value for numerical stability.
|
| 103 |
+
CLAMP_MAX (bool): Whethar to apply scale_ub.
|
| 104 |
+
JAGGED (bool): Whether to use jagged indexing.
|
| 105 |
+
BLOCK_SIZE (int): Block size for reduction.
|
| 106 |
+
USE_INT64 (bool): Whether to use int64 indexing for large inputs.
|
| 107 |
+
"""
|
| 108 |
+
pid = tl.program_id(0)
|
| 109 |
+
# Use int64 indexing for large inputs. This is slower, but
|
| 110 |
+
# needed to avoid index overflows.
|
| 111 |
+
if USE_INT64:
|
| 112 |
+
pid = pid.to(tl.int64)
|
| 113 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 114 |
+
a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
|
| 115 |
+
a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
|
| 116 |
+
|
| 117 |
+
K_in = K
|
| 118 |
+
if JAGGED:
|
| 119 |
+
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
|
| 120 |
+
group_rows = tl.load(zero_start_index_M + z_offset_base)
|
| 121 |
+
current_row = pid % N
|
| 122 |
+
# If this row is empty, dont process any of it.
|
| 123 |
+
if current_row >= group_rows:
|
| 124 |
+
K_in = 0
|
| 125 |
+
|
| 126 |
+
# Calculate max.
|
| 127 |
+
cur_max = 0.0
|
| 128 |
+
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
|
| 129 |
+
a = tl.load(
|
| 130 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 131 |
+
mask=n_offset < K_in,
|
| 132 |
+
other=0.0,
|
| 133 |
+
)
|
| 134 |
+
tile_max = tl.max(tl.abs(a))
|
| 135 |
+
cur_max = tl.maximum(tile_max, cur_max)
|
| 136 |
+
n_offset += BLOCK_SIZE
|
| 137 |
+
# Clamp max value appropriately.
|
| 138 |
+
if CLAMP_MAX:
|
| 139 |
+
ub = tl.load(scale_ub)
|
| 140 |
+
cur_max = tl.clamp(cur_max, EPS, ub)
|
| 141 |
+
else:
|
| 142 |
+
cur_max = tl.maximum(cur_max, EPS)
|
| 143 |
+
# Scale and quantize.
|
| 144 |
+
a_scale = MAX_FP8 / cur_max
|
| 145 |
+
tl.store(A_scale + pid, 1.0 / a_scale)
|
| 146 |
+
n_offset = tl.arange(0, BLOCK_SIZE)
|
| 147 |
+
|
| 148 |
+
# Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
|
| 149 |
+
for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
|
| 150 |
+
# Load from A if in range, else 0 (we're going all the way to K_fp8)
|
| 151 |
+
a = tl.load(
|
| 152 |
+
A + a_offset_base + n_offset * stride_ak,
|
| 153 |
+
mask=n_offset < K_in,
|
| 154 |
+
other=0.0,
|
| 155 |
+
)
|
| 156 |
+
# For elements >= K, a will be 0
|
| 157 |
+
a_fp8 = a * a_scale
|
| 158 |
+
# Clamp A to fp8 range to make sure there's no overflow.
|
| 159 |
+
# This is required for AMD. Nvidia's default saturation
|
| 160 |
+
# handles it, but it's nice to have anyway.
|
| 161 |
+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
|
| 162 |
+
|
| 163 |
+
# Store the full new row in its place (for elements >= K, a_fp8 is already 0)
|
| 164 |
+
tl.store(
|
| 165 |
+
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
|
| 166 |
+
a_fp8,
|
| 167 |
+
mask=n_offset < K_fp8,
|
| 168 |
+
)
|
| 169 |
+
n_offset += BLOCK_SIZE
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def quantize_fp8_per_row(
|
| 173 |
+
a: torch.Tensor,
|
| 174 |
+
scale_ub: Optional[torch.Tensor] = None,
|
| 175 |
+
zero_start_index_M: Optional[torch.Tensor] = None,
|
| 176 |
+
align_rows_to: Optional[int] = None,
|
| 177 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 178 |
+
"""
|
| 179 |
+
Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
a (Tensor): higher precision input tensor of 4 dimension.
|
| 183 |
+
scale_ub (Tensor): Maximum allowed value for scale.
|
| 184 |
+
zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
|
| 185 |
+
align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
|
| 186 |
+
Returns:
|
| 187 |
+
torch.Tensor: fp8 scaled tensor.
|
| 188 |
+
torch.Tensor: reciprocal scale tensor per row.
|
| 189 |
+
"""
|
| 190 |
+
# Handle meta tensors (skip kernel execution)
|
| 191 |
+
if a.device.type == "meta":
|
| 192 |
+
pt_dtype, _, _, _ = get_fp8_constants()
|
| 193 |
+
a_shape = list(a.shape)
|
| 194 |
+
if align_rows_to is not None:
|
| 195 |
+
last_dim = a_shape[-1]
|
| 196 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 197 |
+
a_shape[-1] = padded_last_dim
|
| 198 |
+
|
| 199 |
+
# Return empty meta tensors with correct shapes
|
| 200 |
+
return (
|
| 201 |
+
torch.empty(a_shape, device="meta", dtype=pt_dtype),
|
| 202 |
+
torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if scale_ub is not None and scale_ub.device != a.device:
|
| 206 |
+
raise Exception("'scale_ub' must be on the same device as 'a'")
|
| 207 |
+
if zero_start_index_M is not None and zero_start_index_M.device != a.device:
|
| 208 |
+
raise Exception("'zero_start_index_M' must be on the same device as 'a'")
|
| 209 |
+
|
| 210 |
+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
|
| 211 |
+
a_shape = a.shape
|
| 212 |
+
while a.dim() < 4:
|
| 213 |
+
a = a.unsqueeze(0)
|
| 214 |
+
if zero_start_index_M is not None:
|
| 215 |
+
# There should be one value of zero_start_index_M per NxK matrix.
|
| 216 |
+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
|
| 217 |
+
# Get constant values.
|
| 218 |
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
| 219 |
+
num_rows = a.numel() // a.shape[-1]
|
| 220 |
+
a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
|
| 221 |
+
# If align_rows_to is provided, pad the last dimension to be a multiple of it
|
| 222 |
+
if align_rows_to is not None:
|
| 223 |
+
last_dim = a.shape[-1]
|
| 224 |
+
padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
|
| 225 |
+
a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
|
| 226 |
+
a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
|
| 227 |
+
else:
|
| 228 |
+
a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
|
| 229 |
+
|
| 230 |
+
# If input tensor is sufficiently large, we need to use int64 indexing.
|
| 231 |
+
use_int64 = a.numel() > (2**31 - 1)
|
| 232 |
+
grid = (num_rows,)
|
| 233 |
+
_kernel_quantize_fp8_row[grid](
|
| 234 |
+
a,
|
| 235 |
+
a_scale,
|
| 236 |
+
a_fp8,
|
| 237 |
+
scale_ub,
|
| 238 |
+
zero_start_index_M,
|
| 239 |
+
a.shape[0],
|
| 240 |
+
a.shape[1],
|
| 241 |
+
a.shape[2],
|
| 242 |
+
a.shape[3],
|
| 243 |
+
a_fp8.shape[3],
|
| 244 |
+
a.stride(0),
|
| 245 |
+
a.stride(1),
|
| 246 |
+
a.stride(2),
|
| 247 |
+
a.stride(3),
|
| 248 |
+
a_fp8.stride(0),
|
| 249 |
+
a_fp8.stride(1),
|
| 250 |
+
a_fp8.stride(2),
|
| 251 |
+
a_fp8.stride(3),
|
| 252 |
+
(zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
|
| 253 |
+
(zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
|
| 254 |
+
TL_FP8_DTYPE=tl_dtype,
|
| 255 |
+
MAX_FP8=max_fp8,
|
| 256 |
+
EPS=eps,
|
| 257 |
+
CLAMP_MAX=scale_ub is not None,
|
| 258 |
+
JAGGED=zero_start_index_M is not None,
|
| 259 |
+
USE_INT64=use_int64,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
|