Coverage for src/flag_gems/runtime/backend/_ascend/ops/exponential_.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13 

14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

15 

16 

17@triton.heuristics(runtime.get_heuristic_config("exponential_")) 

18@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

19def fused_exponential_kernel( 

20 out_ptr, 

21 N, 

22 is_double, 

23 lambd, 

24 eps, 

25 philox_seed, 

26 philox_offset, 

27 UNROLL, 

28 BLOCK: tl.constexpr, 

29): 

30 n_workers = tl.num_programs(0) 

31 pid = tl.program_id(0) 

32 n_tasks = tl.cdiv(N, BLOCK * UNROLL) 

33 tasks_per_worker = tl.cdiv(n_tasks, n_workers) 

34 

35 for task_index in range(tasks_per_worker): 

36 task_id = pid + task_index * n_workers 

37 philox_seed = philox_seed.to(tl.int64) 

38 philox_offset_64 = philox_offset.to(tl.int64) 

39 c0 = (philox_offset_64 & 0xFFFFFFFF).to(tl.uint32) 

40 c1 = ((philox_offset_64 >> 32) & 0xFFFFFFFF).to(tl.uint32) 

41 i4 = task_id * BLOCK + tl.arange(0, BLOCK) 

42 c0 += i4 

43 _O = c0 * 0 

44 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

45 if is_double: 

46 d0 = uint_to_uniform_float(paste_u64(r0, r2)) 

47 d1 = uint_to_uniform_float(paste_u64(r1, r3)) 

48 y0 = transform_exponential(d0, lambd, eps) 

49 y1 = transform_exponential(d1, lambd, eps) 

50 # UNROLLL = 2 

51 start = task_id.to(tl.int64) * BLOCK * 2 

52 off_0 = start + tl.arange(0, BLOCK) 

53 off_1 = off_0 + BLOCK 

54 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

55 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

56 else: 

57 f0 = uint_to_uniform_float(r0) 

58 f1 = uint_to_uniform_float(r1) 

59 f2 = uint_to_uniform_float(r2) 

60 f3 = uint_to_uniform_float(r3) 

61 y0 = transform_exponential(f0, lambd, eps) 

62 y1 = transform_exponential(f1, lambd, eps) 

63 y2 = transform_exponential(f2, lambd, eps) 

64 y3 = transform_exponential(f3, lambd, eps) 

65 # UNROLLL = 4 

66 start = task_id.to(tl.int64) * BLOCK * 4 

67 off_0 = start + tl.arange(0, BLOCK) 

68 off_1 = off_0 + BLOCK 

69 off_2 = off_1 + BLOCK 

70 off_3 = off_2 + BLOCK 

71 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

72 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

73 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") 

74 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") 

75 

76 

77@triton.jit 

78def paste_u64(hi: tl.uint32, lo: tl.uint32): 

79 hi = hi.to(tl.uint64) << 32 

80 x = hi | lo.to(tl.uint64) 

81 return x 

82 

83 

84@triton.jit 

85def transform_exponential(u, lambd, eps): 

86 eps1 = -0.5 * eps 

87 is_min = u >= 1.0 + eps1 

88 log = tl.where(is_min, eps1, tl.math.log(u)) 

89 v = -1.0 / lambd * log 

90 return v 

91 

92 

93def exponential_(x, lambd: float = 1.0, *, gen=None): 

94 logger.debug("GEMS_ASCEND EXPONENTIAL_") 

95 dtype = x.dtype 

96 device = x.device 

97 inplace = x.is_contiguous() 

98 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

99 is_double = dtype in (torch.float64,) 

100 UNROLL = 2 if is_double else 4 

101 N = x.numel() 

102 

103 def grid_fn(meta): 

104 grid = triton.cdiv(N, meta["BLOCK"] * UNROLL) 

105 grid = grid if grid < 240 else 240 

106 return (grid,) 

107 

108 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

109 # hence we cannot obtain the per thread offset as in Pytorch. 

110 increment = triton.cdiv(N, UNROLL) 

111 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

112 eps = torch.finfo(dtype).eps 

113 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) 

114 with torch_device_fn.device(device): 

115 fused_exponential_kernel[grid_fn]( 

116 x_, N, is_double, lambd, eps, philox_seed, philox_offset, UNROLL 

117 ) 

118 if not inplace: 

119 x.copy_(x_) 

120 return x