Coverage for src/flag_gems/runtime/backend/_hygon/ops/exponential_.py: 0%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12 

13logger = logging.getLogger(__name__) 

14 

15MIN_NORMAL_F32 = 1.17549435e-38 

16# Largest value less than 1.0 to avoid log(1)=0 edge (though harmless) 

17MAX_U_F32 = 0.99999994 # nextafter(1.0, 0.0) in float32 

18 

19 

20@triton.jit 

21def safe_fast_log(x): 

22 # Construct FP32 constants matching x's dtype 

23 min_normal = x * 0.0 + 1.17549435e-38 

24 max_u = x * 0.0 + 0.99999994 

25 

26 x = tl.minimum(tl.maximum(x, min_normal), max_u) 

27 

28 bits = x.to(tl.int32, bitcast=True) 

29 exponent = (bits >> 23) - 127 

30 # mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / (1 << 23)) + 1.0 

31 mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608) + 1.0 

32 

33 m1 = mantissa - 1.0 

34 log_m = m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) 

35 log_val = log_m + exponent.to(tl.float32) * 0.6931471805599453 

36 

37 return log_val 

38 

39 

40# ===== Kernel with constexpr switch ===== 

41@triton.autotune( 

42 configs=[ 

43 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), 

44 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), 

45 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), 

46 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), 

47 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3), 

48 triton.Config({"BLOCK": 1024}, num_warps=16, num_stages=3), 

49 triton.Config({"BLOCK": 2048}, num_warps=16, num_stages=4), 

50 ], 

51 key=["N", "is_double"], 

52) 

53# @triton.heuristics(runtime.get_heuristic_config("exponential_")) 

54@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

55def fused_exponential_kernel( 

56 out_ptr, 

57 N, 

58 is_double, 

59 inv_lambd, 

60 eps_minus, 

61 philox_seed, 

62 philox_offset, 

63 BLOCK: tl.constexpr, 

64): 

65 philox_seed = philox_seed.to(tl.int64) 

66 philox_offset = philox_offset.to(tl.int64) 

67 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

68 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

69 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

70 c0 += i4 

71 _O = c0 * 0 

72 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

73 if is_double: 

74 d0 = uint_to_uniform_float(paste_u64(r0, r2)) 

75 d1 = uint_to_uniform_float(paste_u64(r1, r3)) 

76 y0 = transform_exponential(d0, inv_lambd, eps_minus) 

77 y1 = transform_exponential(d1, inv_lambd, eps_minus) 

78 UNROLL = 2 

79 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

80 off_0 = start + tl.arange(0, BLOCK) 

81 off_1 = off_0 + BLOCK 

82 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

83 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

84 else: 

85 f0 = uint_to_uniform_float(r0) 

86 f1 = uint_to_uniform_float(r1) 

87 f2 = uint_to_uniform_float(r2) 

88 f3 = uint_to_uniform_float(r3) 

89 y0 = transform_exponential(f0, inv_lambd, eps_minus) 

90 y1 = transform_exponential(f1, inv_lambd, eps_minus) 

91 y2 = transform_exponential(f2, inv_lambd, eps_minus) 

92 y3 = transform_exponential(f3, inv_lambd, eps_minus) 

93 

94 UNROLL = 4 

95 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

96 off_0 = start + tl.arange(0, BLOCK) 

97 off_1 = off_0 + BLOCK 

98 off_2 = off_1 + BLOCK 

99 off_3 = off_2 + BLOCK 

100 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_last") 

101 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_last") 

102 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_last") 

103 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_last") 

104 

105 

106@triton.jit 

107def paste_u64(hi: tl.uint32, lo: tl.uint32): 

108 hi = hi.to(tl.uint64) << 32 

109 x = hi | lo.to(tl.uint64) 

110 return x 

111 

112 

113@triton.jit 

114def transform_exponential(u, inv_lambd, eps_minus): 

115 # eps1 = -0.5 * eps 

116 is_min = u >= 1.0 + eps_minus 

117 # log = tl.where(is_min, eps1, tl.math.log(u)) 

118 # is_min = u >= compare_val 

119 log = tl.where(is_min, eps_minus, safe_fast_log(u)) 

120 v = -inv_lambd * log 

121 return v 

122 

123 

124def exponential_(x, lambd: float = 1.0, *, generator=None): 

125 logger.debug("GEMS EXPONENTIAL_") 

126 dtype = x.dtype 

127 device = x.device 

128 inplace = x.is_contiguous() 

129 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

130 is_double = dtype in (torch.float64,) 

131 UNROLL = 2 if is_double else 4 

132 N = x.numel() 

133 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

134 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

135 # hence we cannot obtain the per thread offset as in Pytorch. 

136 increment = triton.cdiv(N, UNROLL) 

137 philox_seed, philox_offset = philox_backend_seed_offset( 

138 increment, generator=generator 

139 ) 

140 eps = torch.finfo(dtype).eps 

141 eps_minus = -0.5 * eps 

142 inv_lambd = 1.0 / lambd 

143 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) 

144 with torch_device_fn.device(device): 

145 fused_exponential_kernel[grid_fn]( 

146 x_, N, is_double, inv_lambd, eps_minus, philox_seed, philox_offset 

147 ) 

148 if not inplace: 

149 x.copy_(x_) 

150 return x