Coverage for src/flag_gems/ops/fill.py: 84%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import pointwise_dynamic 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@pointwise_dynamic( 

14 is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")], num_outputs=1 

15) 

16@triton.jit 

17def fill_scalar_func(inp, value_scalar): 

18 return tl.full(inp.shape, value_scalar, dtype=inp.dtype) 

19 

20 

21@pointwise_dynamic( 

22 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], num_outputs=1 

23) 

24@triton.jit 

25def fill_tensor_func(inp, value): 

26 return value 

27 

28 

29def fill_scalar(input, value): 

30 logger.debug("GEMS FILL (Dynamic)") 

31 out = torch.empty_like(input) 

32 with torch_device_fn.device(input.device): 

33 return fill_scalar_func(input, value, out0=out) 

34 

35 

36def fill_scalar_out(input, value, *, out=None): 

37 logger.debug("GEMS FILL_SCALAR_OUT") 

38 if out is None: 

39 return fill_scalar(input, value) 

40 with torch_device_fn.device(input.device): 

41 fill_scalar_func(input, value, out0=out) 

42 return out 

43 

44 

45def fill_tensor(input, value): 

46 if not value.is_cuda: 

47 return fill_scalar(input, value.item()) 

48 logger.debug("GEMS FILL (Dynamic)") 

49 if value.ndim != 0: 

50 raise RuntimeError( 

51 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

52 ) 

53 out = torch.empty_like(input) 

54 with torch_device_fn.device(input.device): 

55 return fill_tensor_func(input, value, out0=out) 

56 

57 

58def fill_tensor_out(input, value, *, out=None): 

59 logger.debug("GEMS FILL_TENSOR_OUT") 

60 if out is None: 

61 return fill_tensor(input, value) 

62 if not value.is_cuda: 

63 return fill_scalar_out(input, value.item(), out=out) 

64 if value.ndim != 0: 

65 raise RuntimeError( 

66 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

67 ) 

68 with torch_device_fn.device(input.device): 

69 fill_tensor_func(input, value, out0=out) 

70 return out 

71 

72 

73def fill_tensor_(self, value): 

74 if not value.is_cuda: 

75 return fill_scalar_(self, value.item()) 

76 logger.debug("GEMS FILL_TENSOR_") 

77 if value.ndim != 0: 

78 raise RuntimeError( 

79 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

80 ) 

81 with torch_device_fn.device(self.device): 

82 fill_tensor_func(self, value, out0=self) 

83 return self 

84 

85 

86def fill_scalar_(self, value): 

87 logger.debug("GEMS FILL_SCALAR_") 

88 with torch_device_fn.device(self.device): 

89 fill_scalar_func(self, value, out0=self) 

90 return self