Upload custom kernels
Browse files- bitnet_kernel/bitnet_kernels.cu +72 -0
- bitnet_kernel/bitnet_kernels.h +83 -0
- build.toml +15 -0
- flake.nix +13 -0
- torch-ext/bitnet_kernel/__init__.py +4 -0
- torch-ext/torch_binding.cpp +11 -0
- torch-ext/torch_binding.h +5 -0
bitnet_kernel/bitnet_kernels.cu
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "bitnet_kernels.h"
|
| 2 |
+
#include <torch/all.h>
|
| 3 |
+
|
| 4 |
+
extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
|
| 5 |
+
if (M == 1 && N == 3840 && K == 2560){
|
| 6 |
+
ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 7 |
+
}
|
| 8 |
+
else if (M == 1 && N == 2560 && K == 2560){
|
| 9 |
+
ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 10 |
+
}
|
| 11 |
+
else if (M == 1 && N == 13824 && K == 2560){
|
| 12 |
+
ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<<dim3(864, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 13 |
+
}
|
| 14 |
+
else if (M == 1 && N == 2560 && K == 6912){
|
| 15 |
+
ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 16 |
+
}
|
| 17 |
+
else if(M == 1 && N == 4800 && K == 3200){
|
| 18 |
+
ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<<dim3(300, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 19 |
+
}
|
| 20 |
+
else if(M == 1 && N == 3200 && K == 3200){
|
| 21 |
+
ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 22 |
+
}
|
| 23 |
+
else if(M == 1 && N == 20480 && K == 3200){
|
| 24 |
+
ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<<dim3(1280, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 25 |
+
}
|
| 26 |
+
else if(M == 1 && N == 3200 && K == 10240){
|
| 27 |
+
ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 28 |
+
}
|
| 29 |
+
else if(M == 1 && N == 5120 && K == 27648){
|
| 30 |
+
ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<<dim3(320, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 31 |
+
}
|
| 32 |
+
else if(M == 1 && N == 55296 && K == 5120){
|
| 33 |
+
ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<<dim3(3456, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
| 34 |
+
}
|
| 35 |
+
else{
|
| 36 |
+
std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
torch::Tensor bitlinear_int8xint2_cpp(torch::Tensor input0, torch::Tensor input1, torch::Tensor s, torch::Tensor ws) {
|
| 41 |
+
// Get input dimensions
|
| 42 |
+
auto out_shape = input0.sizes().vec();
|
| 43 |
+
out_shape.back() = input1.size(0);
|
| 44 |
+
|
| 45 |
+
// Calculate M, N, K
|
| 46 |
+
int M = input0.size(0);
|
| 47 |
+
if (out_shape.size() == 3) {
|
| 48 |
+
M *= input0.size(1);
|
| 49 |
+
}
|
| 50 |
+
int N = input1.size(0);
|
| 51 |
+
int K = input1.size(1) * 4;
|
| 52 |
+
|
| 53 |
+
// Create output tensor
|
| 54 |
+
auto output = torch::zeros(out_shape, torch::dtype(torch::kBFloat16).device(input0.device()));
|
| 55 |
+
|
| 56 |
+
// Get CUDA stream
|
| 57 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 58 |
+
|
| 59 |
+
// Call kernel
|
| 60 |
+
bitlinear_int8xint2(
|
| 61 |
+
reinterpret_cast<int8_t*>(input0.data_ptr()),
|
| 62 |
+
reinterpret_cast<int8_t*>(input1.data_ptr()),
|
| 63 |
+
reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),
|
| 64 |
+
reinterpret_cast<__nv_bfloat16*>(s.data_ptr()),
|
| 65 |
+
reinterpret_cast<__nv_bfloat16*>(ws.data_ptr()),
|
| 66 |
+
M, N, K, stream
|
| 67 |
+
);
|
| 68 |
+
|
| 69 |
+
return output;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
bitnet_kernel/bitnet_kernels.h
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cuda_runtime.h>
|
| 2 |
+
#include <math_constants.h>
|
| 3 |
+
#include <math.h>
|
| 4 |
+
#include <mma.h>
|
| 5 |
+
#include <iostream>
|
| 6 |
+
#include <cuda.h>
|
| 7 |
+
#include <cuda_fp16.h>
|
| 8 |
+
#include <cuda_bf16.h>
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
|
| 12 |
+
#define TVM_ENABLE_L2_PREFETCH 1
|
| 13 |
+
#else
|
| 14 |
+
#define TVM_ENABLE_L2_PREFETCH 0
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
|
| 18 |
+
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
|
| 19 |
+
#else
|
| 20 |
+
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
|
| 21 |
+
#endif
|
| 22 |
+
|
| 23 |
+
template <typename T1, typename T2>
|
| 24 |
+
__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
|
| 25 |
+
{
|
| 26 |
+
// convert 8 int2b_t to 8 int8b_t -> 2 int32
|
| 27 |
+
uint *i8s = reinterpret_cast<uint *>(_i8s);
|
| 28 |
+
|
| 29 |
+
// i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
|
| 30 |
+
uint const i2s = *_i2s;
|
| 31 |
+
|
| 32 |
+
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
|
| 33 |
+
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
|
| 34 |
+
static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000;
|
| 35 |
+
|
| 36 |
+
#pragma unroll
|
| 37 |
+
for (int i = 0; i < (N / 4); i++)
|
| 38 |
+
{
|
| 39 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
| 40 |
+
: "=r"(i8s[i])
|
| 41 |
+
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
|
| 42 |
+
i8s[i] = __vsubss4(i8s[i], 0x02020202);
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
template <int M, int N, int K, int ws_num, int K_block_size, int N_block_size>
|
| 47 |
+
__global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
|
| 48 |
+
constexpr int K_per_loop = 16;
|
| 49 |
+
constexpr int wmma_K = 32;
|
| 50 |
+
constexpr int wmma_N = 16;
|
| 51 |
+
int in_thread_C_local[1];
|
| 52 |
+
signed char A_local[K_per_loop];
|
| 53 |
+
int B_reshape_local[1];
|
| 54 |
+
signed char B_decode_local[K_per_loop];
|
| 55 |
+
int red_buf0[1];
|
| 56 |
+
in_thread_C_local[0] = 0;
|
| 57 |
+
#pragma unroll
|
| 58 |
+
for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
|
| 59 |
+
*(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
|
| 60 |
+
B_reshape_local[0] = *(int*)(B +
|
| 61 |
+
(((int)blockIdx.x) * N_block_size * K / 4) +
|
| 62 |
+
(k_0 * K_block_size * K_per_loop * wmma_N / 4) +
|
| 63 |
+
((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
|
| 64 |
+
((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) +
|
| 65 |
+
((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) +
|
| 66 |
+
((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
|
| 67 |
+
);
|
| 68 |
+
decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
|
| 69 |
+
#pragma unroll
|
| 70 |
+
for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
|
| 71 |
+
in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
red_buf0[0] = in_thread_C_local[0];
|
| 75 |
+
#pragma unroll
|
| 76 |
+
for (int offset = K_block_size/2; offset > 0; offset /= 2) {
|
| 77 |
+
red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
|
| 78 |
+
}
|
| 79 |
+
int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
|
| 80 |
+
int ws_idx = out_idx / (N / ws_num);
|
| 81 |
+
if (threadIdx.x == 0)
|
| 82 |
+
dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
|
| 83 |
+
}
|
build.toml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "bitnet_kernel"
|
| 3 |
+
|
| 4 |
+
[torch]
|
| 5 |
+
src = [
|
| 6 |
+
"torch-ext/torch_binding.cpp",
|
| 7 |
+
"torch-ext/torch_binding.h"
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
[kernel.rmsnorm_kernel]
|
| 11 |
+
src = [
|
| 12 |
+
"bitnet_kernel/bitnet_kernel.cu",
|
| 13 |
+
"bitnet_kernel/bitnet_kernel.h
|
| 14 |
+
]
|
| 15 |
+
depends = [ "torch"]
|
flake.nix
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for Torch kernel extension";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs = { self, kernel-builder, }:
|
| 9 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 10 |
+
path = ./.;
|
| 11 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 12 |
+
};
|
| 13 |
+
}
|
torch-ext/bitnet_kernel/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._ops import ops
|
| 2 |
+
|
| 3 |
+
def bitnet_int8xint2_linear(input0, input1, s, ws):
|
| 4 |
+
return ops.bitlinear_int8xint2_cpp(input0, input1, s, ws)
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("bitlinear_int8xint2_cpp(Tensor input0, Tensor input1, Tensor s, Tensor ws) -> Tensor");
|
| 8 |
+
ops.impl("bitlinear_int8xint2_cpp", torch::kCUDA, &bitlinear_int8xint2_cpp);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void bitlinear_int8xint2_cpp(torch::Tensor const &input0, torch::Tensor const &input1, torch::Tensor const &s, torch::Tensor const &ws);
|