Coverage for src/flag_gems/fused/reglu.py: 53%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
2from typing import Any, Optional
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger(__name__)
14@libentry()
15@libtuner(
16 configs=runtime.get_tuned_config("gated_activation"),
17 key=["M", "N"],
18)
19@triton.jit
20def dreglu_kernel(
21 grad_output_ptr,
22 input_ptr,
23 grad_input_ptr,
24 M,
25 N,
26 stride_grad_out_m,
27 stride_grad_out_n,
28 stride_in_m,
29 stride_in_n,
30 stride_grad_in_m,
31 stride_grad_in_n,
32 BLOCK_M: tl.constexpr,
33 BLOCK_N: tl.constexpr,
34):
35 pid_m = tl.program_id(axis=0)
36 pid_n = tl.program_id(axis=1)
37 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
38 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
39 grad_output_ptr += (
40 offs_m[:, None] * stride_grad_out_m + offs_n[None, :] * stride_grad_out_n
41 )
42 input_ptr_a = (
43 input_ptr + offs_m[:, None] * stride_in_m + offs_n[None, :] * stride_in_n
44 )
45 input_ptr_b = (
46 input_ptr + offs_m[:, None] * stride_in_m + (offs_n[None, :] + N) * stride_in_n
47 )
48 grad_input_ptr_a = (
49 grad_input_ptr
50 + offs_m[:, None] * stride_grad_in_m
51 + offs_n[None, :] * stride_grad_in_n
52 )
53 grad_input_ptr_b = (
54 grad_input_ptr
55 + offs_m[:, None] * stride_grad_in_m
56 + (offs_n[None, :] + N) * stride_grad_in_n
57 )
58 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
59 grad_out = tl.load(grad_output_ptr, mask=mask, other=0.0)
60 block_a = tl.load(input_ptr_a, mask=mask, other=0.0)
61 block_b = tl.load(input_ptr_b, mask=mask, other=0.0)
62 relu_a = tl.maximum(block_a, 0.0)
63 d_relu_a = tl.where(block_a > 0, 1.0, 0.0)
64 grad_a = grad_out * d_relu_a * block_b
65 grad_b = grad_out * relu_a
66 tl.store(grad_input_ptr_a, grad_a, mask=mask)
67 tl.store(grad_input_ptr_b, grad_b, mask=mask)
70@libentry()
71@libtuner(
72 configs=runtime.get_tuned_config("gated_activation"),
73 key=["M", "N_OUT"],
74)
75@triton.jit
76def reglu_kernel(
77 x_ptr,
78 y_ptr,
79 M,
80 N_OUT,
81 stride_x_m,
82 stride_x_n,
83 stride_y_m,
84 stride_y_n,
85 BLOCK_M: tl.constexpr,
86 BLOCK_N: tl.constexpr,
87):
88 pid_m = tl.program_id(axis=0)
89 pid_n = tl.program_id(axis=1)
90 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
91 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
92 x_ptr_a = x_ptr + offs_m[:, None] * stride_x_m + offs_n[None, :] * stride_x_n
93 x_ptr_b = (
94 x_ptr + offs_m[:, None] * stride_x_m + (offs_n[None, :] + N_OUT) * stride_x_n
95 )
96 y_ptr = y_ptr + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
97 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N_OUT)
98 block_a = tl.load(x_ptr_a, mask=mask, other=0.0)
99 block_b = tl.load(x_ptr_b, mask=mask, other=0.0)
100 gate = tl.where(block_a > 0, block_a, 0.0)
101 output = gate * block_b
102 tl.store(y_ptr, output, mask=mask)
105def reglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor:
106 shape = input_tensor.shape
107 if input_tensor.dim() < 1:
108 raise ValueError("Input tensor must have at least 1 dimension.")
109 last_dim = shape[-1]
110 if last_dim % 2 != 0:
111 raise ValueError(
112 f"The last dimension of the input tensor must be even, but got {last_dim}."
113 )
114 N_OUT = last_dim // 2
115 M = input_tensor.numel() // last_dim
116 if input_tensor.numel() == 0:
117 output_shape = (*shape[:-1], N_OUT)
118 return torch.empty(
119 output_shape, device=input_tensor.device, dtype=input_tensor.dtype
120 )
121 input_2d = input_tensor.contiguous().view(M, last_dim)
122 output_2d = torch.empty(
123 (M, N_OUT), device=input_tensor.device, dtype=input_tensor.dtype
124 )
125 grid = lambda META: (
126 triton.cdiv(M, META["BLOCK_M"]),
127 triton.cdiv(N_OUT, META["BLOCK_N"]),
128 )
129 reglu_kernel[grid](
130 input_2d,
131 output_2d,
132 M,
133 N_OUT,
134 input_2d.stride(0),
135 input_2d.stride(1),
136 output_2d.stride(0),
137 output_2d.stride(1),
138 )
139 output_shape = (*shape[:-1], N_OUT)
140 return output_2d.view(output_shape)
143def dreglu(
144 grad_output: torch.Tensor,
145 input_tensor: torch.Tensor,
146 quantizer: Optional[Any] = None,
147) -> torch.Tensor:
148 shape = input_tensor.shape
149 if shape[:-1] != grad_output.shape[:-1] or shape[-1] != 2 * grad_output.shape[-1]:
150 raise ValueError(
151 f"Shape mismatch: input {shape} vs grad_output {grad_output.shape}"
152 )
153 M = grad_output.numel() // grad_output.shape[-1]
154 N = grad_output.shape[-1]
155 grad_output_2d = grad_output.contiguous().view(M, N)
156 input_2d = input_tensor.contiguous().view(M, 2 * N)
157 grad_input = torch.empty_like(input_2d)
158 grid = lambda META: (
159 triton.cdiv(M, META["BLOCK_M"]),
160 triton.cdiv(N, META["BLOCK_N"]),
161 )
162 dreglu_kernel[grid](
163 grad_output_2d,
164 input_2d,
165 grad_input,
166 M,
167 N,
168 grad_output_2d.stride(0),
169 grad_output_2d.stride(1),
170 input_2d.stride(0),
171 input_2d.stride(1),
172 grad_input.stride(0),
173 grad_input.stride(1),
174 )
175 return grad_input.view(shape)