Coverage for src/flag_gems/runtime/backend/_cambricon/ops/uniform.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5from triton.language.extra.mlu.libdevice import philox as _philox 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13from flag_gems.utils.shape_utils import volume 

14 

15from ..utils import TOTAL_CORE_NUM 

16 

17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

18 

19 

20@triton.heuristics(runtime.get_heuristic_config("uniform")) 

21@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

22def uniform_kernel( 

23 out_ptr, 

24 N, 

25 philox_seed, 

26 philox_offset, 

27 from_, 

28 to, 

29 BLOCK: tl.constexpr, 

30): 

31 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time 

32 philox_seed = philox_seed.to(tl.int64) 

33 philox_offset = philox_offset.to(tl.int64) 

34 

35 pid = tl.program_id(0) 

36 num_jobs = tl.num_programs(0) 

37 i4_start = pid * BLOCK 

38 block_start = pid * UNROLL * BLOCK 

39 step = num_jobs * BLOCK * UNROLL 

40 

41 for block_offset in range(block_start, N, step): 

42 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32) 

43 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32) 

44 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

45 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

46 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10) 

47 r = uint_to_uniform_float(r) * (to - from_) + from_ 

48 r = tl.reshape(r, [UNROLL * BLOCK], can_reorder=True) 

49 

50 off = block_offset + tl.arange(0, UNROLL * BLOCK) 

51 tl.store(out_ptr + off, r, mask=off < N) 

52 i4_start += num_jobs * BLOCK 

53 

54 

55UNROLL = 4 

56 

57 

58def uniform_(self, from_=0.0, to=1.0, *, generator=None): 

59 logger.debug("GEMS_CAMBRICON UNIFORM") 

60 N = volume(self.shape) 

61 grid_fn = lambda meta: ( 

62 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

63 ) 

64 

65 increment = triton.cdiv(N, UNROLL) 

66 philox_seed, philox_offset = philox_backend_seed_offset( 

67 increment, generator=generator 

68 ) 

69 with torch_device_fn.device(self.device): 

70 uniform_kernel[grid_fn]( 

71 self, N, philox_seed, philox_offset, from_, to, num_warps=1, num_stages=3 

72 ) 

73 return self