Coverage for src/flag_gems/ops/uniform.py: 32%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems import runtime 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils.random_utils import ( 

9 philox_backend_seed_offset, 

10 uint_to_uniform_float, 

11) 

12from flag_gems.utils.shape_utils import volume 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.heuristics(runtime.get_heuristic_config("uniform")) 

18@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

19def uniform_kernel( 

20 out_ptr, 

21 N, 

22 philox_seed, 

23 philox_offset, 

24 from_, 

25 to, 

26 BLOCK: tl.constexpr, 

27): 

28 philox_seed = philox_seed.to(tl.int64) 

29 philox_offset = philox_offset.to(tl.int64) 

30 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

31 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

32 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

33 c0 += i4 

34 _O = c0 * 0 

35 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

36 r0 = uint_to_uniform_float(r0) * (to - from_) + from_ 

37 r1 = uint_to_uniform_float(r1) * (to - from_) + from_ 

38 r2 = uint_to_uniform_float(r2) * (to - from_) + from_ 

39 r3 = uint_to_uniform_float(r3) * (to - from_) + from_ 

40 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK) 

41 off_1 = off_0 + BLOCK 

42 off_2 = off_1 + BLOCK 

43 off_3 = off_2 + BLOCK 

44 tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first") 

45 tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first") 

46 tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first") 

47 tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first") 

48 

49 

50UNROLL = 4 

51 

52 

53def uniform_(self, from_=0.0, to=1.0, *, generator=None): 

54 logger.debug("GEMS UNIFORM") 

55 N = volume(self.shape) 

56 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

57 

58 increment = triton.cdiv(N, UNROLL) 

59 philox_seed, philox_offset = philox_backend_seed_offset( 

60 increment, generator=generator 

61 ) 

62 with torch_device_fn.device(self.device): 

63 uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to) 

64 return self