Coverage for src/flag_gems/runtime/backend/_cambricon/ops/ones.py: 0%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils import libentry, libtuner
9from flag_gems.utils.shape_utils import volume
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14device_ = device
17@libentry()
18@libtuner(
19 configs=[
20 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
23 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
24 ],
25 key=["n_elements"],
26 strategy=["align32"],
27)
28@triton.jit
29def ones_kernel(
30 output_ptr,
31 n_elements,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tl.program_id(axis=0)
35 num_jobs = tl.num_programs(axis=0)
36 block_start = pid * BLOCK_SIZE
37 step = num_jobs * BLOCK_SIZE
38 block_start = block_start.to(tl.int64)
39 for block_start_offset in range(block_start, n_elements, step):
40 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
41 mask = offsets < n_elements
42 tl.store(output_ptr + offsets, 1.0, mask=mask)
45def ones(size, *, dtype=None, layout=None, device=None, pin_memory=None):
46 logger.debug("GEMS_CAMBRICON ONES")
47 if dtype is None:
48 dtype = torch.get_default_dtype()
49 if device is None:
50 device = torch.device(device_.name)
52 out = torch.empty(size, device=device, dtype=dtype)
53 N = volume(size)
54 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
55 with torch_device_fn.device(device):
56 ones_kernel[grid_fn](out, N)
57 return out