Coverage for src/flag_gems/runtime/backend/_cambricon/ops/exponential_.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.mlu.libdevice import philox as _philox 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.random_utils import ( 

11 philox_backend_seed_offset, 

12 uint_to_uniform_float, 

13) 

14 

15from ..utils import TOTAL_CORE_NUM 

16 

17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

18 

19 

20@triton.heuristics(runtime.get_heuristic_config("exponential_")) 

21@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

22def fused_exponential_kernel( 

23 out_ptr, 

24 N, 

25 is_double: tl.constexpr, 

26 lambd, 

27 eps, 

28 philox_seed, 

29 philox_offset, 

30 BLOCK: tl.constexpr, 

31): 

32 if is_double: 

33 UNROLL: tl.constexpr = 2 # philox generate 128 random bits at a time 

34 else: 

35 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time 

36 philox_seed = philox_seed.to(tl.int64) 

37 philox_offset = philox_offset.to(tl.int64) 

38 

39 pid = tl.program_id(0) 

40 num_jobs = tl.num_programs(0) 

41 i4_start = pid * BLOCK 

42 block_start = pid * UNROLL * BLOCK 

43 step = num_jobs * BLOCK * UNROLL 

44 

45 for block_offset in range(block_start, N, step): 

46 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32) 

47 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32) 

48 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

49 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

50 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10) 

51 r = tl.reshape(r, [UNROLL * BLOCK], can_reorder=True) 

52 off = block_offset + tl.arange(0, UNROLL * BLOCK) 

53 

54 if is_double: 

55 r = r.to(tl.uint64, bitcast=True) 

56 f = uint_to_uniform_float(r) 

57 else: 

58 f = uint_to_uniform_float(r) 

59 y = transform_exponential(f, lambd, eps) 

60 tl.store(out_ptr + off, y, mask=off < N) 

61 i4_start += num_jobs * BLOCK 

62 

63 

64@triton.jit 

65def paste_u64(hi: tl.uint32, lo: tl.uint32): 

66 hi = hi.to(tl.uint64) << 32 

67 x = hi | lo.to(tl.uint64) 

68 return x 

69 

70 

71@triton.jit 

72def transform_exponential(u, lambd, eps): 

73 eps1 = -0.5 * eps 

74 is_min = u >= 1.0 + eps1 

75 log = tl.where(is_min, eps1, tl.math.log(u)) 

76 v = -1.0 / lambd * log 

77 return v 

78 

79 

80def exponential_(x, lambd: float = 1.0, *, generator=None): 

81 logger.debug("GEMS_CAMBRICON EXPONENTIAL_") 

82 dtype = x.dtype 

83 device = x.device 

84 inplace = x.is_contiguous() 

85 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

86 is_double = dtype in (torch.float64,) 

87 UNROLL = 2 if is_double else 4 

88 N = x.numel() 

89 grid_fn = lambda meta: ( 

90 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

91 ) 

92 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

93 # hence we cannot obtain the per thread offset as in Pytorch. 

94 increment = triton.cdiv(N, UNROLL) 

95 philox_seed, philox_offset = philox_backend_seed_offset( 

96 increment, generator=generator 

97 ) 

98 eps = torch.finfo(dtype).eps 

99 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device) 

100 with torch_device_fn.device(device): 

101 fused_exponential_kernel[grid_fn]( 

102 x_, 

103 N, 

104 is_double, 

105 lambd, 

106 eps, 

107 philox_seed, 

108 philox_offset, 

109 num_warps=1, 

110 num_stages=3, 

111 ) 

112 if not inplace: 

113 x.copy_(x_) 

114 return x