Coverage for src/flag_gems/runtime/backend/_mthreads/ops/normal.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems import runtime 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12from flag_gems.utils.shape_utils import volume 

13 

14try: 

15 pair_uniform_to_normal = tl.pair_uniform_to_normal 

16except AttributeError: 

17 

18 @triton.jit 

19 def pair_uniform_to_normal(u1, u2): 

20 """Box-Muller transform""" 

21 u1 = tl.maximum(1.0e-7, u1) 

22 th = 6.283185307179586 * u2 

23 r = tl.sqrt(-2.0 * tl.log(u1)) 

24 return r * tl.cos(th), r * tl.sin(th) 

25 

26 

27device_ = device 

28logger = logging.getLogger( 

29 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

30) 

31 

32 

33@triton.heuristics(runtime.get_heuristic_config("randn")) 

34@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "mean", "std"]) 

35def normal_kernel( 

36 out_ptr, 

37 N, 

38 mean, 

39 std, 

40 philox_seed, 

41 philox_offset, 

42 BLOCK: tl.constexpr, 

43): 

44 philox_seed = philox_seed.to(tl.int64) 

45 philox_offset = philox_offset.to(tl.int64) 

46 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

47 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

48 i4 = (tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)).to(tl.uint32) 

49 c0 += i4 

50 _O = c0 * 0 

51 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

52 r0 = uint_to_uniform_float(r0) 

53 r1 = uint_to_uniform_float(r1) 

54 r2 = uint_to_uniform_float(r2) 

55 r3 = uint_to_uniform_float(r3) 

56 n0, n1 = pair_uniform_to_normal(r0, r1) 

57 n2, n3 = pair_uniform_to_normal(r2, r3) 

58 

59 # Apply linear transform: val * std + mean 

60 n0 = n0 * std + mean 

61 n1 = n1 * std + mean 

62 n2 = n2 * std + mean 

63 n3 = n3 * std + mean 

64 

65 off_0 = ((tl.program_id(0) * BLOCK * 4).to(tl.int64) + tl.arange(0, BLOCK)).to( 

66 tl.int64 

67 ) 

68 off_1 = off_0 + BLOCK 

69 off_2 = off_1 + BLOCK 

70 off_3 = off_2 + BLOCK 

71 tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy="evict_first") 

72 tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy="evict_first") 

73 tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy="evict_first") 

74 tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy="evict_first") 

75 

76 

77UNROLL = 4 

78 

79 

80def normal_(self, mean=0, std=1, *, generator=None): 

81 logger.debug("GEMS_MTHREADS NORMAL_") 

82 shape = self.shape 

83 device = self.device 

84 N = volume(shape) 

85 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

86 increment = triton.cdiv(N, UNROLL) 

87 philox_seed, philox_offset = philox_backend_seed_offset( 

88 increment, generator=generator 

89 ) 

90 with torch_device_fn.device(device): 

91 normal_kernel[grid_fn](self, N, mean, std, philox_seed, philox_offset) 

92 return self