Coverage for src/flag_gems/runtime/backend/_mthreads/ops/celu.py: 0%
82 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import math
3from typing import Tuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems.ops.celu import celu as default_celu
10from flag_gems.ops.celu import celu_ as default_celu_
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry, tl_extra_shim
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
18_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float32}
19exp = tl_extra_shim.exp
22@libentry()
23@triton.autotune(
24 configs=[
25 triton.Config({"BLOCK_SIZE": 256, "VEC": 4}, num_warps=4, num_stages=1),
26 triton.Config({"BLOCK_SIZE": 256, "VEC": 2}, num_warps=4, num_stages=1),
27 triton.Config({"BLOCK_SIZE": 512, "VEC": 2}, num_warps=8, num_stages=1),
28 triton.Config({"BLOCK_SIZE": 512, "VEC": 4}, num_warps=8, num_stages=1),
29 triton.Config({"BLOCK_SIZE": 1024, "VEC": 1}, num_warps=4, num_stages=2),
30 triton.Config({"BLOCK_SIZE": 1024, "VEC": 2}, num_warps=8, num_stages=2),
31 ],
32 key=["n_elements", "dtype_size"],
33)
34@triton.jit
35def celu_kernel_alpha1(
36 x_ptr,
37 out_ptr,
38 n_elements,
39 dtype_size, # used for autotune key
40 BLOCK_SIZE: tl.constexpr,
41 VEC: tl.constexpr,
42):
43 pid = tl.program_id(0)
44 BLOCK_ELEMS: tl.constexpr = BLOCK_SIZE * VEC
45 offsets = (pid * BLOCK_ELEMS + tl.arange(0, BLOCK_ELEMS)).to(tl.int64)
46 mask = offsets < n_elements
47 x = tl.load(x_ptr + offsets, mask=mask)
49 x_compute = x.to(tl.float32)
50 neg_mask = x_compute <= 0
51 exp_val = exp(tl.where(neg_mask, x_compute, 0.0))
52 neg = exp_val - 1.0
53 out = tl.where(neg_mask, neg, x_compute).to(x.dtype)
55 tl.store(out_ptr + offsets, out, mask=mask)
58@triton.autotune(
59 configs=[
60 triton.Config({"BLOCK_SIZE": 256, "VEC": 4}, num_warps=4, num_stages=1),
61 triton.Config({"BLOCK_SIZE": 256, "VEC": 2}, num_warps=4, num_stages=1),
62 triton.Config({"BLOCK_SIZE": 512, "VEC": 2}, num_warps=8, num_stages=1),
63 triton.Config({"BLOCK_SIZE": 512, "VEC": 4}, num_warps=8, num_stages=1),
64 triton.Config({"BLOCK_SIZE": 1024, "VEC": 1}, num_warps=4, num_stages=2),
65 triton.Config({"BLOCK_SIZE": 1024, "VEC": 2}, num_warps=8, num_stages=2),
66 ],
67 key=["n_elements", "dtype_size"],
68)
69@triton.jit(do_not_specialize=["alpha"])
70def celu_kernel(
71 x_ptr,
72 out_ptr,
73 n_elements,
74 alpha,
75 dtype_size, # used for autotune key
76 BLOCK_SIZE: tl.constexpr,
77 VEC: tl.constexpr,
78):
79 pid = tl.program_id(0)
80 BLOCK_ELEMS: tl.constexpr = BLOCK_SIZE * VEC
81 offsets = (pid * BLOCK_ELEMS + tl.arange(0, BLOCK_ELEMS)).to(tl.int64)
82 mask = offsets < n_elements
83 x = tl.load(x_ptr + offsets, mask=mask)
85 x_compute = x.to(tl.float32)
86 alpha_val = tl.full((1,), alpha, tl.float32)
87 inv_alpha = 1.0 / alpha_val
88 neg_mask = x_compute <= 0
89 exp_val = exp(tl.where(neg_mask, x_compute * inv_alpha, 0.0))
90 neg = alpha_val * (exp_val - 1.0)
91 out = tl.where(neg_mask, neg, x_compute).to(x.dtype)
93 tl.store(out_ptr + offsets, out, mask=mask)
96def _use_triton_kernel(
97 A: torch.Tensor, alpha, *, is_inplace: bool
98) -> Tuple[bool, float]:
99 if not isinstance(A, torch.Tensor):
100 return False, 0.0
101 if A.device.type != "musa" or A.dtype not in _SUPPORTED_DTYPES:
102 return False, 0.0
103 if not A.is_contiguous() or A.numel() == 0:
104 return False, 0.0
105 try:
106 alpha_value = (
107 float(alpha) if not isinstance(alpha, torch.Tensor) else float(alpha.item())
108 )
109 except Exception:
110 return False, 0.0
111 if not math.isfinite(alpha_value):
112 return False, 0.0
113 return True, alpha_value
116def _launch_celu(A: torch.Tensor, out: torch.Tensor, alpha_value: float):
117 x_flat = A.view(-1)
118 out_flat = out.view(-1)
119 n_elements = out_flat.numel()
120 dtype_size = out_flat.element_size()
121 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"] * META["VEC"]),)
122 with torch_device_fn.device(out.device):
123 if alpha_value == 1.0:
124 celu_kernel_alpha1[grid](x_flat, out_flat, n_elements, dtype_size)
125 else:
126 celu_kernel[grid](x_flat, out_flat, n_elements, alpha_value, dtype_size)
127 return out
130def celu(A, alpha=1.0):
131 logger.debug("GEMS_MTHREADS CELU")
132 use_triton, alpha_value = _use_triton_kernel(A, alpha, is_inplace=False)
133 if not use_triton:
134 return default_celu(A, alpha=alpha)
136 out = torch.empty_like(A)
137 return _launch_celu(A, out, alpha_value)
140def celu_(A, alpha=1.0):
141 logger.debug("GEMS_MTHREADS CELU_")
142 use_triton, alpha_value = _use_triton_kernel(A, alpha, is_inplace=True)
143 if not use_triton:
144 return default_celu_(A, alpha=alpha)
146 return _launch_celu(A, A, alpha_value)