Coverage for src/flag_gems/runtime/backend/_cambricon/ops/rand.py: 0%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.mlu.libdevice import philox as _philox 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils.random_utils import ( 

11 philox_backend_seed_offset, 

12 uint_to_uniform_float, 

13) 

14from flag_gems.utils.shape_utils import volume 

15 

16from ..utils import TOTAL_CORE_NUM 

17 

18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

19device_ = device 

20 

21 

22@triton.heuristics(runtime.get_heuristic_config("rand")) 

23@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

24def rand_kernel( 

25 out_ptr, 

26 N, 

27 philox_seed, 

28 philox_offset, 

29 BLOCK: tl.constexpr, 

30): 

31 UNROLL: tl.constexpr = 4 

32 philox_seed = philox_seed.to(tl.int64) 

33 philox_offset = philox_offset.to(tl.int64) 

34 

35 pid = tl.program_id(0) 

36 num_jobs = tl.num_programs(0) 

37 i4_start = pid * BLOCK 

38 block_start = pid * UNROLL * BLOCK 

39 step = num_jobs * BLOCK * UNROLL 

40 

41 for block_offset in range(block_start, N, step): 

42 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32) 

43 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32) 

44 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

45 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

46 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10) 

47 r = uint_to_uniform_float(r) 

48 

49 off = block_offset + tl.arange(0, UNROLL * BLOCK) 

50 r = tl.reshape(r, [UNROLL * BLOCK], can_reorder=True) 

51 tl.store(out_ptr + off, r, mask=off < N) 

52 i4_start += num_jobs * BLOCK 

53 

54 

55UNROLL = 4 

56 

57 

58def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

59 logger.debug("GEMS_CAMBRICON RAND") 

60 if dtype is None: 

61 dtype = torch.get_default_dtype() 

62 if device is None: 

63 device = torch.device(device_.name) 

64 

65 out = torch.empty(size, device=device, dtype=dtype) 

66 N = volume(size) 

67 grid_fn = lambda meta: ( 

68 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

69 ) 

70 philox_seed, philox_offset = philox_backend_seed_offset(N) 

71 with torch_device_fn.device(device): 

72 rand_kernel[grid_fn]( 

73 out, N, philox_seed, philox_offset, num_stages=3, num_warps=1 

74 ) 

75 return out