Coverage for src/flag_gems/experimental_ops/glu.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import math # noqa: F401
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def glu_kernel(
10 x_ptr, # *Pointer* to input tensor data (flattened, contiguous).
11 y_ptr, # *Pointer* to output tensor data (flattened, contiguous).
12 n_out_elements, # Number of elements in the output tensor.
13 inner_size, # Product of sizes of dims after 'dim' in output shape.
14 half_size, # Size along 'dim' in output shape (i.e., original dim size // 2).
15 outer_elems, # Number of elements per 'outer' slice in the input: (2*half_size)*inner_size.
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < n_out_elements
23 idx = offsets
24 s = half_size
25 inner = inner_size
26 outer_inc = outer_elems
28 # Map each output index to the corresponding input indices.
29 # For contiguous tensors:
30 # - output shape: [..., s, ...]; n_out = outer * s * inner
31 # - input shape: [..., 2*s, ...]
32 # Let:
33 # o = idx // (s * inner)
34 # r = idx % (s * inner)
35 # d = r // inner
36 # i = r % inner
37 # Then:
38 # x_left_index = o * (2*s*inner) + d * inner + i
39 # x_right_index = x_left_index + s * inner
40 denom = s * inner
41 o = idx // denom
42 r = idx % denom
43 d = r // inner
44 i = r % inner
46 x_index = o * outer_inc + d * inner + i
47 gate_index = x_index + s * inner
49 x_val = tl.load(x_ptr + x_index, mask=mask, other=0.0)
50 g_val = tl.load(x_ptr + gate_index, mask=mask, other=0.0)
52 x_f = x_val.to(tl.float32)
53 g_f = g_val.to(tl.float32)
54 gate = 1.0 / (1.0 + tl.exp(-g_f))
55 y = x_f * gate
56 y_cast = y.to(x_val.dtype)
58 tl.store(y_ptr + idx, y_cast, mask=mask)
61def _normalize_dim(dim: int, ndim: int) -> int:
62 if dim < 0:
63 dim += ndim
64 if not (0 <= dim < ndim):
65 actual_dim = dim - ndim if dim >= ndim else dim
66 raise IndexError(
67 f"Dimension out of range (expected to be in range of "
68 f"[{-ndim}, {ndim - 1}], but got {actual_dim})"
69 )
70 return dim
73def _check_dtype_supported(dtype: torch.dtype):
74 if dtype not in (torch.float16, torch.bfloat16, torch.float32):
75 raise TypeError(
76 f"Unsupported dtype {dtype}. Supported dtypes are: float16, bfloat16, float32."
77 )
80def _glu_launch(x: torch.Tensor, dim: int, out: torch.Tensor = None) -> torch.Tensor:
81 if not x.is_cuda:
82 raise AssertionError("Input tensor must be on CUDA device.")
83 x = x.contiguous()
84 _check_dtype_supported(x.dtype)
86 ndim = x.dim()
87 dim = _normalize_dim(dim, ndim)
88 size_dim = x.size(dim)
89 if size_dim % 2 != 0:
90 raise RuntimeError(
91 f"glu: dimension {dim} size must be even, but got {size_dim}."
92 )
94 half = size_dim // 2
96 # Compute output shape
97 out_shape = list(x.shape)
98 out_shape[dim] = half
100 # Prepare output
101 if out is None:
102 out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
103 else:
104 if not out.is_cuda:
105 raise AssertionError("Output tensor must be on CUDA device.")
106 if tuple(out.shape) != tuple(out_shape):
107 raise RuntimeError(
108 f"glu_out: provided out has wrong shape. Expected {tuple(out_shape)}, got {tuple(out.shape)}."
109 )
110 if out.dtype != x.dtype:
111 raise RuntimeError(
112 f"glu_out: dtype mismatch. out.dtype={out.dtype}, expected {x.dtype}."
113 )
114 if not out.is_contiguous():
115 raise RuntimeError("glu_out: output tensor must be contiguous.")
116 out = out.contiguous()
118 # Compute mapping parameters for contiguous layout
119 # inner_size = product of dimensions after 'dim' in the output shape
120 inner_size = 1
121 for k in range(dim + 1, ndim):
122 inner_size *= out_shape[k]
124 n_out = out.numel()
125 outer_elems = (2 * half) * inner_size # elements per 'outer' slice in input
127 BLOCK_SIZE = 1024
128 grid = lambda meta: (triton.cdiv(n_out, meta["BLOCK_SIZE"]),) # noqa: E731
130 glu_kernel[grid](
131 x,
132 out,
133 n_out,
134 inner_size,
135 half,
136 outer_elems,
137 BLOCK_SIZE=BLOCK_SIZE,
138 )
139 return out
142def glu(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
143 return _glu_launch(input, dim, out=None)
146def glu_out(
147 input: torch.Tensor, dim: int = -1, out: torch.Tensor = None
148) -> torch.Tensor:
149 if out is None:
150 raise RuntimeError("glu_out: 'out' tensor must be provided.")
151 return _glu_launch(input, dim, out=out)