Coverage for src/flag_gems/runtime/backend/_mthreads/ops/rand.py: 0%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13from flag_gems.utils.shape_utils import volume 

14 

15logger = logging.getLogger( 

16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

17) 

18device_ = device 

19 

20 

21@triton.heuristics(runtime.get_heuristic_config("rand")) 

22@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

23def rand_kernel( 

24 out_ptr, 

25 N, 

26 philox_seed, 

27 philox_offset, 

28 BLOCK: tl.constexpr, 

29): 

30 philox_seed = philox_seed.to(tl.int64) 

31 philox_offset = philox_offset.to(tl.int64) 

32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

34 i4 = (tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)).to(tl.uint32) 

35 c0 += i4 

36 _O = c0 * 0 

37 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

38 r0 = uint_to_uniform_float(r0) 

39 r1 = uint_to_uniform_float(r1) 

40 r2 = uint_to_uniform_float(r2) 

41 r3 = uint_to_uniform_float(r3) 

42 off_0 = ((tl.program_id(0) * BLOCK * 4).to(tl.int64) + tl.arange(0, BLOCK)).to( 

43 tl.int64 

44 ) 

45 off_1 = off_0 + BLOCK 

46 off_2 = off_1 + BLOCK 

47 off_3 = off_2 + BLOCK 

48 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first") 

49 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first") 

50 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first") 

51 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first") 

52 

53 

54UNROLL = 4 

55 

56 

57def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

58 logger.debug("GEMS_MTHREADS RAND") 

59 if dtype is None: 

60 dtype = torch.get_default_dtype() 

61 if device is None: 

62 device = torch.device(device_.name) 

63 

64 out = torch.empty(size, device=device, dtype=dtype) 

65 N = volume(size) 

66 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

67 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

68 # hence we cannot obtain the per thread offset as in Pytorch. 

69 increment = triton.cdiv(N, UNROLL) 

70 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

71 with torch_device_fn.device(device): 

72 rand_kernel[grid_fn](out, N, philox_seed, philox_offset) 

73 return out