Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/randn.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12from flag_gems.utils.shape_utils import volume 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15try: 

16 pair_uniform_to_normal = tl.pair_uniform_to_normal 

17except AttributeError: 

18 

19 @triton.jit 

20 def pair_uniform_to_normal(u1, u2): 

21 """Box-Muller transform""" 

22 u1 = tl.maximum(1.0e-7, u1) 

23 th = 6.283185307179586 * u2 

24 r = tl.sqrt(-2.0 * tl.log(u1)) 

25 return r * tl.cos(th), r * tl.sin(th) 

26 

27 

28device_ = device 

29 

30 

31@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

32def randn_kernel( 

33 out_ptr, 

34 N, 

35 philox_seed, 

36 philox_offset, 

37 BLOCK: tl.constexpr, 

38): 

39 philox_seed = philox_seed.to(tl.int64) 

40 philox_offset = philox_offset.to(tl.int64) 

41 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

42 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

43 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

44 c0 += i4 

45 _O = c0 * 0 

46 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O, 7) 

47 r0 = uint_to_uniform_float(r0) 

48 r1 = uint_to_uniform_float(r1) 

49 r2 = uint_to_uniform_float(r2) 

50 r3 = uint_to_uniform_float(r3) 

51 n0, n1 = pair_uniform_to_normal(r0, r1) 

52 n2, n3 = pair_uniform_to_normal(r2, r3) 

53 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK) 

54 off_1 = off_0 + BLOCK 

55 off_2 = off_1 + BLOCK 

56 off_3 = off_2 + BLOCK 

57 tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy="evict_first") 

58 tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy="evict_first") 

59 tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy="evict_first") 

60 tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy="evict_first") 

61 

62 

63UNROLL = 4 

64 

65 

66def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

67 logger.debug("GEMS RANDN") 

68 if dtype is None: 

69 dtype = torch.get_default_dtype() 

70 if device is None: 

71 device = torch.device(device_.name) 

72 out = torch.empty(size, device=device, dtype=dtype) 

73 N = volume(size) 

74 cluster_num = 12 

75 BLOCK_SIZE = min( 

76 triton.next_power_of_2(triton.cdiv(N, cluster_num * UNROLL)), 

77 1024, 

78 ) 

79 grid_fn = triton.cdiv(N, BLOCK_SIZE * UNROLL) 

80 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

81 # hence we cannot obtain the per thread offset as in Pytorch. 

82 increment = triton.cdiv(N, UNROLL) 

83 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

84 with torch_device_fn.device(device): 

85 randn_kernel[(grid_fn,)](out, N, philox_seed, philox_offset, BLOCK_SIZE) 

86 return out