medmekk HF Staff commited on
Commit
d62e31c
·
verified ·
1 Parent(s): 0d2ba94

Build uploaded using `kernels`.

Browse files
build/torch-cpu/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .quantizer import quantize_fp8_per_row
2
+
3
+ __all__ = [
4
+ quantize_fp8_per_row
5
+ ]
build/torch-cpu/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
build/torch-cpu/fp8_fbgemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-cpu/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch-cpu/quantizer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license
5
+ # copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
6
+
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+ from torch import nn
12
+ from triton import Config
13
+ from typing import Any, Optional
14
+
15
+ def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
16
+ """
17
+ Helper function to get constant values for the current platform.
18
+
19
+ Returns:
20
+ pt_dtype (torch.dtype): The correct torch fp8 datatype.
21
+ tl_dtype (tl.dtype): The correct triton fp8 datatype.
22
+ max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
23
+ eps (float): Minimum clip value to prevent divide by zero.
24
+ """
25
+ pt_fp8_dtype = torch.float8_e4m3fn
26
+ tl_fp8_dtype = tl.float8e4nv
27
+ return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
28
+
29
+
30
+ @triton.autotune(
31
+ configs=[
32
+ Config({"BLOCK_SIZE": 512}),
33
+ Config({"BLOCK_SIZE": 1024}),
34
+ Config({"BLOCK_SIZE": 2048}),
35
+ Config({"BLOCK_SIZE": 4096}),
36
+ Config({"BLOCK_SIZE": 8192}),
37
+ ],
38
+ key=["K"],
39
+ )
40
+ @triton.jit
41
+ def _kernel_quantize_fp8_row(
42
+ A,
43
+ A_scale,
44
+ A_fp8,
45
+ scale_ub,
46
+ zero_start_index_M,
47
+ B,
48
+ M,
49
+ N,
50
+ K,
51
+ K_fp8, # used when padding
52
+ stride_ab,
53
+ stride_am,
54
+ stride_an,
55
+ stride_ak,
56
+ stride_ob,
57
+ stride_om,
58
+ stride_on,
59
+ stride_ok,
60
+ stride_zb,
61
+ stride_zm,
62
+ TL_FP8_DTYPE: tl.constexpr,
63
+ MAX_FP8: tl.constexpr,
64
+ EPS: tl.constexpr,
65
+ CLAMP_MAX: tl.constexpr,
66
+ JAGGED: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ USE_INT64: tl.constexpr,
69
+ ) -> None:
70
+ """Quantize and scale each row.
71
+
72
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
73
+
74
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
75
+ in a max pass then scale/quantize pass.
76
+
77
+ Todo:
78
+ * Better tiling schemes.
79
+
80
+ Args:
81
+ A (Tensor): higher precision input tensor of 4 dimension.
82
+ A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
83
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
84
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
85
+ B (int): Size of dimenion 0
86
+ M (int): Size of dimenion 1
87
+ N (int): Size of dimenion 2
88
+ K (int): Size of dimenion 3 (input row size)
89
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
90
+ stride_ab (int): Stride of b dimension of A.
91
+ stride_am (int): Stride of m dimension of A.
92
+ stride_an (int): Stride of n dimension of A.
93
+ stride_ak (int): Stride of k dimension of A.
94
+ stride_ob (int): Stride of b dimension of output.
95
+ stride_om (int): Stride of m dimension of output.
96
+ stride_on (int): Stride of n dimension of output.
97
+ stride_ok (int): Stride of k dimension of output.
98
+ stride_zb (int): Stride of b dimension of jagged index.
99
+ stride_zm (int): Stride of m dimension of jagged index.
100
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
101
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
102
+ EPS (float): Epsilon value for numerical stability.
103
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
104
+ JAGGED (bool): Whether to use jagged indexing.
105
+ BLOCK_SIZE (int): Block size for reduction.
106
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
107
+ """
108
+ pid = tl.program_id(0)
109
+ # Use int64 indexing for large inputs. This is slower, but
110
+ # needed to avoid index overflows.
111
+ if USE_INT64:
112
+ pid = pid.to(tl.int64)
113
+ n_offset = tl.arange(0, BLOCK_SIZE)
114
+ a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
115
+ a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
116
+
117
+ K_in = K
118
+ if JAGGED:
119
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
120
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
121
+ current_row = pid % N
122
+ # If this row is empty, dont process any of it.
123
+ if current_row >= group_rows:
124
+ K_in = 0
125
+
126
+ # Calculate max.
127
+ cur_max = 0.0
128
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
129
+ a = tl.load(
130
+ A + a_offset_base + n_offset * stride_ak,
131
+ mask=n_offset < K_in,
132
+ other=0.0,
133
+ )
134
+ tile_max = tl.max(tl.abs(a))
135
+ cur_max = tl.maximum(tile_max, cur_max)
136
+ n_offset += BLOCK_SIZE
137
+ # Clamp max value appropriately.
138
+ if CLAMP_MAX:
139
+ ub = tl.load(scale_ub)
140
+ cur_max = tl.clamp(cur_max, EPS, ub)
141
+ else:
142
+ cur_max = tl.maximum(cur_max, EPS)
143
+ # Scale and quantize.
144
+ a_scale = MAX_FP8 / cur_max
145
+ tl.store(A_scale + pid, 1.0 / a_scale)
146
+ n_offset = tl.arange(0, BLOCK_SIZE)
147
+
148
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
149
+ for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
150
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
151
+ a = tl.load(
152
+ A + a_offset_base + n_offset * stride_ak,
153
+ mask=n_offset < K_in,
154
+ other=0.0,
155
+ )
156
+ # For elements >= K, a will be 0
157
+ a_fp8 = a * a_scale
158
+ # Clamp A to fp8 range to make sure there's no overflow.
159
+ # This is required for AMD. Nvidia's default saturation
160
+ # handles it, but it's nice to have anyway.
161
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
162
+
163
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
164
+ tl.store(
165
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
166
+ a_fp8,
167
+ mask=n_offset < K_fp8,
168
+ )
169
+ n_offset += BLOCK_SIZE
170
+
171
+
172
+ def quantize_fp8_per_row(
173
+ a: torch.Tensor,
174
+ scale_ub: Optional[torch.Tensor] = None,
175
+ zero_start_index_M: Optional[torch.Tensor] = None,
176
+ align_rows_to: Optional[int] = None,
177
+ ) -> tuple[torch.Tensor, torch.Tensor]:
178
+ """
179
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
180
+
181
+ Args:
182
+ a (Tensor): higher precision input tensor of 4 dimension.
183
+ scale_ub (Tensor): Maximum allowed value for scale.
184
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
185
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
186
+ Returns:
187
+ torch.Tensor: fp8 scaled tensor.
188
+ torch.Tensor: reciprocal scale tensor per row.
189
+ """
190
+ # Handle meta tensors (skip kernel execution)
191
+ if a.device.type == "meta":
192
+ pt_dtype, _, _, _ = get_fp8_constants()
193
+ a_shape = list(a.shape)
194
+ if align_rows_to is not None:
195
+ last_dim = a_shape[-1]
196
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
197
+ a_shape[-1] = padded_last_dim
198
+
199
+ # Return empty meta tensors with correct shapes
200
+ return (
201
+ torch.empty(a_shape, device="meta", dtype=pt_dtype),
202
+ torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
203
+ )
204
+
205
+ if scale_ub is not None and scale_ub.device != a.device:
206
+ raise Exception("'scale_ub' must be on the same device as 'a'")
207
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
208
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
209
+
210
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
211
+ a_shape = a.shape
212
+ while a.dim() < 4:
213
+ a = a.unsqueeze(0)
214
+ if zero_start_index_M is not None:
215
+ # There should be one value of zero_start_index_M per NxK matrix.
216
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
217
+ # Get constant values.
218
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
219
+ num_rows = a.numel() // a.shape[-1]
220
+ a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
221
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
222
+ if align_rows_to is not None:
223
+ last_dim = a.shape[-1]
224
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
225
+ a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
226
+ a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
227
+ else:
228
+ a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
229
+
230
+ # If input tensor is sufficiently large, we need to use int64 indexing.
231
+ use_int64 = a.numel() > (2**31 - 1)
232
+ grid = (num_rows,)
233
+ _kernel_quantize_fp8_row[grid](
234
+ a,
235
+ a_scale,
236
+ a_fp8,
237
+ scale_ub,
238
+ zero_start_index_M,
239
+ a.shape[0],
240
+ a.shape[1],
241
+ a.shape[2],
242
+ a.shape[3],
243
+ a_fp8.shape[3],
244
+ a.stride(0),
245
+ a.stride(1),
246
+ a.stride(2),
247
+ a.stride(3),
248
+ a_fp8.stride(0),
249
+ a_fp8.stride(1),
250
+ a_fp8.stride(2),
251
+ a_fp8.stride(3),
252
+ (zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
253
+ (zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
254
+ TL_FP8_DTYPE=tl_dtype,
255
+ MAX_FP8=max_fp8,
256
+ EPS=eps,
257
+ CLAMP_MAX=scale_ub is not None,
258
+ JAGGED=zero_start_index_M is not None,
259
+ USE_INT64=use_int64,
260
+ )
261
+
262
+ return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
build/torch-cuda/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .quantizer import quantize_fp8_per_row
2
+
3
+ __all__ = [
4
+ quantize_fp8_per_row
5
+ ]
build/torch-cuda/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
build/torch-cuda/fp8_fbgemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-cuda/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch-cuda/quantizer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license
5
+ # copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
6
+
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+ from torch import nn
12
+ from triton import Config
13
+ from typing import Any, Optional
14
+
15
+ def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
16
+ """
17
+ Helper function to get constant values for the current platform.
18
+
19
+ Returns:
20
+ pt_dtype (torch.dtype): The correct torch fp8 datatype.
21
+ tl_dtype (tl.dtype): The correct triton fp8 datatype.
22
+ max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
23
+ eps (float): Minimum clip value to prevent divide by zero.
24
+ """
25
+ pt_fp8_dtype = torch.float8_e4m3fn
26
+ tl_fp8_dtype = tl.float8e4nv
27
+ return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
28
+
29
+
30
+ @triton.autotune(
31
+ configs=[
32
+ Config({"BLOCK_SIZE": 512}),
33
+ Config({"BLOCK_SIZE": 1024}),
34
+ Config({"BLOCK_SIZE": 2048}),
35
+ Config({"BLOCK_SIZE": 4096}),
36
+ Config({"BLOCK_SIZE": 8192}),
37
+ ],
38
+ key=["K"],
39
+ )
40
+ @triton.jit
41
+ def _kernel_quantize_fp8_row(
42
+ A,
43
+ A_scale,
44
+ A_fp8,
45
+ scale_ub,
46
+ zero_start_index_M,
47
+ B,
48
+ M,
49
+ N,
50
+ K,
51
+ K_fp8, # used when padding
52
+ stride_ab,
53
+ stride_am,
54
+ stride_an,
55
+ stride_ak,
56
+ stride_ob,
57
+ stride_om,
58
+ stride_on,
59
+ stride_ok,
60
+ stride_zb,
61
+ stride_zm,
62
+ TL_FP8_DTYPE: tl.constexpr,
63
+ MAX_FP8: tl.constexpr,
64
+ EPS: tl.constexpr,
65
+ CLAMP_MAX: tl.constexpr,
66
+ JAGGED: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ USE_INT64: tl.constexpr,
69
+ ) -> None:
70
+ """Quantize and scale each row.
71
+
72
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
73
+
74
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
75
+ in a max pass then scale/quantize pass.
76
+
77
+ Todo:
78
+ * Better tiling schemes.
79
+
80
+ Args:
81
+ A (Tensor): higher precision input tensor of 4 dimension.
82
+ A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
83
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
84
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
85
+ B (int): Size of dimenion 0
86
+ M (int): Size of dimenion 1
87
+ N (int): Size of dimenion 2
88
+ K (int): Size of dimenion 3 (input row size)
89
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
90
+ stride_ab (int): Stride of b dimension of A.
91
+ stride_am (int): Stride of m dimension of A.
92
+ stride_an (int): Stride of n dimension of A.
93
+ stride_ak (int): Stride of k dimension of A.
94
+ stride_ob (int): Stride of b dimension of output.
95
+ stride_om (int): Stride of m dimension of output.
96
+ stride_on (int): Stride of n dimension of output.
97
+ stride_ok (int): Stride of k dimension of output.
98
+ stride_zb (int): Stride of b dimension of jagged index.
99
+ stride_zm (int): Stride of m dimension of jagged index.
100
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
101
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
102
+ EPS (float): Epsilon value for numerical stability.
103
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
104
+ JAGGED (bool): Whether to use jagged indexing.
105
+ BLOCK_SIZE (int): Block size for reduction.
106
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
107
+ """
108
+ pid = tl.program_id(0)
109
+ # Use int64 indexing for large inputs. This is slower, but
110
+ # needed to avoid index overflows.
111
+ if USE_INT64:
112
+ pid = pid.to(tl.int64)
113
+ n_offset = tl.arange(0, BLOCK_SIZE)
114
+ a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
115
+ a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
116
+
117
+ K_in = K
118
+ if JAGGED:
119
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
120
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
121
+ current_row = pid % N
122
+ # If this row is empty, dont process any of it.
123
+ if current_row >= group_rows:
124
+ K_in = 0
125
+
126
+ # Calculate max.
127
+ cur_max = 0.0
128
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
129
+ a = tl.load(
130
+ A + a_offset_base + n_offset * stride_ak,
131
+ mask=n_offset < K_in,
132
+ other=0.0,
133
+ )
134
+ tile_max = tl.max(tl.abs(a))
135
+ cur_max = tl.maximum(tile_max, cur_max)
136
+ n_offset += BLOCK_SIZE
137
+ # Clamp max value appropriately.
138
+ if CLAMP_MAX:
139
+ ub = tl.load(scale_ub)
140
+ cur_max = tl.clamp(cur_max, EPS, ub)
141
+ else:
142
+ cur_max = tl.maximum(cur_max, EPS)
143
+ # Scale and quantize.
144
+ a_scale = MAX_FP8 / cur_max
145
+ tl.store(A_scale + pid, 1.0 / a_scale)
146
+ n_offset = tl.arange(0, BLOCK_SIZE)
147
+
148
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
149
+ for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
150
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
151
+ a = tl.load(
152
+ A + a_offset_base + n_offset * stride_ak,
153
+ mask=n_offset < K_in,
154
+ other=0.0,
155
+ )
156
+ # For elements >= K, a will be 0
157
+ a_fp8 = a * a_scale
158
+ # Clamp A to fp8 range to make sure there's no overflow.
159
+ # This is required for AMD. Nvidia's default saturation
160
+ # handles it, but it's nice to have anyway.
161
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
162
+
163
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
164
+ tl.store(
165
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
166
+ a_fp8,
167
+ mask=n_offset < K_fp8,
168
+ )
169
+ n_offset += BLOCK_SIZE
170
+
171
+
172
+ def quantize_fp8_per_row(
173
+ a: torch.Tensor,
174
+ scale_ub: Optional[torch.Tensor] = None,
175
+ zero_start_index_M: Optional[torch.Tensor] = None,
176
+ align_rows_to: Optional[int] = None,
177
+ ) -> tuple[torch.Tensor, torch.Tensor]:
178
+ """
179
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
180
+
181
+ Args:
182
+ a (Tensor): higher precision input tensor of 4 dimension.
183
+ scale_ub (Tensor): Maximum allowed value for scale.
184
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
185
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
186
+ Returns:
187
+ torch.Tensor: fp8 scaled tensor.
188
+ torch.Tensor: reciprocal scale tensor per row.
189
+ """
190
+ # Handle meta tensors (skip kernel execution)
191
+ if a.device.type == "meta":
192
+ pt_dtype, _, _, _ = get_fp8_constants()
193
+ a_shape = list(a.shape)
194
+ if align_rows_to is not None:
195
+ last_dim = a_shape[-1]
196
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
197
+ a_shape[-1] = padded_last_dim
198
+
199
+ # Return empty meta tensors with correct shapes
200
+ return (
201
+ torch.empty(a_shape, device="meta", dtype=pt_dtype),
202
+ torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
203
+ )
204
+
205
+ if scale_ub is not None and scale_ub.device != a.device:
206
+ raise Exception("'scale_ub' must be on the same device as 'a'")
207
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
208
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
209
+
210
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
211
+ a_shape = a.shape
212
+ while a.dim() < 4:
213
+ a = a.unsqueeze(0)
214
+ if zero_start_index_M is not None:
215
+ # There should be one value of zero_start_index_M per NxK matrix.
216
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
217
+ # Get constant values.
218
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
219
+ num_rows = a.numel() // a.shape[-1]
220
+ a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
221
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
222
+ if align_rows_to is not None:
223
+ last_dim = a.shape[-1]
224
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
225
+ a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
226
+ a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
227
+ else:
228
+ a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
229
+
230
+ # If input tensor is sufficiently large, we need to use int64 indexing.
231
+ use_int64 = a.numel() > (2**31 - 1)
232
+ grid = (num_rows,)
233
+ _kernel_quantize_fp8_row[grid](
234
+ a,
235
+ a_scale,
236
+ a_fp8,
237
+ scale_ub,
238
+ zero_start_index_M,
239
+ a.shape[0],
240
+ a.shape[1],
241
+ a.shape[2],
242
+ a.shape[3],
243
+ a_fp8.shape[3],
244
+ a.stride(0),
245
+ a.stride(1),
246
+ a.stride(2),
247
+ a.stride(3),
248
+ a_fp8.stride(0),
249
+ a_fp8.stride(1),
250
+ a_fp8.stride(2),
251
+ a_fp8.stride(3),
252
+ (zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
253
+ (zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
254
+ TL_FP8_DTYPE=tl_dtype,
255
+ MAX_FP8=max_fp8,
256
+ EPS=eps,
257
+ CLAMP_MAX=scale_ub is not None,
258
+ JAGGED=zero_start_index_M is not None,
259
+ USE_INT64=use_int64,
260
+ )
261
+
262
+ return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
build/torch-rocm/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .quantizer import quantize_fp8_per_row
2
+
3
+ __all__ = [
4
+ quantize_fp8_per_row
5
+ ]
build/torch-rocm/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
build/torch-rocm/fp8_fbgemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-rocm/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch-rocm/quantizer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license
5
+ # copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
6
+
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+ from torch import nn
12
+ from triton import Config
13
+ from typing import Any, Optional
14
+
15
+ def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
16
+ """
17
+ Helper function to get constant values for the current platform.
18
+
19
+ Returns:
20
+ pt_dtype (torch.dtype): The correct torch fp8 datatype.
21
+ tl_dtype (tl.dtype): The correct triton fp8 datatype.
22
+ max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
23
+ eps (float): Minimum clip value to prevent divide by zero.
24
+ """
25
+ pt_fp8_dtype = torch.float8_e4m3fn
26
+ tl_fp8_dtype = tl.float8e4nv
27
+ return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
28
+
29
+
30
+ @triton.autotune(
31
+ configs=[
32
+ Config({"BLOCK_SIZE": 512}),
33
+ Config({"BLOCK_SIZE": 1024}),
34
+ Config({"BLOCK_SIZE": 2048}),
35
+ Config({"BLOCK_SIZE": 4096}),
36
+ Config({"BLOCK_SIZE": 8192}),
37
+ ],
38
+ key=["K"],
39
+ )
40
+ @triton.jit
41
+ def _kernel_quantize_fp8_row(
42
+ A,
43
+ A_scale,
44
+ A_fp8,
45
+ scale_ub,
46
+ zero_start_index_M,
47
+ B,
48
+ M,
49
+ N,
50
+ K,
51
+ K_fp8, # used when padding
52
+ stride_ab,
53
+ stride_am,
54
+ stride_an,
55
+ stride_ak,
56
+ stride_ob,
57
+ stride_om,
58
+ stride_on,
59
+ stride_ok,
60
+ stride_zb,
61
+ stride_zm,
62
+ TL_FP8_DTYPE: tl.constexpr,
63
+ MAX_FP8: tl.constexpr,
64
+ EPS: tl.constexpr,
65
+ CLAMP_MAX: tl.constexpr,
66
+ JAGGED: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ USE_INT64: tl.constexpr,
69
+ ) -> None:
70
+ """Quantize and scale each row.
71
+
72
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
73
+
74
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
75
+ in a max pass then scale/quantize pass.
76
+
77
+ Todo:
78
+ * Better tiling schemes.
79
+
80
+ Args:
81
+ A (Tensor): higher precision input tensor of 4 dimension.
82
+ A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
83
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
84
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
85
+ B (int): Size of dimenion 0
86
+ M (int): Size of dimenion 1
87
+ N (int): Size of dimenion 2
88
+ K (int): Size of dimenion 3 (input row size)
89
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
90
+ stride_ab (int): Stride of b dimension of A.
91
+ stride_am (int): Stride of m dimension of A.
92
+ stride_an (int): Stride of n dimension of A.
93
+ stride_ak (int): Stride of k dimension of A.
94
+ stride_ob (int): Stride of b dimension of output.
95
+ stride_om (int): Stride of m dimension of output.
96
+ stride_on (int): Stride of n dimension of output.
97
+ stride_ok (int): Stride of k dimension of output.
98
+ stride_zb (int): Stride of b dimension of jagged index.
99
+ stride_zm (int): Stride of m dimension of jagged index.
100
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
101
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
102
+ EPS (float): Epsilon value for numerical stability.
103
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
104
+ JAGGED (bool): Whether to use jagged indexing.
105
+ BLOCK_SIZE (int): Block size for reduction.
106
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
107
+ """
108
+ pid = tl.program_id(0)
109
+ # Use int64 indexing for large inputs. This is slower, but
110
+ # needed to avoid index overflows.
111
+ if USE_INT64:
112
+ pid = pid.to(tl.int64)
113
+ n_offset = tl.arange(0, BLOCK_SIZE)
114
+ a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
115
+ a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
116
+
117
+ K_in = K
118
+ if JAGGED:
119
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
120
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
121
+ current_row = pid % N
122
+ # If this row is empty, dont process any of it.
123
+ if current_row >= group_rows:
124
+ K_in = 0
125
+
126
+ # Calculate max.
127
+ cur_max = 0.0
128
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
129
+ a = tl.load(
130
+ A + a_offset_base + n_offset * stride_ak,
131
+ mask=n_offset < K_in,
132
+ other=0.0,
133
+ )
134
+ tile_max = tl.max(tl.abs(a))
135
+ cur_max = tl.maximum(tile_max, cur_max)
136
+ n_offset += BLOCK_SIZE
137
+ # Clamp max value appropriately.
138
+ if CLAMP_MAX:
139
+ ub = tl.load(scale_ub)
140
+ cur_max = tl.clamp(cur_max, EPS, ub)
141
+ else:
142
+ cur_max = tl.maximum(cur_max, EPS)
143
+ # Scale and quantize.
144
+ a_scale = MAX_FP8 / cur_max
145
+ tl.store(A_scale + pid, 1.0 / a_scale)
146
+ n_offset = tl.arange(0, BLOCK_SIZE)
147
+
148
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
149
+ for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
150
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
151
+ a = tl.load(
152
+ A + a_offset_base + n_offset * stride_ak,
153
+ mask=n_offset < K_in,
154
+ other=0.0,
155
+ )
156
+ # For elements >= K, a will be 0
157
+ a_fp8 = a * a_scale
158
+ # Clamp A to fp8 range to make sure there's no overflow.
159
+ # This is required for AMD. Nvidia's default saturation
160
+ # handles it, but it's nice to have anyway.
161
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
162
+
163
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
164
+ tl.store(
165
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
166
+ a_fp8,
167
+ mask=n_offset < K_fp8,
168
+ )
169
+ n_offset += BLOCK_SIZE
170
+
171
+
172
+ def quantize_fp8_per_row(
173
+ a: torch.Tensor,
174
+ scale_ub: Optional[torch.Tensor] = None,
175
+ zero_start_index_M: Optional[torch.Tensor] = None,
176
+ align_rows_to: Optional[int] = None,
177
+ ) -> tuple[torch.Tensor, torch.Tensor]:
178
+ """
179
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
180
+
181
+ Args:
182
+ a (Tensor): higher precision input tensor of 4 dimension.
183
+ scale_ub (Tensor): Maximum allowed value for scale.
184
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
185
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
186
+ Returns:
187
+ torch.Tensor: fp8 scaled tensor.
188
+ torch.Tensor: reciprocal scale tensor per row.
189
+ """
190
+ # Handle meta tensors (skip kernel execution)
191
+ if a.device.type == "meta":
192
+ pt_dtype, _, _, _ = get_fp8_constants()
193
+ a_shape = list(a.shape)
194
+ if align_rows_to is not None:
195
+ last_dim = a_shape[-1]
196
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
197
+ a_shape[-1] = padded_last_dim
198
+
199
+ # Return empty meta tensors with correct shapes
200
+ return (
201
+ torch.empty(a_shape, device="meta", dtype=pt_dtype),
202
+ torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
203
+ )
204
+
205
+ if scale_ub is not None and scale_ub.device != a.device:
206
+ raise Exception("'scale_ub' must be on the same device as 'a'")
207
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
208
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
209
+
210
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
211
+ a_shape = a.shape
212
+ while a.dim() < 4:
213
+ a = a.unsqueeze(0)
214
+ if zero_start_index_M is not None:
215
+ # There should be one value of zero_start_index_M per NxK matrix.
216
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
217
+ # Get constant values.
218
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
219
+ num_rows = a.numel() // a.shape[-1]
220
+ a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
221
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
222
+ if align_rows_to is not None:
223
+ last_dim = a.shape[-1]
224
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
225
+ a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
226
+ a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
227
+ else:
228
+ a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
229
+
230
+ # If input tensor is sufficiently large, we need to use int64 indexing.
231
+ use_int64 = a.numel() > (2**31 - 1)
232
+ grid = (num_rows,)
233
+ _kernel_quantize_fp8_row[grid](
234
+ a,
235
+ a_scale,
236
+ a_fp8,
237
+ scale_ub,
238
+ zero_start_index_M,
239
+ a.shape[0],
240
+ a.shape[1],
241
+ a.shape[2],
242
+ a.shape[3],
243
+ a_fp8.shape[3],
244
+ a.stride(0),
245
+ a.stride(1),
246
+ a.stride(2),
247
+ a.stride(3),
248
+ a_fp8.stride(0),
249
+ a_fp8.stride(1),
250
+ a_fp8.stride(2),
251
+ a_fp8.stride(3),
252
+ (zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
253
+ (zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
254
+ TL_FP8_DTYPE=tl_dtype,
255
+ MAX_FP8=max_fp8,
256
+ EPS=eps,
257
+ CLAMP_MAX=scale_ub is not None,
258
+ JAGGED=zero_start_index_M is not None,
259
+ USE_INT64=use_int64,
260
+ )
261
+
262
+ return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
build/torch-xpu/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .quantizer import quantize_fp8_per_row
2
+
3
+ __all__ = [
4
+ quantize_fp8_per_row
5
+ ]
build/torch-xpu/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._fp8_fbgemm_5f3c84f_dirty
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_fp8_fbgemm_5f3c84f_dirty::{op_name}"
build/torch-xpu/fp8_fbgemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-xpu/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch-xpu/quantizer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license
5
+ # copied from https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
6
+
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+ from torch import nn
12
+ from triton import Config
13
+ from typing import Any, Optional
14
+
15
+ def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
16
+ """
17
+ Helper function to get constant values for the current platform.
18
+
19
+ Returns:
20
+ pt_dtype (torch.dtype): The correct torch fp8 datatype.
21
+ tl_dtype (tl.dtype): The correct triton fp8 datatype.
22
+ max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
23
+ eps (float): Minimum clip value to prevent divide by zero.
24
+ """
25
+ pt_fp8_dtype = torch.float8_e4m3fn
26
+ tl_fp8_dtype = tl.float8e4nv
27
+ return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
28
+
29
+
30
+ @triton.autotune(
31
+ configs=[
32
+ Config({"BLOCK_SIZE": 512}),
33
+ Config({"BLOCK_SIZE": 1024}),
34
+ Config({"BLOCK_SIZE": 2048}),
35
+ Config({"BLOCK_SIZE": 4096}),
36
+ Config({"BLOCK_SIZE": 8192}),
37
+ ],
38
+ key=["K"],
39
+ )
40
+ @triton.jit
41
+ def _kernel_quantize_fp8_row(
42
+ A,
43
+ A_scale,
44
+ A_fp8,
45
+ scale_ub,
46
+ zero_start_index_M,
47
+ B,
48
+ M,
49
+ N,
50
+ K,
51
+ K_fp8, # used when padding
52
+ stride_ab,
53
+ stride_am,
54
+ stride_an,
55
+ stride_ak,
56
+ stride_ob,
57
+ stride_om,
58
+ stride_on,
59
+ stride_ok,
60
+ stride_zb,
61
+ stride_zm,
62
+ TL_FP8_DTYPE: tl.constexpr,
63
+ MAX_FP8: tl.constexpr,
64
+ EPS: tl.constexpr,
65
+ CLAMP_MAX: tl.constexpr,
66
+ JAGGED: tl.constexpr,
67
+ BLOCK_SIZE: tl.constexpr,
68
+ USE_INT64: tl.constexpr,
69
+ ) -> None:
70
+ """Quantize and scale each row.
71
+
72
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
73
+
74
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
75
+ in a max pass then scale/quantize pass.
76
+
77
+ Todo:
78
+ * Better tiling schemes.
79
+
80
+ Args:
81
+ A (Tensor): higher precision input tensor of 4 dimension.
82
+ A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
83
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
84
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
85
+ B (int): Size of dimenion 0
86
+ M (int): Size of dimenion 1
87
+ N (int): Size of dimenion 2
88
+ K (int): Size of dimenion 3 (input row size)
89
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
90
+ stride_ab (int): Stride of b dimension of A.
91
+ stride_am (int): Stride of m dimension of A.
92
+ stride_an (int): Stride of n dimension of A.
93
+ stride_ak (int): Stride of k dimension of A.
94
+ stride_ob (int): Stride of b dimension of output.
95
+ stride_om (int): Stride of m dimension of output.
96
+ stride_on (int): Stride of n dimension of output.
97
+ stride_ok (int): Stride of k dimension of output.
98
+ stride_zb (int): Stride of b dimension of jagged index.
99
+ stride_zm (int): Stride of m dimension of jagged index.
100
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
101
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
102
+ EPS (float): Epsilon value for numerical stability.
103
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
104
+ JAGGED (bool): Whether to use jagged indexing.
105
+ BLOCK_SIZE (int): Block size for reduction.
106
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
107
+ """
108
+ pid = tl.program_id(0)
109
+ # Use int64 indexing for large inputs. This is slower, but
110
+ # needed to avoid index overflows.
111
+ if USE_INT64:
112
+ pid = pid.to(tl.int64)
113
+ n_offset = tl.arange(0, BLOCK_SIZE)
114
+ a_offset_base = pid // (M * N) * stride_ab + (pid % (M * N)) // N * stride_am + (pid % (M * N)) % N * stride_an
115
+ a_fp8_offset_base = pid // (M * N) * stride_ob + (pid % (M * N)) // N * stride_om + (pid % (M * N)) % N * stride_on
116
+
117
+ K_in = K
118
+ if JAGGED:
119
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
120
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
121
+ current_row = pid % N
122
+ # If this row is empty, dont process any of it.
123
+ if current_row >= group_rows:
124
+ K_in = 0
125
+
126
+ # Calculate max.
127
+ cur_max = 0.0
128
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
129
+ a = tl.load(
130
+ A + a_offset_base + n_offset * stride_ak,
131
+ mask=n_offset < K_in,
132
+ other=0.0,
133
+ )
134
+ tile_max = tl.max(tl.abs(a))
135
+ cur_max = tl.maximum(tile_max, cur_max)
136
+ n_offset += BLOCK_SIZE
137
+ # Clamp max value appropriately.
138
+ if CLAMP_MAX:
139
+ ub = tl.load(scale_ub)
140
+ cur_max = tl.clamp(cur_max, EPS, ub)
141
+ else:
142
+ cur_max = tl.maximum(cur_max, EPS)
143
+ # Scale and quantize.
144
+ a_scale = MAX_FP8 / cur_max
145
+ tl.store(A_scale + pid, 1.0 / a_scale)
146
+ n_offset = tl.arange(0, BLOCK_SIZE)
147
+
148
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
149
+ for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
150
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
151
+ a = tl.load(
152
+ A + a_offset_base + n_offset * stride_ak,
153
+ mask=n_offset < K_in,
154
+ other=0.0,
155
+ )
156
+ # For elements >= K, a will be 0
157
+ a_fp8 = a * a_scale
158
+ # Clamp A to fp8 range to make sure there's no overflow.
159
+ # This is required for AMD. Nvidia's default saturation
160
+ # handles it, but it's nice to have anyway.
161
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
162
+
163
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
164
+ tl.store(
165
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
166
+ a_fp8,
167
+ mask=n_offset < K_fp8,
168
+ )
169
+ n_offset += BLOCK_SIZE
170
+
171
+
172
+ def quantize_fp8_per_row(
173
+ a: torch.Tensor,
174
+ scale_ub: Optional[torch.Tensor] = None,
175
+ zero_start_index_M: Optional[torch.Tensor] = None,
176
+ align_rows_to: Optional[int] = None,
177
+ ) -> tuple[torch.Tensor, torch.Tensor]:
178
+ """
179
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
180
+
181
+ Args:
182
+ a (Tensor): higher precision input tensor of 4 dimension.
183
+ scale_ub (Tensor): Maximum allowed value for scale.
184
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
185
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
186
+ Returns:
187
+ torch.Tensor: fp8 scaled tensor.
188
+ torch.Tensor: reciprocal scale tensor per row.
189
+ """
190
+ # Handle meta tensors (skip kernel execution)
191
+ if a.device.type == "meta":
192
+ pt_dtype, _, _, _ = get_fp8_constants()
193
+ a_shape = list(a.shape)
194
+ if align_rows_to is not None:
195
+ last_dim = a_shape[-1]
196
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
197
+ a_shape[-1] = padded_last_dim
198
+
199
+ # Return empty meta tensors with correct shapes
200
+ return (
201
+ torch.empty(a_shape, device="meta", dtype=pt_dtype),
202
+ torch.empty(a_shape[:-1], device="meta", dtype=torch.float32)
203
+ )
204
+
205
+ if scale_ub is not None and scale_ub.device != a.device:
206
+ raise Exception("'scale_ub' must be on the same device as 'a'")
207
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
208
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
209
+
210
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
211
+ a_shape = a.shape
212
+ while a.dim() < 4:
213
+ a = a.unsqueeze(0)
214
+ if zero_start_index_M is not None:
215
+ # There should be one value of zero_start_index_M per NxK matrix.
216
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
217
+ # Get constant values.
218
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
219
+ num_rows = a.numel() // a.shape[-1]
220
+ a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
221
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
222
+ if align_rows_to is not None:
223
+ last_dim = a.shape[-1]
224
+ padded_last_dim = ((last_dim + align_rows_to - 1) // align_rows_to) * align_rows_to
225
+ a_fp8 = torch.empty((*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype)
226
+ a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
227
+ else:
228
+ a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
229
+
230
+ # If input tensor is sufficiently large, we need to use int64 indexing.
231
+ use_int64 = a.numel() > (2**31 - 1)
232
+ grid = (num_rows,)
233
+ _kernel_quantize_fp8_row[grid](
234
+ a,
235
+ a_scale,
236
+ a_fp8,
237
+ scale_ub,
238
+ zero_start_index_M,
239
+ a.shape[0],
240
+ a.shape[1],
241
+ a.shape[2],
242
+ a.shape[3],
243
+ a_fp8.shape[3],
244
+ a.stride(0),
245
+ a.stride(1),
246
+ a.stride(2),
247
+ a.stride(3),
248
+ a_fp8.stride(0),
249
+ a_fp8.stride(1),
250
+ a_fp8.stride(2),
251
+ a_fp8.stride(3),
252
+ (zero_start_index_M.stride(0) if zero_start_index_M is not None else None),
253
+ (zero_start_index_M.stride(1) if zero_start_index_M is not None else None),
254
+ TL_FP8_DTYPE=tl_dtype,
255
+ MAX_FP8=max_fp8,
256
+ EPS=eps,
257
+ CLAMP_MAX=scale_ub is not None,
258
+ JAGGED=zero_start_index_M is not None,
259
+ USE_INT64=use_int64,
260
+ )
261
+
262
+ return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])