Coverage for src/flag_gems/ops/scaled_softmax.py: 39%
122 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems.utils import libentry
7autotune_configs = [
8 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 128}, num_warps=4, num_stages=2),
9 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 256}, num_warps=4, num_stages=2),
10 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 512}, num_warps=4, num_stages=2),
11 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 1024}, num_warps=4, num_stages=2),
12 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 2048}, num_warps=4, num_stages=2),
13 triton.Config({"BLOCK_Q": 2, "BLOCK_K": 128}, num_warps=4, num_stages=2),
14 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 32}, num_warps=4, num_stages=2),
15 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 64}, num_warps=4, num_stages=2),
16 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 128}, num_warps=4, num_stages=2),
17 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 32}, num_warps=4, num_stages=2),
18 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 64}, num_warps=4, num_stages=2),
19 triton.Config({"BLOCK_Q": 16, "BLOCK_K": 32}, num_warps=4, num_stages=2),
20 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 128}, num_warps=8, num_stages=4),
21 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 256}, num_warps=8, num_stages=4),
22 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 512}, num_warps=8, num_stages=4),
23 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 128}, num_warps=8, num_stages=4),
24 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 256}, num_warps=8, num_stages=4),
25 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 512}, num_warps=8, num_stages=4),
26 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 1024}, num_warps=8, num_stages=4),
27 triton.Config({"BLOCK_Q": 128, "BLOCK_K": 512}, num_warps=8, num_stages=4),
28]
31@libentry()
32@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"])
33@triton.jit
34def scaled_softmax_forward_kernel(
35 output_ptr,
36 input_ptr,
37 scale_factor,
38 query_seq_len,
39 key_seq_len,
40 stride_b,
41 stride_h,
42 stride_q,
43 BLOCK_Q: tl.constexpr,
44 BLOCK_K: tl.constexpr,
45):
46 query_seq_tile_idx = tl.program_id(0)
47 attn_head_idx = tl.program_id(1)
48 batch_idx = tl.program_id(2)
50 start_query_idx = query_seq_tile_idx * BLOCK_Q
51 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q)
53 query_mask = query_offsets < query_seq_len
55 row_start_ptr = (
56 input_ptr
57 + batch_idx * stride_b
58 + attn_head_idx * stride_h
59 + query_offsets * stride_q
60 )
62 m = tl.full([BLOCK_Q], -float("inf"), dtype=tl.float32)
63 exp_sum = tl.zeros([BLOCK_Q], dtype=tl.float32)
65 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
66 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
67 block_ptr = row_start_ptr[:, None] + k_offsets[None, :]
69 row_mask = query_mask[:, None]
70 col_mask = k_offsets[None, :] < key_seq_len
71 mask = row_mask & col_mask
73 s_block = tl.load(
74 block_ptr, mask=mask, other=-float("inf"), cache_modifier=".ca"
75 )
76 s_block = s_block * scale_factor
78 m_new = tl.max(s_block, axis=1)
79 m_old = m
80 m = tl.maximum(m_old, m_new)
82 s_prev = tl.exp(m_old - m)
83 exp_sum = exp_sum * s_prev
85 s_curr = tl.exp(s_block - m[:, None])
86 l_new = tl.sum(tl.where(mask, s_curr, 0.0), axis=1)
87 exp_sum = exp_sum + l_new
89 exp_sum_inv = 1.0 / exp_sum
91 out_row_start_ptr = (
92 output_ptr
93 + batch_idx * stride_b
94 + attn_head_idx * stride_h
95 + query_offsets * stride_q
96 )
98 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
99 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
101 block_ptr_in = row_start_ptr[:, None] + k_offsets[None, :]
102 block_ptr_out = out_row_start_ptr[:, None] + k_offsets[None, :]
104 row_mask = query_mask[:, None]
105 col_mask = k_offsets[None, :] < key_seq_len
106 mask = row_mask & col_mask
108 s_block = tl.load(
109 block_ptr_in, mask=mask, other=-float("inf"), eviction_policy="evict_first"
110 )
112 s_block = s_block * scale_factor
113 s_block = s_block - m[:, None]
114 p_block = tl.exp(s_block)
115 p_block = p_block * exp_sum_inv[:, None]
117 tl.store(block_ptr_out, p_block, mask=mask, cache_modifier=".cs")
120def scaled_softmax_forward(input_t: torch.Tensor, scale_factor: float):
121 assert input_t.dim() == 4, "expected 4D tensor"
122 batch_size, attn_heads, query_seq_len, key_seq_len = input_t.shape
123 assert input_t.dtype in [
124 torch.float16,
125 torch.bfloat16,
126 ], "Only fp16 and bf16 are supported"
127 assert key_seq_len <= 16384, "Key sequence length must be 16384 or less"
128 assert key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"
129 assert query_seq_len > 1, "Query sequence length must be greater than 1"
131 def grid(meta):
132 BLOCK_Q = meta["BLOCK_Q"]
133 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q)
134 return (query_seq_tile_len, attn_heads, batch_size)
136 output_t = torch.empty_like(input_t)
138 stride_b = input_t.stride(0)
139 stride_h = input_t.stride(1)
140 stride_q = input_t.stride(2)
142 scaled_softmax_forward_kernel[grid](
143 output_t,
144 input_t,
145 scale_factor,
146 query_seq_len,
147 key_seq_len,
148 stride_b,
149 stride_h,
150 stride_q,
151 )
152 return output_t
155@libentry()
156@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"])
157@triton.jit
158def scaled_softmax_backward_kernel(
159 grad_input_ptr,
160 grad_output_ptr,
161 output_ptr,
162 scale_factor,
163 query_seq_len,
164 key_seq_len,
165 stride_b,
166 stride_h,
167 stride_q,
168 BLOCK_Q: tl.constexpr,
169 BLOCK_K: tl.constexpr,
170):
171 query_seq_tile_idx = tl.program_id(0)
172 attn_head_idx = tl.program_id(1)
173 batch_idx = tl.program_id(2)
175 start_query_idx = query_seq_tile_idx * BLOCK_Q
176 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q)
178 query_mask = query_offsets < query_seq_len
180 output_row_ptr = (
181 output_ptr
182 + batch_idx * stride_b
183 + attn_head_idx * stride_h
184 + query_offsets * stride_q
185 )
187 grad_output_row_ptr = (
188 grad_output_ptr
189 + batch_idx * stride_b
190 + attn_head_idx * stride_h
191 + query_offsets * stride_q
192 )
194 grad_input_row_ptr = (
195 grad_input_ptr
196 + batch_idx * stride_b
197 + attn_head_idx * stride_h
198 + query_offsets * stride_q
199 )
201 D = tl.zeros([BLOCK_Q], dtype=tl.float32)
203 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
204 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
205 row_mask = query_mask[:, None]
206 col_mask = k_offsets[None, :] < key_seq_len
207 mask = row_mask & col_mask
209 ptr_P = output_row_ptr[:, None] + k_offsets[None, :]
210 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :]
212 P_block = tl.load(ptr_P, mask=mask, other=0.0, cache_modifier=".ca")
213 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, cache_modifier=".ca")
215 dot_block = P_block * dP_block
216 D += tl.sum(tl.where(mask, dot_block, 0.0), axis=1)
218 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
219 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
220 row_mask = query_mask[:, None]
221 col_mask = k_offsets[None, :] < key_seq_len
222 mask = row_mask & col_mask
224 ptr_P = output_row_ptr[:, None] + k_offsets[None, :]
225 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :]
226 ptr_dS = grad_input_row_ptr[:, None] + k_offsets[None, :]
228 P_block = tl.load(ptr_P, mask=mask, other=0.0, eviction_policy="evict_first")
229 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, eviction_policy="evict_first")
231 dZ_block = P_block * (dP_block - D[:, None])
232 dS_block = scale_factor * dZ_block
234 tl.store(ptr_dS, dS_block, mask=mask, cache_modifier=".cs")
237def scaled_softmax_backward(
238 grad_output: torch.Tensor, softmax_results: torch.Tensor, scale_factor: float
239):
240 assert grad_output.dim() == 4, "expected 4D tensor"
241 assert softmax_results.dim() == 4, "expected 4D tensor"
242 assert grad_output.dtype in [
243 torch.float16,
244 torch.bfloat16,
245 ], "Only fp16 and bf16 are supported"
246 assert softmax_results.dtype in [
247 torch.float16,
248 torch.bfloat16,
249 ], "Only fp16 and bf16 are supported"
251 grad_output = grad_output.contiguous()
252 softmax_results = softmax_results.contiguous()
254 batch_size, attn_heads, query_seq_len, key_seq_len = softmax_results.shape
256 def grid(meta):
257 BLOCK_Q = meta["BLOCK_Q"]
258 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q)
259 return (query_seq_tile_len, attn_heads, batch_size)
261 grad_input = torch.empty_like(grad_output)
263 stride_b = softmax_results.stride(0)
264 stride_h = softmax_results.stride(1)
265 stride_q = softmax_results.stride(2)
267 scaled_softmax_backward_kernel[grid](
268 grad_input,
269 grad_output,
270 softmax_results,
271 scale_factor,
272 query_seq_len,
273 key_seq_len,
274 stride_b,
275 stride_h,
276 stride_q,
277 )
279 return grad_input