Coverage for src/flag_gems/runtime/backend/_hygon/ops/fill.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger(__name__)
14@pointwise_dynamic(
15 is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")], num_outputs=1
16)
17@triton.jit
18def fill_scalar_func(inp, value_scalar):
19 return tl.full(inp.shape, value_scalar, dtype=inp.dtype)
22@pointwise_dynamic(
23 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], num_outputs=1
24)
25@triton.jit
26def fill_tensor_func(inp, value):
27 return value
30def fill_scalar(input, value):
31 logger.debug("GEMS FILL (Dynamic)")
32 out = torch.empty_like(input)
33 with torch_device_fn.device(input.device):
34 return fill_scalar_func(input, value, out0=out)
37def fill_tensor(input, value):
38 if not value.is_cuda:
39 return fill_scalar(input, value.item())
40 logger.debug("GEMS FILL (Dynamic)")
41 if value.ndim != 0:
42 raise RuntimeError(
43 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
44 )
45 out = torch.empty_like(input)
46 with torch_device_fn.device(input.device):
47 return fill_tensor_func(input, value, out0=out)
50def fill_tensor_(self, value):
51 if not value.is_cuda:
52 return fill_scalar_(self, value.item())
53 logger.debug("GEMS FILL_TENSOR_")
54 if value.ndim != 0:
55 raise RuntimeError(
56 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
57 )
58 with torch_device_fn.device(self.device):
59 fill_tensor_func(self, value, out0=self)
60 return self
63def fill_scalar_(self, value):
64 logger.debug("GEMS FILL_SCALAR_")
65 with torch_device_fn.device(self.device):
66 fill_scalar_func(self, value, out0=self)
67 return self