Coverage for src/flag_gems/runtime/backend/_mthreads/ops/one_hot.py: 0%
104 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(
12 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
13)
16@libentry()
17@triton.jit
18def one_hot_kernel_16(
19 input_ptr,
20 output_ptr,
21 num_elements,
22 actual_classes,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(axis=0)
26 block_start = pid * BLOCK_SIZE
27 offsets = block_start + tl.arange(0, BLOCK_SIZE)
28 mask = offsets < num_elements
30 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
31 out_base = offsets * actual_classes
33 class_offsets = tl.arange(0, 16)
34 out_offsets = out_base[:, None] + class_offsets[None, :]
35 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0)
36 valid_classes = class_offsets < actual_classes
37 combined_mask = mask[:, None] & valid_classes[None, :]
38 tl.store(output_ptr + out_offsets, values, mask=combined_mask)
41@libentry()
42@triton.jit
43def one_hot_kernel_32(
44 input_ptr,
45 output_ptr,
46 num_elements,
47 actual_classes,
48 BLOCK_SIZE: tl.constexpr,
49):
50 pid = tle.program_id(axis=0)
51 block_start = pid * BLOCK_SIZE
52 offsets = block_start + tl.arange(0, BLOCK_SIZE)
53 mask = offsets < num_elements
55 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
56 out_base = offsets * actual_classes
58 class_offsets = tl.arange(0, 32)
59 out_offsets = out_base[:, None] + class_offsets[None, :]
60 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0)
61 valid_classes = class_offsets < actual_classes
62 combined_mask = mask[:, None] & valid_classes[None, :]
63 tl.store(output_ptr + out_offsets, values, mask=combined_mask)
66@libentry()
67@triton.jit
68def one_hot_kernel_64(
69 input_ptr,
70 output_ptr,
71 num_elements,
72 actual_classes,
73 BLOCK_SIZE: tl.constexpr,
74):
75 pid = tle.program_id(axis=0)
76 block_start = pid * BLOCK_SIZE
77 offsets = block_start + tl.arange(0, BLOCK_SIZE)
78 mask = offsets < num_elements
80 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
81 out_base = offsets * actual_classes
83 class_offsets = tl.arange(0, 64)
84 out_offsets = out_base[:, None] + class_offsets[None, :]
85 values = tl.where(indices[:, None] == class_offsets[None, :], 1, 0)
86 valid_classes = class_offsets < actual_classes
87 combined_mask = mask[:, None] & valid_classes[None, :]
88 tl.store(output_ptr + out_offsets, values, mask=combined_mask)
91@libentry()
92@triton.jit
93def one_hot_set_one_kernel(
94 input_ptr,
95 output_ptr,
96 num_elements,
97 num_classes,
98 BLOCK_SIZE: tl.constexpr,
99):
100 """
101 Kernel that only writes 1s to the correct positions.
102 Output tensor should be pre-initialized with zeros.
103 """
104 pid = tle.program_id(axis=0)
105 block_start = pid * BLOCK_SIZE
106 offsets = block_start + tl.arange(0, BLOCK_SIZE)
107 mask = offsets < num_elements
109 indices = tl.load(input_ptr + offsets, mask=mask, other=0)
110 out_offsets = offsets * num_classes + indices
111 tl.store(output_ptr + out_offsets, 1, mask=mask)
114def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor:
115 logger.debug("GEMS_MTHREADS ONE_HOT")
117 if tensor.dtype != torch.int64:
118 raise RuntimeError(
119 "one_hot is only applicable to index tensor of type LongTensor."
120 )
122 if tensor.numel() == 0:
123 if num_classes <= 0:
124 raise RuntimeError(
125 "Can not infer total number of classes from empty tensor."
126 )
127 shape = (*tensor.shape, num_classes)
128 return torch.empty(shape, device=tensor.device, dtype=torch.int64)
130 # Only compute max when necessary (num_classes=-1)
131 if num_classes == -1:
132 # Only compute max to infer num_classes
133 maxv = int(tensor.max().item())
134 num_classes = maxv + 1
135 else:
136 if num_classes < 1:
137 raise RuntimeError("num_classes should be positive")
139 # CPU tensor handling
140 if tensor.device.type == "cpu":
141 out = torch.zeros((*tensor.shape, num_classes), device="cpu", dtype=torch.int64)
142 out.scatter_(-1, tensor.unsqueeze(-1), 1)
143 return out
145 # Flatten input for kernel processing
146 flat_input = tensor.contiguous().view(-1)
147 num_elements = flat_input.numel()
149 # Choose kernel based on num_classes
150 with torch_device_fn.device(tensor.device):
151 if num_classes <= 16:
152 out = torch.empty(
153 num_elements * num_classes, device=tensor.device, dtype=torch.int64
154 )
155 BLOCK_SIZE = 128
156 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
157 one_hot_kernel_16[grid](
158 flat_input,
159 out,
160 num_elements,
161 num_classes,
162 BLOCK_SIZE=BLOCK_SIZE,
163 )
164 elif num_classes <= 32:
165 out = torch.empty(
166 num_elements * num_classes, device=tensor.device, dtype=torch.int64
167 )
168 BLOCK_SIZE = 128
169 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
170 one_hot_kernel_32[grid](
171 flat_input,
172 out,
173 num_elements,
174 num_classes,
175 BLOCK_SIZE=BLOCK_SIZE,
176 )
177 elif num_classes <= 64:
178 out = torch.empty(
179 num_elements * num_classes, device=tensor.device, dtype=torch.int64
180 )
181 BLOCK_SIZE = 128
182 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
183 one_hot_kernel_64[grid](
184 flat_input,
185 out,
186 num_elements,
187 num_classes,
188 BLOCK_SIZE=BLOCK_SIZE,
189 )
190 else:
191 # For large num_classes, use zeros + set ones
192 out = torch.zeros(
193 num_elements * num_classes, device=tensor.device, dtype=torch.int64
194 )
195 BLOCK_SIZE = 1024
196 grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
197 one_hot_set_one_kernel[grid](
198 flat_input,
199 out,
200 num_elements,
201 num_classes,
202 BLOCK_SIZE=BLOCK_SIZE,
203 )
205 return out.view(*tensor.shape, num_classes)