Coverage for src/flag_gems/runtime/backend/_cambricon/ops/randn.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.mlu.libdevice import philox as _philox 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils.random_utils import ( 

11 philox_backend_seed_offset, 

12 uint_to_uniform_float, 

13) 

14from flag_gems.utils.shape_utils import volume 

15 

16from ..utils import TOTAL_CORE_NUM 

17 

18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

19try: 

20 pair_uniform_to_normal = tl.pair_uniform_to_normal 

21except AttributeError: 

22 

23 @triton.jit 

24 def pair_uniform_to_normal(u1, u2): 

25 """Box-Muller transform""" 

26 u1 = tl.maximum(1.0e-7, u1) 

27 th = 6.283185307179586 * u2 

28 r = tl.sqrt(-2.0 * tl.log(u1)) 

29 return r * tl.cos(th), r * tl.sin(th) 

30 

31 

32device_ = device 

33 

34 

35@triton.heuristics(runtime.get_heuristic_config("randn")) 

36@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

37def randn_kernel( 

38 out_ptr, 

39 N, 

40 philox_seed, 

41 philox_offset, 

42 BLOCK: tl.constexpr, 

43): 

44 UNROLL: tl.constexpr = 4 

45 philox_seed = philox_seed.to(tl.int64) 

46 philox_offset = philox_offset.to(tl.int64) 

47 

48 pid = tl.program_id(0) 

49 num_jobs = tl.num_programs(0) 

50 i4_start = pid * BLOCK 

51 block_start = pid * UNROLL * BLOCK 

52 step = num_jobs * BLOCK * UNROLL 

53 

54 res = tl.empty([UNROLL, BLOCK], dtype=tl.float32) 

55 for block_offset in range(block_start, N, step): 

56 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32) 

57 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32) 

58 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

59 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

60 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10) 

61 r = uint_to_uniform_float(r) 

62 

63 res[0, :], res[1, :] = pair_uniform_to_normal(r[:, 0], r[:, 1]) 

64 res[2, :], res[3, :] = pair_uniform_to_normal(r[:, 2], r[:, 3]) 

65 

66 off = block_offset + tl.arange(0, BLOCK * UNROLL) 

67 tl.store( 

68 out_ptr + off, 

69 tl.reshape(res, [BLOCK * UNROLL], can_reorder=True), 

70 mask=off < N, 

71 ) 

72 i4_start += num_jobs * BLOCK 

73 

74 

75UNROLL = 4 

76 

77 

78def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

79 logger.debug("GEMS_CAMBRICON RANDN") 

80 if dtype is None: 

81 dtype = torch.get_default_dtype() 

82 if device is None: 

83 device = torch.device(device_.name) 

84 

85 out = torch.empty(size, device=device, dtype=dtype) 

86 N = volume(size) 

87 grid_fn = lambda meta: ( 

88 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

89 ) 

90 philox_seed, philox_offset = philox_backend_seed_offset(N) 

91 with torch_device_fn.device(device): 

92 randn_kernel[grid_fn]( 

93 out, N, philox_seed, philox_offset, num_stages=3, num_warps=1 

94 ) 

95 return out