Coverage for src/flag_gems/runtime/backend/_cambricon/ops/fill.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=[ 

18 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

21 ], 

22 key=["N"], 

23 strategy=["log"], 

24) 

25@triton.jit(do_not_specialize=["value_scalar"]) 

26def fill_scalar_kernel( 

27 out_ptr, 

28 N, 

29 value_scalar, 

30 BLOCK_SIZE: tl.constexpr, 

31): 

32 pid = tl.program_id(0) 

33 num_jobs = tl.num_programs(axis=0) 

34 block_start = pid * BLOCK_SIZE 

35 step = num_jobs * BLOCK_SIZE 

36 block_start = block_start.to(tl.int64) 

37 for block_start_offset in range(block_start, N, step): 

38 offset = block_start_offset + tl.arange(0, BLOCK_SIZE) 

39 tl.store(out_ptr + offset, value_scalar, mask=offset < N) 

40 

41 

42@libentry() 

43@libtuner( 

44 configs=[ 

45 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

46 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

47 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

48 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

49 ], 

50 key=["N"], 

51) 

52@triton.jit 

53def fill_tensor_kernel( 

54 out_ptr, 

55 N, 

56 value_ptr, 

57 BLOCK_SIZE: tl.constexpr, 

58): 

59 pid = tl.program_id(0) 

60 num_jobs = tl.num_programs(axis=0) 

61 block_start = pid * BLOCK_SIZE 

62 step = num_jobs * BLOCK_SIZE 

63 block_start = block_start.to(tl.int64) 

64 for block_start_offset in range(block_start, N, step): 

65 offset = block_start_offset + tl.arange(0, BLOCK_SIZE) 

66 value_scalar = tl.load(value_ptr) # load the value from the tensor. 

67 tl.store(out_ptr + offset, value_scalar, mask=offset < N) 

68 

69 

70def fill_tensor(input, value): 

71 logger.debug("GEMS_CAMBRICON FILL TENSOR") 

72 out = torch.empty_like(input) 

73 N = out.numel() 

74 # grid = triton.cdiv(N, BLOCK_SIZE) 

75 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

76 

77 with torch_device_fn.device(input.device): 

78 fill_tensor_kernel[grid_fn](out, N, value) 

79 return out 

80 

81 

82def fill_scalar(input, value): 

83 logger.debug("GEMS_CAMBRICON FILL SCALAR") 

84 if 0 in input.shape: 

85 return input 

86 out = torch.empty_like(input) 

87 N = out.numel() 

88 # grid = triton.cdiv(N, BLOCK_SIZE) 

89 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

90 

91 with torch_device_fn.device(input.device): 

92 fill_scalar_kernel[grid_fn](out, N, value) 

93 return out 

94 

95 

96def fill_tensor_(self, value): 

97 logger.debug("GEMS_CAMBRICON FILL_TENSOR_") 

98 if value.ndim != 0: 

99 raise RuntimeError( 

100 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

101 ) 

102 N = self.numel() 

103 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

104 

105 with torch_device_fn.device(self.device): 

106 fill_tensor_kernel[grid_fn](self, N, value) 

107 return self 

108 

109 

110def fill_scalar_(self, value): 

111 logger.debug("GEMS_CAMBRICON FILL_SCALAR_") 

112 if 0 in self.shape: 

113 return self 

114 N = self.numel() 

115 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

116 

117 with torch_device_fn.device(self.device): 

118 fill_scalar_kernel[grid_fn](self, N, value) 

119 return self