Coverage for src/flag_gems/runtime/backend/_cambricon/ops/full.py: 0%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10from flag_gems.utils.shape_utils import volume 

11 

12from ..utils import TOTAL_CORE_NUM 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@libentry() 

18@libtuner( 

19 configs=[ 

20 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

23 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

24 ], 

25 key=["n_elements"], 

26) 

27@triton.jit(do_not_specialize=["fill_value_or_ptr"]) 

28def full_tensor_kernel( 

29 output_ptr, 

30 n_elements, 

31 fill_value_ptr, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tl.program_id(axis=0) 

35 num_jobs = tl.num_programs(axis=0) 

36 block_start = pid * BLOCK_SIZE 

37 step = num_jobs * BLOCK_SIZE 

38 block_start = block_start.to(tl.int64) 

39 for block_start_offset in range(block_start, n_elements, step): 

40 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

41 mask = offsets < n_elements 

42 fill_value = tl.load(fill_value_ptr) 

43 tl.store(output_ptr + offsets, fill_value, mask=mask) 

44 

45 

46@libentry() 

47@libtuner( 

48 configs=[ 

49 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

50 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

51 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

52 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

53 ], 

54 key=["n_elements"], 

55) 

56@triton.jit(do_not_specialize=["fill_value_or_ptr"]) 

57def full_scalar_kernel( 

58 output_ptr, 

59 n_elements, 

60 fill_value, 

61 BLOCK_SIZE: tl.constexpr, 

62): 

63 pid = tl.program_id(axis=0) 

64 num_jobs = tl.num_programs(axis=0) 

65 block_start = pid * BLOCK_SIZE 

66 step = num_jobs * BLOCK_SIZE 

67 block_start = block_start.to(tl.int64) 

68 for block_start_offset in range(block_start, n_elements, step): 

69 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

70 mask = offsets < n_elements 

71 tl.store(output_ptr + offsets, fill_value, mask=mask) 

72 

73 

74ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64) 

75ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64) 

76 

77 

78def check_dtype(fill_value, dtype, device): 

79 if isinstance(fill_value, bool): 

80 if dtype != torch.bool: 

81 fill_value = int(fill_value) 

82 else: 

83 if isinstance(fill_value, float) and math.isinf(fill_value): 

84 if dtype not in ALL_FLOAT_DTYPES: 

85 raise RuntimeError( 

86 f"value {fill_value!r} cannot be converted to type {dtype} without overflow" 

87 ) 

88 elif ( 

89 dtype in ALL_INT_DTYPES 

90 and ( 

91 fill_value < torch.iinfo(dtype).min 

92 or fill_value > torch.iinfo(dtype).max 

93 ) 

94 ) or ( 

95 dtype in ALL_FLOAT_DTYPES 

96 and ( 

97 fill_value < torch.finfo(dtype).min 

98 or fill_value > torch.finfo(dtype).max 

99 ) 

100 ): 

101 raise RuntimeError( 

102 f"value cannot be converted to type {dtype} without overflow" 

103 ) 

104 if dtype in ALL_FLOAT_DTYPES: 

105 fill_value = torch.tensor(fill_value, dtype=dtype, device=device) 

106 return fill_value 

107 

108 

109def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None): 

110 logger.debug("GEMS_CAMBRICON FULL") 

111 if device is None: 

112 device = torch.device("cpu") 

113 if dtype is None: 

114 if isinstance(fill_value, bool): 

115 dtype = torch.bool 

116 elif isinstance(fill_value, int): 

117 dtype = torch.int64 

118 else: 

119 dtype = torch.get_default_dtype() 

120 else: 

121 fill_value = check_dtype(fill_value, dtype, device) 

122 

123 out = torch.empty(size, device=device, dtype=dtype) 

124 N = volume(size) 

125 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

126 with torch_device_fn.device(device): 

127 if isinstance(fill_value, torch.Tensor): 

128 full_tensor_kernel[grid_fn]( 

129 out, 

130 N, 

131 fill_value, 

132 ) 

133 else: 

134 full_scalar_kernel[grid_fn]( 

135 out, 

136 N, 

137 fill_value, 

138 ) 

139 return out