Coverage for src/flag_gems/runtime/backend/_cambricon/ops/randn.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.mlu.libdevice import philox as _philox
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils.random_utils import (
11 philox_backend_seed_offset,
12 uint_to_uniform_float,
13)
14from flag_gems.utils.shape_utils import volume
16from ..utils import TOTAL_CORE_NUM
18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19try:
20 pair_uniform_to_normal = tl.pair_uniform_to_normal
21except AttributeError:
23 @triton.jit
24 def pair_uniform_to_normal(u1, u2):
25 """Box-Muller transform"""
26 u1 = tl.maximum(1.0e-7, u1)
27 th = 6.283185307179586 * u2
28 r = tl.sqrt(-2.0 * tl.log(u1))
29 return r * tl.cos(th), r * tl.sin(th)
32device_ = device
35@triton.heuristics(runtime.get_heuristic_config("randn"))
36@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
37def randn_kernel(
38 out_ptr,
39 N,
40 philox_seed,
41 philox_offset,
42 BLOCK: tl.constexpr,
43):
44 UNROLL: tl.constexpr = 4
45 philox_seed = philox_seed.to(tl.int64)
46 philox_offset = philox_offset.to(tl.int64)
48 pid = tl.program_id(0)
49 num_jobs = tl.num_programs(0)
50 i4_start = pid * BLOCK
51 block_start = pid * UNROLL * BLOCK
52 step = num_jobs * BLOCK * UNROLL
54 res = tl.empty([UNROLL, BLOCK], dtype=tl.float32)
55 for block_offset in range(block_start, N, step):
56 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32)
57 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32)
58 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
59 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
60 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10)
61 r = uint_to_uniform_float(r)
63 res[0, :], res[1, :] = pair_uniform_to_normal(r[:, 0], r[:, 1])
64 res[2, :], res[3, :] = pair_uniform_to_normal(r[:, 2], r[:, 3])
66 off = block_offset + tl.arange(0, BLOCK * UNROLL)
67 tl.store(
68 out_ptr + off,
69 tl.reshape(res, [BLOCK * UNROLL], can_reorder=True),
70 mask=off < N,
71 )
72 i4_start += num_jobs * BLOCK
75UNROLL = 4
78def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None):
79 logger.debug("GEMS_CAMBRICON RANDN")
80 if dtype is None:
81 dtype = torch.get_default_dtype()
82 if device is None:
83 device = torch.device(device_.name)
85 out = torch.empty(size, device=device, dtype=dtype)
86 N = volume(size)
87 grid_fn = lambda meta: (
88 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
89 )
90 philox_seed, philox_offset = philox_backend_seed_offset(N)
91 with torch_device_fn.device(device):
92 randn_kernel[grid_fn](
93 out, N, philox_seed, philox_offset, num_stages=3, num_warps=1
94 )
95 return out