Coverage for src/flag_gems/ops/fill.py: 86%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import pointwise_dynamic
10logger = logging.getLogger(__name__)
13@pointwise_dynamic(
14 is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")], num_outputs=1
15)
16@triton.jit
17def fill_scalar_func(inp, value_scalar):
18 return tl.full(inp.shape, value_scalar, dtype=inp.dtype)
21@pointwise_dynamic(
22 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], num_outputs=1
23)
24@triton.jit
25def fill_tensor_func(inp, value):
26 return value
29def fill_scalar(input, value):
30 logger.debug("GEMS FILL (Dynamic)")
31 out = torch.empty_like(input)
32 with torch_device_fn.device(input.device):
33 return fill_scalar_func(input, value, out0=out)
36def fill_tensor(input, value):
37 if not value.is_cuda:
38 return fill_scalar(input, value.item())
39 logger.debug("GEMS FILL (Dynamic)")
40 if value.ndim != 0:
41 raise RuntimeError(
42 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
43 )
44 out = torch.empty_like(input)
45 with torch_device_fn.device(input.device):
46 return fill_tensor_func(input, value, out0=out)
49def fill_tensor_(self, value):
50 if not value.is_cuda:
51 return fill_scalar_(self, value.item())
52 logger.debug("GEMS FILL_TENSOR_")
53 if value.ndim != 0:
54 raise RuntimeError(
55 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
56 )
57 with torch_device_fn.device(self.device):
58 fill_tensor_func(self, value, out0=self)
59 return self
62def fill_scalar_(self, value):
63 logger.debug("GEMS FILL_SCALAR_")
64 with torch_device_fn.device(self.device):
65 fill_scalar_func(self, value, out0=self)
66 return self