Coverage for src/flag_gems/ops/scaled_softmax.py: 41%
126 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
9autotune_configs = [
10 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 128}, num_warps=4, num_stages=2),
11 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 256}, num_warps=4, num_stages=2),
12 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 512}, num_warps=4, num_stages=2),
13 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 1024}, num_warps=4, num_stages=2),
14 triton.Config({"BLOCK_Q": 1, "BLOCK_K": 2048}, num_warps=4, num_stages=2),
15 triton.Config({"BLOCK_Q": 2, "BLOCK_K": 128}, num_warps=4, num_stages=2),
16 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 32}, num_warps=4, num_stages=2),
17 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 64}, num_warps=4, num_stages=2),
18 triton.Config({"BLOCK_Q": 4, "BLOCK_K": 128}, num_warps=4, num_stages=2),
19 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 32}, num_warps=4, num_stages=2),
20 triton.Config({"BLOCK_Q": 8, "BLOCK_K": 64}, num_warps=4, num_stages=2),
21 triton.Config({"BLOCK_Q": 16, "BLOCK_K": 32}, num_warps=4, num_stages=2),
22 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 128}, num_warps=8, num_stages=4),
23 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 256}, num_warps=8, num_stages=4),
24 triton.Config({"BLOCK_Q": 32, "BLOCK_K": 512}, num_warps=8, num_stages=4),
25 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 128}, num_warps=8, num_stages=4),
26 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 256}, num_warps=8, num_stages=4),
27 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 512}, num_warps=8, num_stages=4),
28 triton.Config({"BLOCK_Q": 64, "BLOCK_K": 1024}, num_warps=8, num_stages=4),
29 triton.Config({"BLOCK_Q": 128, "BLOCK_K": 512}, num_warps=8, num_stages=4),
30]
33logger = logging.getLogger(__name__)
36@libentry()
37@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"])
38@triton.jit
39def scaled_softmax_forward_kernel(
40 output_ptr,
41 input_ptr,
42 scale_factor,
43 query_seq_len,
44 key_seq_len,
45 stride_b,
46 stride_h,
47 stride_q,
48 BLOCK_Q: tl.constexpr,
49 BLOCK_K: tl.constexpr,
50):
51 query_seq_tile_idx = tl.program_id(0)
52 attn_head_idx = tl.program_id(1)
53 batch_idx = tl.program_id(2)
55 start_query_idx = query_seq_tile_idx * BLOCK_Q
56 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q)
58 query_mask = query_offsets < query_seq_len
60 row_start_ptr = (
61 input_ptr
62 + batch_idx * stride_b
63 + attn_head_idx * stride_h
64 + query_offsets * stride_q
65 )
67 m = tl.full([BLOCK_Q], -float("inf"), dtype=tl.float32)
68 exp_sum = tl.zeros([BLOCK_Q], dtype=tl.float32)
70 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
71 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
72 block_ptr = row_start_ptr[:, None] + k_offsets[None, :]
74 row_mask = query_mask[:, None]
75 col_mask = k_offsets[None, :] < key_seq_len
76 mask = row_mask & col_mask
78 s_block = tl.load(
79 block_ptr, mask=mask, other=-float("inf"), cache_modifier=".ca"
80 )
81 s_block = s_block * scale_factor
83 m_new = tl.max(s_block, axis=1)
84 m_old = m
85 m = tl.maximum(m_old, m_new)
87 s_prev = tl.exp(m_old - m)
88 exp_sum = exp_sum * s_prev
90 s_curr = tl.exp(s_block - m[:, None])
91 l_new = tl.sum(tl.where(mask, s_curr, 0.0), axis=1)
92 exp_sum = exp_sum + l_new
94 exp_sum_inv = 1.0 / exp_sum
96 out_row_start_ptr = (
97 output_ptr
98 + batch_idx * stride_b
99 + attn_head_idx * stride_h
100 + query_offsets * stride_q
101 )
103 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
104 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
106 block_ptr_in = row_start_ptr[:, None] + k_offsets[None, :]
107 block_ptr_out = out_row_start_ptr[:, None] + k_offsets[None, :]
109 row_mask = query_mask[:, None]
110 col_mask = k_offsets[None, :] < key_seq_len
111 mask = row_mask & col_mask
113 s_block = tl.load(
114 block_ptr_in, mask=mask, other=-float("inf"), eviction_policy="evict_first"
115 )
117 s_block = s_block * scale_factor
118 s_block = s_block - m[:, None]
119 p_block = tl.exp(s_block)
120 p_block = p_block * exp_sum_inv[:, None]
122 tl.store(block_ptr_out, p_block, mask=mask, cache_modifier=".cs")
125def scaled_softmax_forward(input_t: torch.Tensor, scale_factor: float):
126 logger.debug("GEMS SCALED SOFTMAX FORWARD")
127 assert input_t.dim() == 4, "expected 4D tensor"
128 batch_size, attn_heads, query_seq_len, key_seq_len = input_t.shape
129 assert input_t.dtype in [
130 torch.float16,
131 torch.bfloat16,
132 ], "Only fp16 and bf16 are supported"
133 assert key_seq_len <= 16384, "Key sequence length must be 16384 or less"
134 assert key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"
135 assert query_seq_len > 1, "Query sequence length must be greater than 1"
137 def grid(meta):
138 BLOCK_Q = meta["BLOCK_Q"]
139 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q)
140 return (query_seq_tile_len, attn_heads, batch_size)
142 output_t = torch.empty_like(input_t)
144 stride_b = input_t.stride(0)
145 stride_h = input_t.stride(1)
146 stride_q = input_t.stride(2)
148 scaled_softmax_forward_kernel[grid](
149 output_t,
150 input_t,
151 scale_factor,
152 query_seq_len,
153 key_seq_len,
154 stride_b,
155 stride_h,
156 stride_q,
157 )
158 return output_t
161@libentry()
162@triton.autotune(configs=autotune_configs, key=["query_seq_len", "key_seq_len"])
163@triton.jit
164def scaled_softmax_backward_kernel(
165 grad_input_ptr,
166 grad_output_ptr,
167 output_ptr,
168 scale_factor,
169 query_seq_len,
170 key_seq_len,
171 stride_b,
172 stride_h,
173 stride_q,
174 BLOCK_Q: tl.constexpr,
175 BLOCK_K: tl.constexpr,
176):
177 query_seq_tile_idx = tl.program_id(0)
178 attn_head_idx = tl.program_id(1)
179 batch_idx = tl.program_id(2)
181 start_query_idx = query_seq_tile_idx * BLOCK_Q
182 query_offsets = start_query_idx + tl.arange(0, BLOCK_Q)
184 query_mask = query_offsets < query_seq_len
186 output_row_ptr = (
187 output_ptr
188 + batch_idx * stride_b
189 + attn_head_idx * stride_h
190 + query_offsets * stride_q
191 )
193 grad_output_row_ptr = (
194 grad_output_ptr
195 + batch_idx * stride_b
196 + attn_head_idx * stride_h
197 + query_offsets * stride_q
198 )
200 grad_input_row_ptr = (
201 grad_input_ptr
202 + batch_idx * stride_b
203 + attn_head_idx * stride_h
204 + query_offsets * stride_q
205 )
207 D = tl.zeros([BLOCK_Q], dtype=tl.float32)
209 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
210 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
211 row_mask = query_mask[:, None]
212 col_mask = k_offsets[None, :] < key_seq_len
213 mask = row_mask & col_mask
215 ptr_P = output_row_ptr[:, None] + k_offsets[None, :]
216 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :]
218 P_block = tl.load(ptr_P, mask=mask, other=0.0, cache_modifier=".ca")
219 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, cache_modifier=".ca")
221 dot_block = P_block * dP_block
222 D += tl.sum(tl.where(mask, dot_block, 0.0), axis=1)
224 for k_block_idx in range(0, tl.cdiv(key_seq_len, BLOCK_K)):
225 k_offsets = k_block_idx * BLOCK_K + tl.arange(0, BLOCK_K)
226 row_mask = query_mask[:, None]
227 col_mask = k_offsets[None, :] < key_seq_len
228 mask = row_mask & col_mask
230 ptr_P = output_row_ptr[:, None] + k_offsets[None, :]
231 ptr_dP = grad_output_row_ptr[:, None] + k_offsets[None, :]
232 ptr_dS = grad_input_row_ptr[:, None] + k_offsets[None, :]
234 P_block = tl.load(ptr_P, mask=mask, other=0.0, eviction_policy="evict_first")
235 dP_block = tl.load(ptr_dP, mask=mask, other=0.0, eviction_policy="evict_first")
237 dZ_block = P_block * (dP_block - D[:, None])
238 dS_block = scale_factor * dZ_block
240 tl.store(ptr_dS, dS_block, mask=mask, cache_modifier=".cs")
243def scaled_softmax_backward(
244 grad_output: torch.Tensor, softmax_results: torch.Tensor, scale_factor: float
245):
246 logger.debug("GEMS SCALED SOFTMAX BACKWARD")
247 assert grad_output.dim() == 4, "expected 4D tensor"
248 assert softmax_results.dim() == 4, "expected 4D tensor"
249 assert grad_output.dtype in [
250 torch.float16,
251 torch.bfloat16,
252 ], "Only fp16 and bf16 are supported"
253 assert softmax_results.dtype in [
254 torch.float16,
255 torch.bfloat16,
256 ], "Only fp16 and bf16 are supported"
258 grad_output = grad_output.contiguous()
259 softmax_results = softmax_results.contiguous()
261 batch_size, attn_heads, query_seq_len, key_seq_len = softmax_results.shape
263 def grid(meta):
264 BLOCK_Q = meta["BLOCK_Q"]
265 query_seq_tile_len = triton.cdiv(query_seq_len, BLOCK_Q)
266 return (query_seq_tile_len, attn_heads, batch_size)
268 grad_input = torch.empty_like(grad_output)
270 stride_b = softmax_results.stride(0)
271 stride_h = softmax_results.stride(1)
272 stride_q = softmax_results.stride(2)
274 scaled_softmax_backward_kernel[grid](
275 grad_input,
276 grad_output,
277 softmax_results,
278 scale_factor,
279 query_seq_len,
280 key_seq_len,
281 stride_b,
282 stride_h,
283 stride_q,
284 )
286 return grad_input