Coverage for src/flag_gems/runtime/backend/_cambricon/ops/fill.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, libtuner
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@libtuner(
17 configs=[
18 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
21 ],
22 key=["N"],
23 strategy=["log"],
24)
25@triton.jit(do_not_specialize=["value_scalar"])
26def fill_scalar_kernel(
27 out_ptr,
28 N,
29 value_scalar,
30 BLOCK_SIZE: tl.constexpr,
31):
32 pid = tl.program_id(0)
33 num_jobs = tl.num_programs(axis=0)
34 block_start = pid * BLOCK_SIZE
35 step = num_jobs * BLOCK_SIZE
36 block_start = block_start.to(tl.int64)
37 for block_start_offset in range(block_start, N, step):
38 offset = block_start_offset + tl.arange(0, BLOCK_SIZE)
39 tl.store(out_ptr + offset, value_scalar, mask=offset < N)
42@libentry()
43@libtuner(
44 configs=[
45 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
46 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
47 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
48 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
49 ],
50 key=["N"],
51)
52@triton.jit
53def fill_tensor_kernel(
54 out_ptr,
55 N,
56 value_ptr,
57 BLOCK_SIZE: tl.constexpr,
58):
59 pid = tl.program_id(0)
60 num_jobs = tl.num_programs(axis=0)
61 block_start = pid * BLOCK_SIZE
62 step = num_jobs * BLOCK_SIZE
63 block_start = block_start.to(tl.int64)
64 for block_start_offset in range(block_start, N, step):
65 offset = block_start_offset + tl.arange(0, BLOCK_SIZE)
66 value_scalar = tl.load(value_ptr) # load the value from the tensor.
67 tl.store(out_ptr + offset, value_scalar, mask=offset < N)
70def fill_tensor(input, value):
71 logger.debug("GEMS_CAMBRICON FILL TENSOR")
72 out = torch.empty_like(input)
73 N = out.numel()
74 # grid = triton.cdiv(N, BLOCK_SIZE)
75 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
77 with torch_device_fn.device(input.device):
78 fill_tensor_kernel[grid_fn](out, N, value)
79 return out
82def fill_scalar(input, value):
83 logger.debug("GEMS_CAMBRICON FILL SCALAR")
84 if 0 in input.shape:
85 return input
86 out = torch.empty_like(input)
87 N = out.numel()
88 # grid = triton.cdiv(N, BLOCK_SIZE)
89 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
91 with torch_device_fn.device(input.device):
92 fill_scalar_kernel[grid_fn](out, N, value)
93 return out
96def fill_tensor_(self, value):
97 logger.debug("GEMS_CAMBRICON FILL_TENSOR_")
98 if value.ndim != 0:
99 raise RuntimeError(
100 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
101 )
102 N = self.numel()
103 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
105 with torch_device_fn.device(self.device):
106 fill_tensor_kernel[grid_fn](self, N, value)
107 return self
110def fill_scalar_(self, value):
111 logger.debug("GEMS_CAMBRICON FILL_SCALAR_")
112 if 0 in self.shape:
113 return self
114 N = self.numel()
115 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
117 with torch_device_fn.device(self.device):
118 fill_scalar_kernel[grid_fn](self, N, value)
119 return self