Coverage for src/flag_gems/runtime/backend/_ascend/ops/fill.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

12 

13 

14@libentry() 

15@triton.jit(do_not_specialize=["value_scalar"]) 

16def fill_scalar_kernel( 

17 out_ptr, 

18 N, 

19 value_scalar, 

20 BLOCK_SIZE: tl.constexpr, 

21 SUBBLOCK_SIZE: tl.constexpr, 

22): 

23 pid = tle.program_id(0) 

24 pid_offset = pid * BLOCK_SIZE 

25 cols = tl.arange(0, SUBBLOCK_SIZE) 

26 num_loop = triton.cdiv(BLOCK_SIZE, SUBBLOCK_SIZE) 

27 for iloop in tl.range(num_loop): 

28 offset = pid_offset + iloop * SUBBLOCK_SIZE + cols 

29 tl.store(out_ptr + offset, value_scalar, mask=offset < N) 

30 

31 

32@libentry() 

33@triton.jit 

34def fill_tensor_kernel( 

35 out_ptr, 

36 N, 

37 value_ptr, 

38 BLOCK_SIZE: tl.constexpr, 

39 SUBBLOCK_SIZE: tl.constexpr, 

40): 

41 pid = tle.program_id(0) 

42 pid_offset = pid * BLOCK_SIZE 

43 cols = tl.arange(0, SUBBLOCK_SIZE) 

44 num_loop = triton.cdiv(BLOCK_SIZE, SUBBLOCK_SIZE) 

45 for iloop in tl.range(num_loop): 

46 offset = pid_offset + iloop * SUBBLOCK_SIZE + cols 

47 value_scalar = tl.load(value_ptr) # load the value from the tensor. 

48 tl.store(out_ptr + offset, value_scalar, mask=offset < N) 

49 

50 

51def fill_tensor(input, value): 

52 if not value.is_cuda: 

53 return fill_scalar(input, value.item()) 

54 logger.debug("GEMS_ASCEND FILL") 

55 if value.ndim != 0: 

56 raise RuntimeError( 

57 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

58 ) 

59 out = torch.empty_like(input) 

60 N = out.numel() 

61 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48. 

62 grid = min(40, N) 

63 BLOCK_SIZE = (N + grid - 1) // grid 

64 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE) 

65 

66 with torch_device_fn.device(input.device): 

67 fill_tensor_kernel[grid,](out, N, value, BLOCK_SIZE, SUBBLOCK_SIZE) 

68 return out 

69 

70 

71def fill_scalar(input, value): 

72 logger.debug("GEMS_ASCEND FILL") 

73 out = torch.empty_like(input) 

74 N = out.numel() 

75 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48. 

76 grid = min(40, N) 

77 BLOCK_SIZE = (N + grid - 1) // grid 

78 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE) 

79 

80 with torch_device_fn.device(input.device): 

81 fill_scalar_kernel[grid,](out, N, value, BLOCK_SIZE, SUBBLOCK_SIZE) 

82 return out 

83 

84 

85def fill_tensor_(self, value): 

86 if not value.is_cuda: 

87 return fill_scalar_(self, value.item()) 

88 logger.debug("GEMS_ASCEND FILL_TENSOR_") 

89 if value.ndim != 0: 

90 raise RuntimeError( 

91 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

92 ) 

93 N = self.numel() 

94 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48. 

95 grid = min(40, N) 

96 BLOCK_SIZE = (N + grid - 1) // grid 

97 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE) 

98 

99 with torch_device_fn.device(self.device): 

100 fill_tensor_kernel[grid,](self, N, value, BLOCK_SIZE, SUBBLOCK_SIZE) 

101 return self 

102 

103 

104def fill_scalar_(self, value): 

105 logger.debug("GEMS_ASCEND FILL_SCALAR_") 

106 N = self.numel() 

107 # FIXME: 910B3&910B4 have 40 AIV cores while 910B1 has 50, 910B2 has 48. 

108 grid = min(40, N) 

109 BLOCK_SIZE = (N + grid - 1) // grid 

110 SUBBLOCK_SIZE = min(8192, BLOCK_SIZE) 

111 

112 with torch_device_fn.device(self.device): 

113 fill_scalar_kernel[grid,](self, N, value, BLOCK_SIZE, SUBBLOCK_SIZE) 

114 return self