######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import torch.nn.functional as F import gradio as gr from huggingface_hub import hf_hub_download import gc from rwkv.utils import PIPELINE, PIPELINE_ARGS import types, torch, copy, time from typing import List # torch.backends.cudnn.benchmark = True # torch.backends.cudnn.allow_tf32 = True # torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True torch._C._jit_set_autocast_mode(False) import torch.nn as nn from torch.nn import functional as F MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method MyStatic = torch.jit.script pipeline = PIPELINE(None, "rwkv_vocab_v20230424") ######################################################################################################## print('\nNOTE: this is very inefficient (loads all weights to VRAM, and slow KV cache). better method is to prefetch DeepEmbed from RAM/SSD\n') args = types.SimpleNamespace() model_path = hf_hub_download(repo_id='Alic-Li/RWKV_v7_G1_Translate_ctx4096_20250620', filename='RWKV_v7s_G1_DEA_0.1B_Translate_ctx4096_20250917_latest.pth') args.MODEL_NAME = model_path args.n_layer = 12 args.n_embd = 768 args.vocab_size = 65536 args.head_size = 64 ctx_limit = 4096 gen_limit = 4096 penalty_decay = 0.996 NUM_TRIALS = 1 LENGTH_PER_TRIAL = 500 TEMPERATURE = 1.0 TOP_P = 0.0 DTYPE = torch.half from torch.utils.cpp_extension import load HEAD_SIZE = args.head_size # ROCm_flag = torch.version.hip is not None # if ROCm_flag: # load(name="wkv7s", sources=["cuda/wkv7s_op.cpp", f"cuda/wkv7s.cu"], is_python_module=False, # verbose=True, extra_cuda_cflags=["-xhip", "-fopenmp", "-ffast-math", "-O3", "-munsafe-fp-atomics", f"-D_N_={HEAD_SIZE}"]) # else: # load(name="wkv7s", sources=["cuda/wkv7s_op.cpp", f"cuda/wkv7s.cu"], is_python_module=False, # verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) class WKV_7(torch.autograd.Function): @staticmethod def forward(ctx, state, r, w, k, v, a, b): with torch.no_grad(): T, C = r.size() H = C // HEAD_SIZE N = HEAD_SIZE assert HEAD_SIZE == C // H assert all(x.dtype == DTYPE for x in [r,w,k,v,a,b]) assert all(x.is_contiguous() for x in [r,w,k,v,a,b]) y = torch.empty((T, C), device=k.device, dtype=DTYPE, requires_grad=False, memory_format=torch.contiguous_format) torch.ops.wkv7s.forward(1, T, C, H, state, r, w, k, v, a, b, y) return y def RWKV7_OP(state, r, w, k, v, a, b): return WKV_7.apply(state, r, w, k, v, a, b) ######################################################################################################## class RWKV_x070(MyModule): def __init__(self, args): super().__init__() self.args = args self.n_embd = args.n_embd self.n_layer = args.n_layer self.eval() self.z = torch.load(args.MODEL_NAME, map_location='cpu') z = self.z self.n_head, self.head_size = z['blocks.0.att.r_k'].shape keys = list(z.keys()) for k in keys: if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k or 'qq.weight' in k: z[k] = z[k].t() z[k] = z[k].squeeze().to(dtype=DTYPE) if k.endswith('att.r_k'): z[k] = z[k].flatten() assert self.head_size == args.head_size z['emb.weight'] = F.layer_norm(z['emb.weight'], (args.n_embd,), weight=z['blocks.0.ln0.weight'], bias=z['blocks.0.ln0.bias']) for i in range(self.n_layer): # !!! merge emb residual !!! z[f'blocks.{i}.ffn.s_emb.weight'] = z[f'blocks.{i}.ffn.s_emb.weight'] + z['emb.weight'] @ z[f'blocks.{i}.ffn.s_emb_x.weight'].t() z[f'blocks.{i}.qkv.k_emb.weight'] = z[f'blocks.{i}.qkv.k_emb.weight'] + z['emb.weight'] @ z[f'blocks.{i}.qkv.k_emb_x.weight'].t() z[f'blocks.{i}.qkv.v_emb.weight'] = z[f'blocks.{i}.qkv.v_emb.weight'] + z['emb.weight'] @ z[f'blocks.{i}.qkv.v_emb_x.weight'].t() z['blocks.0.att.v0'] = z['blocks.0.att.a0'] # actually ignored z['blocks.0.att.v1'] = z['blocks.0.att.a1'] # actually ignored z['blocks.0.att.v2'] = z['blocks.0.att.a2'] # actually ignored def forward(self, idx, state, full_output=False): if state == None: state = [None for _ in range(args.n_layer * 3 + 37)] # with KV cache etc. for i in range(args.n_layer): # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev state[i*3+0] = torch.zeros(args.n_embd, dtype=DTYPE, requires_grad=False, device="cpu") state[i*3+1] = torch.zeros((args.n_embd // args.head_size, args.head_size, args.head_size), dtype=torch.float, requires_grad=False, device="cpu") state[i*3+2] = torch.zeros(args.n_embd, dtype=DTYPE, requires_grad=False, device="cpu") state[args.n_layer*3+0] = torch.empty((0), dtype=torch.int, requires_grad=False, device="cpu") # token idx cache for i in range(1,1+24): # kv cache = 12*2*32 numbers per token state[args.n_layer*3+i] = torch.empty((0,32), dtype=DTYPE, requires_grad=False, device="cpu") for i in range(1+24,1+36): # token-shift cache for Q in DEA state[args.n_layer*3+i] = torch.zeros(256, dtype=DTYPE, requires_grad=False, device="cpu") if type(idx) is list: if len(idx) > 1: return self.forward_seq(idx, state, full_output) else: # return self.forward_one(idx[0], state) # sorry too busy to add forward_one mode return self.forward_seq(idx, state, full_output) else: # return self.forward_one(idx, state) # sorry too busy to add forward_one mode return self.forward_seq([idx], state, full_output) @MyFunction def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=False): with torch.no_grad(): z = self.z x = z['emb.weight'][idx] state[self.n_layer*3] = torch.cat((state[self.n_layer*3], torch.tensor(idx, dtype=torch.int, device=x.device)), dim=0) ctx = state[self.n_layer*3] v_first = torch.empty_like(x) for i in range(self.n_layer): bbb = f'blocks.{i}.' att = f'blocks.{i}.att.' ffn = f'blocks.{i}.ffn.' qkv = f'blocks.{i}.qkv.' q = x @ z[qkv+'qq.weight'] k = x @ z[qkv+'k1'] state[self.n_layer*3+1+i*2] = torch.cat((state[self.n_layer*3+1+i*2], k), dim=0) k = (state[self.n_layer*3+1+i*2] @ z[qkv+'k2']) * (z[qkv+'k_emb.weight'][ctx]) v = x @ z[qkv+'v1'] state[self.n_layer*3+1+i*2+1] = torch.cat((state[self.n_layer*3+1+i*2+1], v), dim=0) v = torch.tanh(state[self.n_layer*3+1+i*2+1] @ z[qkv+'v2']) * (z[qkv+'v_emb.weight'][ctx]) qq = torch.cat((state[self.n_layer*3+1+24+i].unsqueeze(0), q[:-1,:])) state[self.n_layer*3+1+24+i] = q[-1,:] q = q + (qq - q) * z[qkv+'x_q'] k = k + (F.pad(k,(0, 0, 1, -1)) - k) * z[qkv+'x_k'] v = v + (F.pad(v,(0, 0, 1, -1)) - v) * z[qkv+'x_v'] q = F.layer_norm(q, (256,), weight=z[qkv+'lnq.weight'], bias=z[qkv+'lnq.bias']) k = F.layer_norm(k, (256,), weight=z[qkv+'lnk.weight'], bias=z[qkv+'lnk.bias']) v = F.layer_norm(v, (self.n_embd,), weight=z[qkv+'lnv.weight'], bias=z[qkv+'lnv.bias']) scores = 64 * torch.tanh((q @ k.mT) * (1.0 / 1024.0)) # using soft-cap if len(idx) > 1: mask = ~torch.tril(torch.ones(len(ctx), len(ctx), dtype=torch.bool, device=x.device))[-len(idx):,:] scores = scores.masked_fill(mask, float('-inf')) qkv = scores.softmax(dim=-1) @ v xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln1.weight'], bias=z[bbb+'ln1.bias']) xx, state[i*3+0], state[i*3+1], v_first = RWKV_x070_TMix_seq(i, self.n_head, self.head_size, xx, state[i*3+0], v_first, state[i*3+1], z[att+'x_r'], z[att+'x_w'], z[att+'x_k'], z[att+'x_v'], z[att+'x_a'], z[att+'x_g'], z[att+'w0'], z[att+'w1'], z[att+'w2'], z[att+'a0'], z[att+'a1'], z[att+'a2'], z[att+'v0'], z[att+'v1'], z[att+'v2'], z[att+'g1'], z[att+'g2'], z[att+'k_k'], z[att+'k_a'], z[att+'r_k'], z[att+'receptance.weight'], z[att+'key.weight'], z[att+'value.weight'], z[att+'output.weight'], z[att+'ln_x.weight'], z[att+'ln_x.bias']) x = x + xx + qkv xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'], z[ffn+'s_emb.weight'][idx], z[ffn+'s1'], z[ffn+'s2'], z[ffn+'s0']) x = x + xx if not full_output: x = x[-1,:] x = F.layer_norm(x, (self.n_embd,), weight=z['ln_out.weight'], bias=z['ln_out.bias']) x = x @ z['head.weight'] return x, state ######################################################################################################## @MyStatic def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x_r, x_w, x_k, x_v, x_a, x_g, w0, w1, w2, a0, a1, a2, v0, v1, v2, g1, g2, k_k, k_a, r_k, R_, K_, V_, O_, ln_w, ln_b): T = x.shape[0] xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x xr, xw, xk, xv, xa, xg = x+xx*x_r, x+xx*x_w, x+xx*x_k, x+xx*x_v, x+xx*x_a, x+xx*x_g r = xr @ R_ w = torch.tanh(xw @ w1) @ w2 k = xk @ K_ v = xv @ V_ a = torch.sigmoid(a0 + (xa @ a1) @ a2) g = torch.sigmoid(xg @ g1) @ g2 kk = torch.nn.functional.normalize((k * k_k).view(T,H,N), dim=-1, p=2.0).view(T,H*N) k = k * (1 + (a-1) * k_a) if layer_id == 0: v_first = v else: v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2) ######## cuda-free method w = torch.exp(-0.606531 * torch.sigmoid((w0 + w).float())) # 0.606531 = exp(-0.5) for t in range(T): r_, w_, k_, v_, kk_, a_ = r[t], w[t], k[t], v[t], kk[t], a[t] vk = v_.view(H,N,1) @ k_.view(H,1,N) ab = (-kk_).view(H,N,1) @ (kk_*a_).view(H,1,N) state = state * w_.view(H,1,N) + state @ ab.float() + vk.float() xx[t] = (state.to(dtype=x.dtype) @ r_.view(H,N,1)).view(H*N) # w = -torch.nn.functional.softplus(-(w0 + w)) - 0.5 # xx = RWKV7_OP(state, r, w, k, v, -kk, kk*a) xx = torch.nn.functional.group_norm(xx.view(T,H*N), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).view(T,H*N) xx = xx + ((r * k * r_k).view(T,H,N).sum(dim=-1, keepdim=True) * v.view(T,H,N)).view(T,H*N) return (xx * g) @ O_, x[-1,:], state, v_first ######################################################################################################## @MyStatic def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_, semb_, s1_, s2_, s0_): T,C = x.shape xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x k = x + xx * x_k k = torch.relu(k @ K_) ** 2 ss = (x @ s1_).view(T,1,32) @ semb_.view(T,32,32) k = k * ((ss.view(T,32) @ s2_) + s0_) return k @ V_, x[-1,:] model = RWKV_x070(args) def evaluate( ctx, token_count=200, temperature=0.0, top_p=0.0, presencePenalty = 0.0, countPenalty = 0.0, ): args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), alpha_frequency = countPenalty, alpha_presence = presencePenalty, token_ban = [], # ban the generation of some tokens token_stop = [0]) # stop generation whenever you see any token here ctx = ctx.strip() all_tokens = [] out_last = 0 out_str = '' occurrence = {} state = None for i in range(int(token_count)): input_ids = pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token] out, state = model.forward(input_ids, state) for n in occurrence: out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) if token in args.token_stop: break all_tokens += [token] for xxx in occurrence: occurrence[xxx] *= penalty_decay ttt = pipeline.decode([token]) www = 1 if ttt in ' \t0123456789': www = 0 #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】': # www = 0.5 if token not in occurrence: occurrence[token] = www else: occurrence[token] += www tmp = pipeline.decode(all_tokens[out_last:]) if '\ufffd' not in tmp: out_str += tmp yield out_str.strip() out_last = i + 1 del out del state gc.collect() # torch.cuda.empty_cache() yield out_str.strip() def translate_english_to_chinese(english_text, token_count, temperature, top_p, presence_penalty, count_penalty): if not english_text.strip(): return "Chinese:\n请输入英文内容。" full_prompt = f"English: {english_text}\n\nChinese:" for output in evaluate(full_prompt, token_count, temperature, top_p, presence_penalty, count_penalty): yield output def translate_chinese_to_chinses(Chinese_text, token_count, temperature, top_p, presence_penalty, count_penalty): if not Chinese_text.strip(): return "Chinses:\n请输入中文内容。" full_prompt = f"Chinese: {Chinese_text}\n\nEnglish:" for output in evaluate(full_prompt, token_count, temperature, top_p, presence_penalty, count_penalty): yield output with gr.Blocks(title="RWKV_v7s_G1_DEA_0.1B_Translate_ctx4096_20250917 English -> Chinese") as demo: with gr.Tab("English To Chinses"): gr.HTML(f"