Coverage for src/flag_gems/ops/fill.py: 86%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import pointwise_dynamic 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@pointwise_dynamic( 

14 is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")], num_outputs=1 

15) 

16@triton.jit 

17def fill_scalar_func(inp, value_scalar): 

18 return tl.full(inp.shape, value_scalar, dtype=inp.dtype) 

19 

20 

21@pointwise_dynamic( 

22 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], num_outputs=1 

23) 

24@triton.jit 

25def fill_tensor_func(inp, value): 

26 return value 

27 

28 

29def fill_scalar(input, value): 

30 logger.debug("GEMS FILL (Dynamic)") 

31 out = torch.empty_like(input) 

32 with torch_device_fn.device(input.device): 

33 return fill_scalar_func(input, value, out0=out) 

34 

35 

36def fill_tensor(input, value): 

37 if not value.is_cuda: 

38 return fill_scalar(input, value.item()) 

39 logger.debug("GEMS FILL (Dynamic)") 

40 if value.ndim != 0: 

41 raise RuntimeError( 

42 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

43 ) 

44 out = torch.empty_like(input) 

45 with torch_device_fn.device(input.device): 

46 return fill_tensor_func(input, value, out0=out) 

47 

48 

49def fill_tensor_(self, value): 

50 if not value.is_cuda: 

51 return fill_scalar_(self, value.item()) 

52 logger.debug("GEMS FILL_TENSOR_") 

53 if value.ndim != 0: 

54 raise RuntimeError( 

55 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

56 ) 

57 with torch_device_fn.device(self.device): 

58 fill_tensor_func(self, value, out0=self) 

59 return self 

60 

61 

62def fill_scalar_(self, value): 

63 logger.debug("GEMS FILL_SCALAR_") 

64 with torch_device_fn.device(self.device): 

65 fill_scalar_func(self, value, out0=self) 

66 return self