Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/fill.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10from ..utils.pointwise_dynamic import pointwise_dynamic 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14config_ = CodeGenConfig( 

15 512, 

16 (65536, 65536, 65536), 

17 32, 

18 True, 

19 prefer_1d_tile=True, 

20 isCloseDtypeConvert=True, 

21) 

22 

23 

24@pointwise_dynamic( 

25 is_tensor=[True, False], 

26 promotion_methods=[(0, "DEFAULT")], 

27 num_outputs=1, 

28 config=config_, 

29) 

30@triton.jit 

31def fill_scalar_func(inp, value_scalar): 

32 return tl.full(inp.shape, value_scalar, dtype=inp.dtype) 

33 

34 

35@pointwise_dynamic( 

36 is_tensor=[True, True], 

37 promotion_methods=[(0, "DEFAULT")], 

38 num_outputs=1, 

39 config=config_, 

40) 

41@triton.jit 

42def fill_tensor_func(inp, value): 

43 return value 

44 

45 

46def fill_scalar(input, value): 

47 logger.debug("GEMS FILL (Dynamic)") 

48 out = torch.empty_like(input) 

49 with torch_device_fn.device(input.device): 

50 return fill_scalar_func(input, value, out0=out) 

51 

52 

53def fill_tensor(input, value): 

54 if not value.is_cuda: 

55 return fill_scalar(input, value.item()) 

56 logger.debug("GEMS FILL (Dynamic)") 

57 if value.ndim != 0: 

58 raise RuntimeError( 

59 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

60 ) 

61 out = torch.empty_like(input) 

62 with torch_device_fn.device(input.device): 

63 return fill_tensor_func(input, value, out0=out) 

64 

65 

66def fill_tensor_(self, value): 

67 if not value.is_cuda: 

68 return fill_scalar_(self, value.item()) 

69 logger.debug("GEMS FILL_TENSOR_") 

70 if value.ndim != 0: 

71 raise RuntimeError( 

72 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

73 ) 

74 with torch_device_fn.device(self.device): 

75 fill_tensor_func(self, value, out0=self) 

76 return self 

77 

78 

79def fill_scalar_(self, value): 

80 logger.debug("GEMS FILL_SCALAR_") 

81 with torch_device_fn.device(self.device): 

82 fill_scalar_func(self, value, out0=self) 

83 return self