Coverage for src/flag_gems/runtime/backend/_cambricon/ops/rand.py: 0%
48 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.mlu.libdevice import philox as _philox
8from flag_gems import runtime
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils.random_utils import (
11 philox_backend_seed_offset,
12 uint_to_uniform_float,
13)
14from flag_gems.utils.shape_utils import volume
16from ..utils import TOTAL_CORE_NUM
18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19device_ = device
22@triton.heuristics(runtime.get_heuristic_config("rand"))
23@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
24def rand_kernel(
25 out_ptr,
26 N,
27 philox_seed,
28 philox_offset,
29 BLOCK: tl.constexpr,
30):
31 UNROLL: tl.constexpr = 4
32 philox_seed = philox_seed.to(tl.int64)
33 philox_offset = philox_offset.to(tl.int64)
35 pid = tl.program_id(0)
36 num_jobs = tl.num_programs(0)
37 i4_start = pid * BLOCK
38 block_start = pid * UNROLL * BLOCK
39 step = num_jobs * BLOCK * UNROLL
41 for block_offset in range(block_start, N, step):
42 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32)
43 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32)
44 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
45 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
46 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10)
47 r = uint_to_uniform_float(r)
49 off = block_offset + tl.arange(0, UNROLL * BLOCK)
50 r = tl.reshape(r, [UNROLL * BLOCK], can_reorder=True)
51 tl.store(out_ptr + off, r, mask=off < N)
52 i4_start += num_jobs * BLOCK
55UNROLL = 4
58def rand(size, *, dtype=None, layout=None, device=None, pin_memory=None):
59 logger.debug("GEMS_CAMBRICON RAND")
60 if dtype is None:
61 dtype = torch.get_default_dtype()
62 if device is None:
63 device = torch.device(device_.name)
65 out = torch.empty(size, device=device, dtype=dtype)
66 N = volume(size)
67 grid_fn = lambda meta: (
68 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
69 )
70 philox_seed, philox_offset = philox_backend_seed_offset(N)
71 with torch_device_fn.device(device):
72 rand_kernel[grid_fn](
73 out, N, philox_seed, philox_offset, num_stages=3, num_warps=1
74 )
75 return out