Coverage for src/flag_gems/runtime/backend/_ascend/ops/fill.py: 0%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14@libentry()
15@triton.jit(do_not_specialize=["value_scalar"])
16def fill_scalar_kernel(
17 out_ptr,
18 N,
19 value_scalar,
20 BLOCK_SIZE: tl.constexpr,
21 SUBBLOCK_SIZE: tl.constexpr,
22):
23 pid = tle.program_id(0)
24 pid_offset = pid * BLOCK_SIZE
25 cols = tl.arange(0, SUBBLOCK_SIZE)
26 num_loop = triton.cdiv(BLOCK_SIZE, SUBBLOCK_SIZE)
27 for iloop in tl.range(num_loop):
28 offset = pid_offset + iloop * SUBBLOCK_SIZE + cols
29 tl.store(out_ptr + offset, value_scalar, mask=offset < N)
32@libentry()
33@triton.jit
34def fill_tensor_kernel(
35 out_ptr,
36 N,
37 value_ptr,
38 BLOCK_SIZE: tl.constexpr,
39 SUBBLOCK_SIZE: tl.constexpr,
40):
41 pid = tle.program_id(0)
42 pid_offset = pid * BLOCK_SIZE
43 cols = tl.arange(0, SUBBLOCK_SIZE)
44 num_loop = triton.cdiv(BLOCK_SIZE, SUBBLOCK_SIZE)
45 for iloop in tl.range(num_loop):
46 offset = pid_offset + iloop * SUBBLOCK_SIZE + cols
47 value_scalar = tl.load(value_ptr) # load the value from the tensor.
48 tl.store(out_ptr + offset, value_scalar, mask=offset < N)
51def fill_tensor(input, value):
52 if not value.is_cuda:
53 return fill_scalar(input, value.item())
54 logger.debug("GEMS_ASCEND FILL")
55 if value.ndim != 0:
56 raise RuntimeError(
57 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
58 )
59 out = torch.empty_like(input)
60 N = out.numel()
61 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48.
62 grid = min(40, N)
63 BLOCK_SIZE = (N + grid - 1) // grid
64 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE)
66 with torch_device_fn.device(input.device):
67 fill_tensor_kernel[grid,](out, N, value, BLOCK_SIZE, SUBBLOCK_SIZE)
68 return out
71def fill_scalar(input, value):
72 logger.debug("GEMS_ASCEND FILL")
73 out = torch.empty_like(input)
74 N = out.numel()
75 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48.
76 grid = min(40, N)
77 BLOCK_SIZE = (N + grid - 1) // grid
78 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE)
80 with torch_device_fn.device(input.device):
81 fill_scalar_kernel[grid,](out, N, value, BLOCK_SIZE, SUBBLOCK_SIZE)
82 return out
85def fill_tensor_(self, value):
86 if not value.is_cuda:
87 return fill_scalar_(self, value.item())
88 logger.debug("GEMS_ASCEND FILL_TENSOR_")
89 if value.ndim != 0:
90 raise RuntimeError(
91 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
92 )
93 N = self.numel()
94 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48.
95 grid = min(40, N)
96 BLOCK_SIZE = (N + grid - 1) // grid
97 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE)
99 with torch_device_fn.device(self.device):
100 fill_tensor_kernel[grid,](self, N, value, BLOCK_SIZE, SUBBLOCK_SIZE)
101 return self
104def fill_scalar_(self, value):
105 logger.debug("GEMS_ASCEND FILL_SCALAR_")
106 N = self.numel()
107 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48.
108 grid = min(40, N)
109 BLOCK_SIZE = (N + grid - 1) // grid
110 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE)
112 with torch_device_fn.device(self.device):
113 fill_scalar_kernel[grid,](self, N, value, BLOCK_SIZE, SUBBLOCK_SIZE)
114 return self