hyx21 commited on
Commit
8de321b
·
verified ·
1 Parent(s): 8b81a10

Update modeling_llama_long_infllmv2.py

Browse files
Files changed (1) hide show
  1. modeling_llama_long_infllmv2.py +3 -3
modeling_llama_long_infllmv2.py CHANGED
@@ -50,10 +50,10 @@ from transformers.utils import (
50
  replace_return_docstrings,
51
  )
52
  from transformers.models.llama.configuration_llama import LlamaConfig
53
- from moba import moba_attn_varlen
54
  from functools import lru_cache
55
  from .cis_pooling import nosa_mean_pooling
56
- from native_sparse_attention.ops.triton.topk_sparse_attention import topk_sparse_attention
57
 
58
  logger = logging.get_logger(__name__)
59
 
@@ -680,7 +680,7 @@ class LlamaFlashAttention2(LlamaAttention):
680
 
681
  return attn_output, attn_weights, past_key_value
682
 
683
- from native_sparse_attention.ops.triton.topk_sparse_attention import topk_sparse_attention
684
  try:
685
  from flash_attn import flash_attn_func, flash_attn_varlen_func
686
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
50
  replace_return_docstrings,
51
  )
52
  from transformers.models.llama.configuration_llama import LlamaConfig
53
+
54
  from functools import lru_cache
55
  from .cis_pooling import nosa_mean_pooling
56
+
57
 
58
  logger = logging.get_logger(__name__)
59
 
 
680
 
681
  return attn_output, attn_weights, past_key_value
682
 
683
+
684
  try:
685
  from flash_attn import flash_attn_func, flash_attn_varlen_func
686
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa