Coverage for src/flag_gems/runtime/backend/_cambricon/ops/zeros.py: 0%
40 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils.shape_utils import volume
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13device_ = device
16@triton.autotune(
17 configs=[
18 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
22 ],
23 key=["n_elements"],
24)
25@triton.jit
26def zeros_kernel(
27 output_ptr,
28 n_elements,
29 BLOCK_SIZE: tl.constexpr,
30):
31 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
32 num_jobs = tl.num_programs(axis=0)
33 block_start = pid * BLOCK_SIZE
34 step = num_jobs * BLOCK_SIZE
35 block_start = block_start.to(tl.int64)
36 for block_start_offset in range(block_start, n_elements, step):
37 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
38 mask = offsets < n_elements
39 tl.store(output_ptr + offsets, 0.0, mask=mask)
42def zeros(size, *, dtype=None, layout=None, device=None, pin_memory=None):
43 logger.debug("GEMS_CAMBRICON ZEROS")
44 if dtype is None:
45 dtype = torch.get_default_dtype()
46 if device is None:
47 device = torch.device(device_.name)
49 out = torch.empty(size, device=device, dtype=dtype)
50 N = volume(size)
51 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
52 with torch_device_fn.device(device):
53 zeros_kernel[grid_fn](out, N)
54 return out
57def zero_(x: torch.Tensor) -> torch.Tensor:
58 logger.debug("GEMS_CAMBRICON ZERO_")
59 N = x.numel()
60 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
61 with torch_device_fn.device(x.device):
62 zeros_kernel[grid_fn](x, N)
63 return x