lll2343 commited on
Commit
2339df1
·
verified ·
1 Parent(s): 75380a8

Update attn_mask_utils.py

Browse files
Files changed (1) hide show
  1. attn_mask_utils.py +51 -0
attn_mask_utils.py CHANGED
@@ -96,6 +96,57 @@ def update_causal_mask_with_pad_non_visible_2d(
96
  return attn_mask_2d
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def update_causal_mask_for_one_gen_window_2d(
100
  input_ids: torch.Tensor,
101
  attn_mask_2d: torch.Tensor,
 
96
  return attn_mask_2d
97
 
98
 
99
+ def update_causal_mask_with_pad_non_visible_2d_for_ssd_cache(
100
+ input_ids: torch.Tensor,
101
+ attn_mask_2d: torch.Tensor,
102
+ block_size: int = 4,
103
+ use_cache: bool = True,
104
+ causal_attn: bool = False
105
+ ) -> torch.Tensor:
106
+ """
107
+ Updates a 2D attention mask for Self-Speculative Decoding generate
108
+
109
+ Details is avaliabe in Appendix B Figure 5.
110
+
111
+ Args:
112
+ input_ids: Input token IDs (unused in current implementation)
113
+ attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where:
114
+ - 0.0 indicates allowed attention
115
+ - -inf indicates masked attention
116
+ block_size: Size of the diffusion window
117
+ use_cache: Whether key-value cache is being used
118
+ causal_attn: If True, maintains strict causal masking throughout
119
+
120
+ Returns:
121
+ Modified attention mask with updated visibility patterns
122
+ """
123
+ q_len, kv_len = attn_mask_2d.shape
124
+
125
+ if q_len == kv_len:
126
+ # prefill stage
127
+ return update_causal_mask_for_one_gen_window_2d(
128
+ input_ids = input_ids,
129
+ attn_mask_2d = attn_mask_2d,
130
+ block_size = block_size,
131
+ use_cache = use_cache,
132
+ causal_attn = causal_attn
133
+ )
134
+
135
+ # decoding, as shown in Appendix B
136
+ start_ix = q_len - block_size
137
+ start_jx = kv_len - block_size
138
+ for ix in range(block_size-1, -1, -1):
139
+ attn_mask_2d[start_ix:start_ix+block_size, start_jx:start_jx+block_size] = 0.0
140
+ attn_mask_2d[start_ix+block_size:, start_jx-ix:start_jx+block_size] = -float('inf')
141
+
142
+ start_ix = start_ix - ix - block_size
143
+ start_jx = start_jx - ix - block_size
144
+
145
+ attn_mask_2d[start_ix+block_size:, start_jx+block_size-1] = -float('inf')
146
+
147
+ return attn_mask_2d
148
+
149
+
150
  def update_causal_mask_for_one_gen_window_2d(
151
  input_ids: torch.Tensor,
152
  attn_mask_2d: torch.Tensor,