Coverage for src/flag_gems/runtime/backend/_cambricon/ops/uniform.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import triton
4import triton.language as tl
5from triton.language.extra.mlu.libdevice import philox as _philox
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
13from flag_gems.utils.shape_utils import volume
15from ..utils import TOTAL_CORE_NUM
17logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
20@triton.heuristics(runtime.get_heuristic_config("uniform"))
21@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
22def uniform_kernel(
23 out_ptr,
24 N,
25 philox_seed,
26 philox_offset,
27 from_,
28 to,
29 BLOCK: tl.constexpr,
30):
31 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time
32 philox_seed = philox_seed.to(tl.int64)
33 philox_offset = philox_offset.to(tl.int64)
35 pid = tl.program_id(0)
36 num_jobs = tl.num_programs(0)
37 i4_start = pid * BLOCK
38 block_start = pid * UNROLL * BLOCK
39 step = num_jobs * BLOCK * UNROLL
41 for block_offset in range(block_start, N, step):
42 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32)
43 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32)
44 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
45 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
46 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10)
47 r = uint_to_uniform_float(r) * (to - from_) + from_
48 r = tl.reshape(r, [UNROLL * BLOCK], can_reorder=True)
50 off = block_offset + tl.arange(0, UNROLL * BLOCK)
51 tl.store(out_ptr + off, r, mask=off < N)
52 i4_start += num_jobs * BLOCK
55UNROLL = 4
58def uniform_(self, from_=0.0, to=1.0, *, generator=None):
59 logger.debug("GEMS_CAMBRICON UNIFORM")
60 N = volume(self.shape)
61 grid_fn = lambda meta: (
62 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
63 )
65 increment = triton.cdiv(N, UNROLL)
66 philox_seed, philox_offset = philox_backend_seed_offset(
67 increment, generator=generator
68 )
69 with torch_device_fn.device(self.device):
70 uniform_kernel[grid_fn](
71 self, N, philox_seed, philox_offset, from_, to, num_warps=1, num_stages=3
72 )
73 return self