Coverage for src/flag_gems/ops/fill.py: 84%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import pointwise_dynamic
10logger = logging.getLogger(__name__)
13@pointwise_dynamic(
14 is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")], num_outputs=1
15)
16@triton.jit
17def fill_scalar_func(inp, value_scalar):
18 return tl.full(inp.shape, value_scalar, dtype=inp.dtype)
21@pointwise_dynamic(
22 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], num_outputs=1
23)
24@triton.jit
25def fill_tensor_func(inp, value):
26 return value
29def fill_scalar(input, value):
30 logger.debug("GEMS FILL (Dynamic)")
31 out = torch.empty_like(input)
32 with torch_device_fn.device(input.device):
33 return fill_scalar_func(input, value, out0=out)
36def fill_scalar_out(input, value, *, out=None):
37 logger.debug("GEMS FILL_SCALAR_OUT")
38 if out is None:
39 return fill_scalar(input, value)
40 with torch_device_fn.device(input.device):
41 fill_scalar_func(input, value, out0=out)
42 return out
45def fill_tensor(input, value):
46 if not value.is_cuda:
47 return fill_scalar(input, value.item())
48 logger.debug("GEMS FILL (Dynamic)")
49 if value.ndim != 0:
50 raise RuntimeError(
51 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
52 )
53 out = torch.empty_like(input)
54 with torch_device_fn.device(input.device):
55 return fill_tensor_func(input, value, out0=out)
58def fill_tensor_out(input, value, *, out=None):
59 logger.debug("GEMS FILL_TENSOR_OUT")
60 if out is None:
61 return fill_tensor(input, value)
62 if not value.is_cuda:
63 return fill_scalar_out(input, value.item(), out=out)
64 if value.ndim != 0:
65 raise RuntimeError(
66 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
67 )
68 with torch_device_fn.device(input.device):
69 fill_tensor_func(input, value, out0=out)
70 return out
73def fill_tensor_(self, value):
74 if not value.is_cuda:
75 return fill_scalar_(self, value.item())
76 logger.debug("GEMS FILL_TENSOR_")
77 if value.ndim != 0:
78 raise RuntimeError(
79 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions."
80 )
81 with torch_device_fn.device(self.device):
82 fill_tensor_func(self, value, out0=self)
83 return self
86def fill_scalar_(self, value):
87 logger.debug("GEMS FILL_SCALAR_")
88 with torch_device_fn.device(self.device):
89 fill_scalar_func(self, value, out0=self)
90 return self