Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/exponential_.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.xpu.libdevice import log2 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.random_utils import ( 

11 philox_backend_seed_offset, 

12 uint_to_uniform_float, 

13) 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16# def heur_block(args): 

17# if args["N"] <= 512: 

18# return 512 

19# else: 

20# return 1024 

21 

22 

23def heur_block(args): 

24 return triton.next_power_of_2(triton.cdiv(args["N"], 12)) # CLUSTER_NUM = 12 

25 

26 

27def heur_num_warps(args): 

28 if args["N"] <= 512: 

29 return 4 

30 elif args["N"] <= 1024: 

31 return 8 

32 else: 

33 return 16 

34 

35 

36@triton.heuristics( 

37 { 

38 "BLOCK": heur_block, 

39 "num_warps": heur_num_warps, 

40 } 

41) 

42# @triton.heuristics(runtime.get_heuristic_config("exponential_")) 

43@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

44def fused_exponential_kernel( 

45 out_ptr, 

46 N, 

47 is_double: tl.constexpr, 

48 lambd, 

49 eps, 

50 philox_seed, 

51 philox_offset, 

52 BLOCK: tl.constexpr, 

53): 

54 philox_seed = philox_seed.to(tl.int64) 

55 philox_offset = philox_offset.to(tl.int64) 

56 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

57 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

58 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

59 c0 += i4 

60 _O = c0 * 0 

61 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

62 if is_double: 

63 d0 = uint_to_uniform_float(paste_u64(r0, r2)) 

64 d1 = uint_to_uniform_float(paste_u64(r1, r3)) 

65 y0 = transform_exponential(d0, lambd, eps) 

66 y1 = transform_exponential(d1, lambd, eps) 

67 UNROLL = 2 

68 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

69 off_0 = start + tl.arange(0, BLOCK) 

70 off_1 = off_0 + BLOCK 

71 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

72 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

73 else: 

74 f0 = uint_to_uniform_float(r0) 

75 f1 = uint_to_uniform_float(r1) 

76 f2 = uint_to_uniform_float(r2) 

77 f3 = uint_to_uniform_float(r3) 

78 y0 = transform_exponential(f0, lambd, eps) 

79 y1 = transform_exponential(f1, lambd, eps) 

80 y2 = transform_exponential(f2, lambd, eps) 

81 y3 = transform_exponential(f3, lambd, eps) 

82 UNROLL = 4 

83 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL 

84 off_0 = start + tl.arange(0, BLOCK) 

85 off_1 = off_0 + BLOCK 

86 off_2 = off_1 + BLOCK 

87 off_3 = off_2 + BLOCK 

88 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

89 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

90 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") 

91 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") 

92 

93 

94@triton.jit 

95def paste_u64(hi: tl.uint32, lo: tl.uint32): 

96 hi = hi.to(tl.uint64) << 32 

97 x = hi | lo.to(tl.uint64) 

98 return x 

99 

100 

101@triton.jit 

102def transform_exponential(u, lambd, eps): 

103 eps1 = -0.5 * eps 

104 is_min = u >= 1.0 + eps1 

105 trans_scale = 1.0 / 1.4426950408889634 

106 log = tl.where(is_min, eps1, log2(u) * trans_scale) 

107 v = -1.0 / lambd * log 

108 return v 

109 

110 

111def exponential_(x, lambd: float = 1.0, *, generator=None): 

112 logger.debug("GEMS EXPONENTIAL_") 

113 dtype = x.dtype 

114 device = x.device 

115 inplace = x.is_contiguous() 

116 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

117 is_double = dtype in (torch.float64,) 

118 UNROLL = 2 if is_double else 4 

119 N = x.numel() 

120 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

121 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

122 # hence we cannot obtain the per thread offset as in Pytorch. 

123 increment = triton.cdiv(N, UNROLL) 

124 philox_seed, philox_offset = philox_backend_seed_offset( 

125 increment, generator=generator 

126 ) 

127 eps = torch.finfo(dtype).eps 

128 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) 

129 with torch_device_fn.device(device): 

130 fused_exponential_kernel[grid_fn]( 

131 x_, N, is_double, lambd, eps, philox_seed, philox_offset 

132 ) 

133 if not inplace: 

134 x.copy_(x_) 

135 return x