Coverage for src/flag_gems/runtime/backend/_mthreads/ops/normal.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems import runtime
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
12from flag_gems.utils.shape_utils import volume
14try:
15 pair_uniform_to_normal = tl.pair_uniform_to_normal
16except AttributeError:
18 @triton.jit
19 def pair_uniform_to_normal(u1, u2):
20 """Box-Muller transform"""
21 u1 = tl.maximum(1.0e-7, u1)
22 th = 6.283185307179586 * u2
23 r = tl.sqrt(-2.0 * tl.log(u1))
24 return r * tl.cos(th), r * tl.sin(th)
27device_ = device
28logger = logging.getLogger(
29 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
30)
33@triton.heuristics(runtime.get_heuristic_config("randn"))
34@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "mean", "std"])
35def normal_kernel(
36 out_ptr,
37 N,
38 mean,
39 std,
40 philox_seed,
41 philox_offset,
42 BLOCK: tl.constexpr,
43):
44 philox_seed = philox_seed.to(tl.int64)
45 philox_offset = philox_offset.to(tl.int64)
46 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
47 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
48 i4 = (tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)).to(tl.uint32)
49 c0 += i4
50 _O = c0 * 0
51 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
52 r0 = uint_to_uniform_float(r0)
53 r1 = uint_to_uniform_float(r1)
54 r2 = uint_to_uniform_float(r2)
55 r3 = uint_to_uniform_float(r3)
56 n0, n1 = pair_uniform_to_normal(r0, r1)
57 n2, n3 = pair_uniform_to_normal(r2, r3)
59 # Apply linear transform: val * std + mean
60 n0 = n0 * std + mean
61 n1 = n1 * std + mean
62 n2 = n2 * std + mean
63 n3 = n3 * std + mean
65 off_0 = ((tl.program_id(0) * BLOCK * 4).to(tl.int64) + tl.arange(0, BLOCK)).to(
66 tl.int64
67 )
68 off_1 = off_0 + BLOCK
69 off_2 = off_1 + BLOCK
70 off_3 = off_2 + BLOCK
71 tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy="evict_first")
72 tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy="evict_first")
73 tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy="evict_first")
74 tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy="evict_first")
77UNROLL = 4
80def normal_(self, mean=0, std=1, *, generator=None):
81 logger.debug("GEMS_MTHREADS NORMAL_")
82 shape = self.shape
83 device = self.device
84 N = volume(shape)
85 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
86 increment = triton.cdiv(N, UNROLL)
87 philox_seed, philox_offset = philox_backend_seed_offset(
88 increment, generator=generator
89 )
90 with torch_device_fn.device(device):
91 normal_kernel[grid_fn](self, N, mean, std, philox_seed, philox_offset)
92 return self