Coverage for src/flag_gems/runtime/backend/_mthreads/ops/celu.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import math 

3from typing import Tuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.ops.celu import celu as default_celu 

10from flag_gems.ops.celu import celu_ as default_celu_ 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry, tl_extra_shim 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float32} 

19exp = tl_extra_shim.exp 

20 

21 

22@libentry() 

23@triton.autotune( 

24 configs=[ 

25 triton.Config({"BLOCK_SIZE": 256, "VEC": 4}, num_warps=4, num_stages=1), 

26 triton.Config({"BLOCK_SIZE": 256, "VEC": 2}, num_warps=4, num_stages=1), 

27 triton.Config({"BLOCK_SIZE": 512, "VEC": 2}, num_warps=8, num_stages=1), 

28 triton.Config({"BLOCK_SIZE": 512, "VEC": 4}, num_warps=8, num_stages=1), 

29 triton.Config({"BLOCK_SIZE": 1024, "VEC": 1}, num_warps=4, num_stages=2), 

30 triton.Config({"BLOCK_SIZE": 1024, "VEC": 2}, num_warps=8, num_stages=2), 

31 ], 

32 key=["n_elements", "dtype_size"], 

33) 

34@triton.jit 

35def celu_kernel_alpha1( 

36 x_ptr, 

37 out_ptr, 

38 n_elements, 

39 dtype_size, # used for autotune key 

40 BLOCK_SIZE: tl.constexpr, 

41 VEC: tl.constexpr, 

42): 

43 pid = tl.program_id(0) 

44 BLOCK_ELEMS: tl.constexpr = BLOCK_SIZE * VEC 

45 offsets = (pid * BLOCK_ELEMS + tl.arange(0, BLOCK_ELEMS)).to(tl.int64) 

46 mask = offsets < n_elements 

47 x = tl.load(x_ptr + offsets, mask=mask) 

48 

49 x_compute = x.to(tl.float32) 

50 neg_mask = x_compute <= 0 

51 exp_val = exp(tl.where(neg_mask, x_compute, 0.0)) 

52 neg = exp_val - 1.0 

53 out = tl.where(neg_mask, neg, x_compute).to(x.dtype) 

54 

55 tl.store(out_ptr + offsets, out, mask=mask) 

56 

57 

58@triton.autotune( 

59 configs=[ 

60 triton.Config({"BLOCK_SIZE": 256, "VEC": 4}, num_warps=4, num_stages=1), 

61 triton.Config({"BLOCK_SIZE": 256, "VEC": 2}, num_warps=4, num_stages=1), 

62 triton.Config({"BLOCK_SIZE": 512, "VEC": 2}, num_warps=8, num_stages=1), 

63 triton.Config({"BLOCK_SIZE": 512, "VEC": 4}, num_warps=8, num_stages=1), 

64 triton.Config({"BLOCK_SIZE": 1024, "VEC": 1}, num_warps=4, num_stages=2), 

65 triton.Config({"BLOCK_SIZE": 1024, "VEC": 2}, num_warps=8, num_stages=2), 

66 ], 

67 key=["n_elements", "dtype_size"], 

68) 

69@triton.jit(do_not_specialize=["alpha"]) 

70def celu_kernel( 

71 x_ptr, 

72 out_ptr, 

73 n_elements, 

74 alpha, 

75 dtype_size, # used for autotune key 

76 BLOCK_SIZE: tl.constexpr, 

77 VEC: tl.constexpr, 

78): 

79 pid = tl.program_id(0) 

80 BLOCK_ELEMS: tl.constexpr = BLOCK_SIZE * VEC 

81 offsets = (pid * BLOCK_ELEMS + tl.arange(0, BLOCK_ELEMS)).to(tl.int64) 

82 mask = offsets < n_elements 

83 x = tl.load(x_ptr + offsets, mask=mask) 

84 

85 x_compute = x.to(tl.float32) 

86 alpha_val = tl.full((1,), alpha, tl.float32) 

87 inv_alpha = 1.0 / alpha_val 

88 neg_mask = x_compute <= 0 

89 exp_val = exp(tl.where(neg_mask, x_compute * inv_alpha, 0.0)) 

90 neg = alpha_val * (exp_val - 1.0) 

91 out = tl.where(neg_mask, neg, x_compute).to(x.dtype) 

92 

93 tl.store(out_ptr + offsets, out, mask=mask) 

94 

95 

96def _use_triton_kernel( 

97 A: torch.Tensor, alpha, *, is_inplace: bool 

98) -> Tuple[bool, float]: 

99 if not isinstance(A, torch.Tensor): 

100 return False, 0.0 

101 if A.device.type != "musa" or A.dtype not in _SUPPORTED_DTYPES: 

102 return False, 0.0 

103 if not A.is_contiguous() or A.numel() == 0: 

104 return False, 0.0 

105 try: 

106 alpha_value = ( 

107 float(alpha) if not isinstance(alpha, torch.Tensor) else float(alpha.item()) 

108 ) 

109 except Exception: 

110 return False, 0.0 

111 if not math.isfinite(alpha_value): 

112 return False, 0.0 

113 return True, alpha_value 

114 

115 

116def _launch_celu(A: torch.Tensor, out: torch.Tensor, alpha_value: float): 

117 x_flat = A.view(-1) 

118 out_flat = out.view(-1) 

119 n_elements = out_flat.numel() 

120 dtype_size = out_flat.element_size() 

121 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"] * META["VEC"]),) 

122 with torch_device_fn.device(out.device): 

123 if alpha_value == 1.0: 

124 celu_kernel_alpha1[grid](x_flat, out_flat, n_elements, dtype_size) 

125 else: 

126 celu_kernel[grid](x_flat, out_flat, n_elements, alpha_value, dtype_size) 

127 return out 

128 

129 

130def celu(A, alpha=1.0): 

131 logger.debug("GEMS_MTHREADS CELU") 

132 use_triton, alpha_value = _use_triton_kernel(A, alpha, is_inplace=False) 

133 if not use_triton: 

134 return default_celu(A, alpha=alpha) 

135 

136 out = torch.empty_like(A) 

137 return _launch_celu(A, out, alpha_value) 

138 

139 

140def celu_(A, alpha=1.0): 

141 logger.debug("GEMS_MTHREADS CELU_") 

142 use_triton, alpha_value = _use_triton_kernel(A, alpha, is_inplace=True) 

143 if not use_triton: 

144 return default_celu_(A, alpha=alpha) 

145 

146 return _launch_celu(A, A, alpha_value)