Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/fill.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from flag_gems.runtime import torch_device_fn
10from ..utils.pointwise_dynamic import pointwise_dynamic
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14config_ = CodeGenConfig(
15 512,
16 (65536, 65536, 65536),
17 32,
18 True,
19 prefer_1d_tile=True,
20 isCloseDtypeConvert=True,
21)
24@pointwise_dynamic(
25 is_tensor=[True, False],
26 promotion_methods=[(0, "DEFAULT")],
27 num_outputs=1,
28 config=config_,
29)
30@triton.jit
31def fill_scalar_func(inp, value_scalar):
32 return tl.full(inp.shape, value_scalar, dtype=inp.dtype)
35@pointwise_dynamic(
36 is_tensor=[True, True],
37 promotion_methods=[(0, "DEFAULT")],
38 num_outputs=1,
39 config=config_,
40)
41@triton.jit
42def fill_tensor_func(inp, value):
43 return value
46def fill_scalar(input, value):
47 logger.debug("GEMS FILL (Dynamic)")
48 out = torch.empty_like(input)
49 with torch_device_fn.device(input.device):
50 return fill_scalar_func(input, value, out0=out)
53def fill_tensor(input, value):
54 if not value.is_cuda:
55 return fill_scalar(input, value.item())
56 logger.debug("GEMS FILL (Dynamic)")
57 if value.ndim != 0:
58 raise RuntimeError(
59 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
60 )
61 out = torch.empty_like(input)
62 with torch_device_fn.device(input.device):
63 return fill_tensor_func(input, value, out0=out)
66def fill_tensor_(self, value):
67 if not value.is_cuda:
68 return fill_scalar_(self, value.item())
69 logger.debug("GEMS FILL_TENSOR_")
70 if value.ndim != 0:
71 raise RuntimeError(
72 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
73 )
74 with torch_device_fn.device(self.device):
75 fill_tensor_func(self, value, out0=self)
76 return self
79def fill_scalar_(self, value):
80 logger.debug("GEMS FILL_SCALAR_")
81 with torch_device_fn.device(self.device):
82 fill_scalar_func(self, value, out0=self)
83 return self