Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/reglu.py: 0%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2from typing import Any, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def heur_tile_m(args):
15 return triton.cdiv(args["M"], 12) # cluster_num
18def heru_tile_n(args):
19 import builtins
21 return builtins.min(args["N"], 8192)
24@libentry()
25@libtuner(
26 configs=[
27 triton.Config({"BLOCK_M": 1, "BLOCK_N": 1024}),
28 triton.Config({"BLOCK_M": 2, "BLOCK_N": 1024}),
29 triton.Config({"BLOCK_M": 4, "BLOCK_N": 1024}),
30 triton.Config({"BLOCK_M": 8, "BLOCK_N": 1024}),
31 triton.Config({"BLOCK_M": 6, "BLOCK_N": 32}),
32 triton.Config({"BLOCK_M": 342, "BLOCK_N": 2048}),
33 triton.Config({"BLOCK_M": 2731, "BLOCK_N": 256}),
34 ],
35 key=["M", "N"],
36)
37# @triton.heuristics(
38# values={
39# "BLOCK_M": heur_tile_m,
40# "BLOCK_N": heru_tile_n,
41# },
42# )
43@triton.jit
44def dreglu_kernel(
45 grad_output_ptr,
46 input_ptr,
47 grad_input_ptr,
48 M,
49 N,
50 stride_grad_out_m,
51 stride_grad_out_n,
52 stride_in_m,
53 stride_in_n,
54 stride_grad_in_m,
55 stride_grad_in_n,
56 BLOCK_M: tl.constexpr,
57 BLOCK_N: tl.constexpr,
58):
59 pid_m = tl.program_id(axis=0)
60 pid_n = tl.program_id(axis=1)
61 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
62 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
63 grad_output_ptr += (
64 offs_m[:, None] * stride_grad_out_m + offs_n[None, :] * stride_grad_out_n
65 )
66 input_ptr_a = (
67 input_ptr + offs_m[:, None] * stride_in_m + offs_n[None, :] * stride_in_n
68 )
69 input_ptr_b = (
70 input_ptr + offs_m[:, None] * stride_in_m + (offs_n[None, :] + N) * stride_in_n
71 )
72 grad_input_ptr_a = (
73 grad_input_ptr
74 + offs_m[:, None] * stride_grad_in_m
75 + offs_n[None, :] * stride_grad_in_n
76 )
77 grad_input_ptr_b = (
78 grad_input_ptr
79 + offs_m[:, None] * stride_grad_in_m
80 + (offs_n[None, :] + N) * stride_grad_in_n
81 )
82 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
83 grad_out = tl.load(grad_output_ptr, mask=mask, other=0.0).to(tl.float32)
84 block_a = tl.load(input_ptr_a, mask=mask, other=0.0).to(tl.float32)
85 block_b = tl.load(input_ptr_b, mask=mask, other=0.0).to(tl.float32)
86 relu_a = tl.maximum(block_a, 0.0)
87 d_relu_a = tl.where(block_a > 0, 1.0, 0.0)
88 grad_a = grad_out * d_relu_a * block_b
89 grad_b = grad_out * relu_a
90 tl.store(grad_input_ptr_a, grad_a, mask=mask)
91 tl.store(grad_input_ptr_b, grad_b, mask=mask)
94@libentry()
95@libtuner(
96 configs=runtime.get_tuned_config("gated_activation"),
97 key=["M", "N_OUT"],
98)
99@triton.jit
100def reglu_kernel(
101 x_ptr,
102 y_ptr,
103 M,
104 N_OUT,
105 stride_x_m,
106 stride_x_n,
107 stride_y_m,
108 stride_y_n,
109 BLOCK_M: tl.constexpr,
110 BLOCK_N: tl.constexpr,
111):
112 pid_m = tl.program_id(axis=0)
113 pid_n = tl.program_id(axis=1)
114 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
115 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
116 x_ptr_a = x_ptr + offs_m[:, None] * stride_x_m + offs_n[None, :] * stride_x_n
117 x_ptr_b = (
118 x_ptr + offs_m[:, None] * stride_x_m + (offs_n[None, :] + N_OUT) * stride_x_n
119 )
120 y_ptr = y_ptr + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
121 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N_OUT)
122 block_a = tl.load(x_ptr_a, mask=mask, other=0.0)
123 block_b = tl.load(x_ptr_b, mask=mask, other=0.0)
124 gate = tl.where(block_a > 0, block_a, 0.0)
125 output = gate * block_b
126 tl.store(y_ptr, output, mask=mask)
129def reglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor:
130 shape = input_tensor.shape
131 if input_tensor.dim() < 1:
132 raise ValueError("Input tensor must have at least 1 dimension.")
133 last_dim = shape[-1]
134 if last_dim % 2 != 0:
135 raise ValueError(
136 f"The last dimension of the input tensor must be even, but got {last_dim}."
137 )
138 N_OUT = last_dim // 2
139 M = input_tensor.numel() // last_dim
140 if input_tensor.numel() == 0:
141 output_shape = (*shape[:-1], N_OUT)
142 return torch.empty(
143 output_shape, device=input_tensor.device, dtype=input_tensor.dtype
144 )
145 input_2d = input_tensor.contiguous().view(M, last_dim)
146 output_2d = torch.empty(
147 (M, N_OUT), device=input_tensor.device, dtype=input_tensor.dtype
148 )
149 grid = lambda META: (
150 triton.cdiv(M, META["BLOCK_M"]),
151 triton.cdiv(N_OUT, META["BLOCK_N"]),
152 )
153 reglu_kernel[grid](
154 input_2d,
155 output_2d,
156 M,
157 N_OUT,
158 input_2d.stride(0),
159 input_2d.stride(1),
160 output_2d.stride(0),
161 output_2d.stride(1),
162 )
163 output_shape = (*shape[:-1], N_OUT)
164 return output_2d.view(output_shape)
167def dreglu(
168 grad_output: torch.Tensor,
169 input_tensor: torch.Tensor,
170 quantizer: Optional[Any] = None,
171) -> torch.Tensor:
172 shape = input_tensor.shape
173 if shape[:-1] != grad_output.shape[:-1] or shape[-1] != 2 * grad_output.shape[-1]:
174 raise ValueError(
175 f"Shape mismatch: input {shape} vs grad_output {grad_output.shape}"
176 )
177 M = grad_output.numel() // grad_output.shape[-1]
178 N = grad_output.shape[-1]
179 grad_output_2d = grad_output.contiguous().view(M, N)
180 input_2d = input_tensor.contiguous().view(M, 2 * N)
181 grad_input = torch.empty_like(input_2d)
182 grid = lambda META: (
183 triton.cdiv(M, META["BLOCK_M"]),
184 triton.cdiv(N, META["BLOCK_N"]),
185 )
186 dreglu_kernel[grid](
187 grad_output_2d,
188 input_2d,
189 grad_input,
190 M,
191 N,
192 grad_output_2d.stride(0),
193 grad_output_2d.stride(1),
194 input_2d.stride(0),
195 input_2d.stride(1),
196 grad_input.stride(0),
197 grad_input.stride(1),
198 )
199 return grad_input.view(shape)