Coverage for src/flag_gems/runtime/backend/_metax/ops/full.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.shape_utils import volume 

11 

12logger = logging.getLogger("flag_gems." + __name__) 

13 

14 

15@triton.jit(do_not_specialize=["fill_value_or_ptr"]) 

16def full_kernel( 

17 output_ptr, 

18 n_elements, 

19 fill_value_or_ptr, 

20 FILL_VALUE_IS_PTR: tl.constexpr, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid = tle.program_id(axis=0) 

24 block_start = pid * BLOCK_SIZE 

25 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

26 mask = offsets < n_elements 

27 if FILL_VALUE_IS_PTR: 

28 fill_value = tl.load(fill_value_or_ptr) 

29 else: 

30 fill_value = fill_value_or_ptr 

31 tl.store(output_ptr + offsets, fill_value, mask=mask) 

32 

33 

34@triton.autotune( 

35 configs=[ 

36 triton.Config({"BLOCK_SIZE": k}, num_warps=w, num_stages=4) 

37 for w in [2, 4, 8, 16] # maca support up to 16 

38 for k in [1024, 2048, 4096, 8192] 

39 ], 

40 key=[ 

41 "N", 

42 ], 

43) 

44@triton.jit() 

45def full_kernel_scale( 

46 output_ptr, 

47 N, 

48 fill_value, 

49 BLOCK_SIZE: tl.constexpr, 

50): 

51 pid = tle.program_id(axis=0) 

52 block_start = pid * BLOCK_SIZE 

53 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

54 mask = offsets < N 

55 tl.store(output_ptr + offsets, fill_value, mask=mask) 

56 

57 

58ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64) 

59ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64) 

60 

61 

62def check_dtype(fill_value, dtype, device): 

63 if isinstance(fill_value, bool): 

64 if dtype != torch.bool: 

65 fill_value = int(fill_value) 

66 elif ( 

67 dtype in ALL_INT_DTYPES 

68 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max) 

69 ) or ( 

70 dtype in ALL_FLOAT_DTYPES 

71 and not (math.isinf(fill_value) or math.isnan(fill_value)) 

72 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max) 

73 ): 

74 raise RuntimeError( 

75 f"value cannot be converted to type {dtype} without overflow" 

76 ) 

77 if dtype in [torch.double]: 

78 fill_value = torch.tensor(fill_value, dtype=dtype, device=device) 

79 

80 return fill_value 

81 

82 

83def full_(out, N, dtype, device, fill_value): 

84 FILL_VALUE_IS_PTR = isinstance(fill_value, torch.Tensor) 

85 is_scale = True 

86 if FILL_VALUE_IS_PTR: 

87 is_scale = False 

88 

89 if FILL_VALUE_IS_PTR and fill_value.numel() == 1 and dtype not in [torch.double]: 

90 fill_value = fill_value.item() 

91 is_scale = True 

92 

93 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

94 with torch_device_fn.device(device): 

95 if is_scale: 

96 full_kernel_scale[grid_fn]( 

97 out, 

98 N, 

99 fill_value, 

100 ) 

101 else: 

102 full_kernel[grid_fn]( 

103 out, 

104 N, 

105 fill_value, 

106 FILL_VALUE_IS_PTR=FILL_VALUE_IS_PTR, 

107 BLOCK_SIZE=1024, 

108 ) 

109 return out 

110 

111 

112def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None): 

113 logger.debug("METAX GEMS FULL") 

114 if device is None: 

115 device = torch.device("cpu") 

116 if dtype is None: 

117 if isinstance(fill_value, bool): 

118 dtype = torch.bool 

119 elif isinstance(fill_value, int): 

120 dtype = torch.int64 

121 else: 

122 dtype = torch.get_default_dtype() 

123 else: 

124 fill_value = check_dtype(fill_value, dtype, device) 

125 

126 out = torch.empty(size, device=device, dtype=dtype) 

127 N = volume(size) 

128 

129 return full_(out, N, dtype, device, fill_value)