Coverage for src/flag_gems/runtime/backend/_mthreads/ops/one_hot.py: 0%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger( 

12 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

13) 

14 

15 

16@libentry() 

17@triton.jit 

18def one_hot_kernel_16( 

19 input_ptr, 

20 output_ptr, 

21 num_elements, 

22 actual_classes, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(axis=0) 

26 block_start = pid * BLOCK_SIZE 

27 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

28 mask = offsets < num_elements 

29 

30 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

31 out_base = offsets * actual_classes 

32 

33 class_offsets = tl.arange(0, 16) 

34 out_offsets = out_base[:, None] + class_offsets[None, :] 

35 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

36 valid_classes = class_offsets < actual_classes 

37 combined_mask = mask[:, None] & valid_classes[None, :] 

38 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

39 

40 

41@libentry() 

42@triton.jit 

43def one_hot_kernel_32( 

44 input_ptr, 

45 output_ptr, 

46 num_elements, 

47 actual_classes, 

48 BLOCK_SIZE: tl.constexpr, 

49): 

50 pid = tle.program_id(axis=0) 

51 block_start = pid * BLOCK_SIZE 

52 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

53 mask = offsets < num_elements 

54 

55 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

56 out_base = offsets * actual_classes 

57 

58 class_offsets = tl.arange(0, 32) 

59 out_offsets = out_base[:, None] + class_offsets[None, :] 

60 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

61 valid_classes = class_offsets < actual_classes 

62 combined_mask = mask[:, None] & valid_classes[None, :] 

63 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

64 

65 

66@libentry() 

67@triton.jit 

68def one_hot_kernel_64( 

69 input_ptr, 

70 output_ptr, 

71 num_elements, 

72 actual_classes, 

73 BLOCK_SIZE: tl.constexpr, 

74): 

75 pid = tle.program_id(axis=0) 

76 block_start = pid * BLOCK_SIZE 

77 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

78 mask = offsets < num_elements 

79 

80 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

81 out_base = offsets * actual_classes 

82 

83 class_offsets = tl.arange(0, 64) 

84 out_offsets = out_base[:, None] + class_offsets[None, :] 

85 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0) 

86 valid_classes = class_offsets < actual_classes 

87 combined_mask = mask[:, None] & valid_classes[None, :] 

88 tl.store(output_ptr + out_offsets, values, mask=combined_mask) 

89 

90 

91@libentry() 

92@triton.jit 

93def one_hot_set_one_kernel( 

94 input_ptr, 

95 output_ptr, 

96 num_elements, 

97 num_classes, 

98 BLOCK_SIZE: tl.constexpr, 

99): 

100 """ 

101 Kernel that only writes 1s to the correct positions. 

102 Output tensor should be pre-initialized with zeros. 

103 """ 

104 pid = tle.program_id(axis=0) 

105 block_start = pid * BLOCK_SIZE 

106 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

107 mask = offsets < num_elements 

108 

109 indices = tl.load(input_ptr + offsets, mask=mask, other=0) 

110 out_offsets = offsets * num_classes + indices 

111 tl.store(output_ptr + out_offsets, 1, mask=mask) 

112 

113 

114def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor: 

115 logger.debug("GEMS_MTHREADS ONE_HOT") 

116 

117 if tensor.dtype != torch.int64: 

118 raise RuntimeError( 

119 "one_hot is only applicable to index tensor of type LongTensor." 

120 ) 

121 

122 if tensor.numel() == 0: 

123 if num_classes <= 0: 

124 raise RuntimeError( 

125 "Can not infer total number of classes from empty tensor." 

126 ) 

127 shape = (*tensor.shape, num_classes) 

128 return torch.empty(shape, device=tensor.device, dtype=torch.int64) 

129 

130 # Only compute max when necessary (num_classes=-1) 

131 if num_classes == -1: 

132 # Only compute max to infer num_classes 

133 maxv = int(tensor.max().item()) 

134 num_classes = maxv + 1 

135 else: 

136 if num_classes < 1: 

137 raise RuntimeError("num_classes should be positive") 

138 

139 # CPU tensor handling 

140 if tensor.device.type == "cpu": 

141 out = torch.zeros((*tensor.shape, num_classes), device="cpu", dtype=torch.int64) 

142 out.scatter_(-1, tensor.unsqueeze(-1), 1) 

143 return out 

144 

145 # Flatten input for kernel processing 

146 flat_input = tensor.contiguous().view(-1) 

147 num_elements = flat_input.numel() 

148 

149 # Choose kernel based on num_classes 

150 with torch_device_fn.device(tensor.device): 

151 if num_classes <= 16: 

152 out = torch.empty( 

153 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

154 ) 

155 BLOCK_SIZE = 128 

156 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

157 one_hot_kernel_16[grid]( 

158 flat_input, 

159 out, 

160 num_elements, 

161 num_classes, 

162 BLOCK_SIZE=BLOCK_SIZE, 

163 ) 

164 elif num_classes <= 32: 

165 out = torch.empty( 

166 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

167 ) 

168 BLOCK_SIZE = 128 

169 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

170 one_hot_kernel_32[grid]( 

171 flat_input, 

172 out, 

173 num_elements, 

174 num_classes, 

175 BLOCK_SIZE=BLOCK_SIZE, 

176 ) 

177 elif num_classes <= 64: 

178 out = torch.empty( 

179 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

180 ) 

181 BLOCK_SIZE = 128 

182 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

183 one_hot_kernel_64[grid]( 

184 flat_input, 

185 out, 

186 num_elements, 

187 num_classes, 

188 BLOCK_SIZE=BLOCK_SIZE, 

189 ) 

190 else: 

191 # For large num_classes, use zeros + set ones 

192 out = torch.zeros( 

193 num_elements * num_classes, device=tensor.device, dtype=torch.int64 

194 ) 

195 BLOCK_SIZE = 1024 

196 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),) 

197 one_hot_set_one_kernel[grid]( 

198 flat_input, 

199 out, 

200 num_elements, 

201 num_classes, 

202 BLOCK_SIZE=BLOCK_SIZE, 

203 ) 

204 

205 return out.view(*tensor.shape, num_classes)