Coverage for src/flag_gems/ops/one_hot.py: 60%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def one_hot_kernel(
12 index_ptr,
13 out_ptr,
14 num_classes,
15 numel,
16 BLOCK_M: tl.constexpr,
17 BLOCK_N: tl.constexpr,
18):
19 pid = tl.program_id(0)
21 row_start = pid * BLOCK_M
22 row_offsets = row_start + tl.arange(0, BLOCK_M)
23 row_mask = row_offsets < numel
25 target_classes = tl.load(index_ptr + row_offsets, mask=row_mask, other=0)
27 for col_st in range(0, num_classes, BLOCK_N):
28 col_offsets = col_st + tl.arange(0, BLOCK_N)
29 col_mask = col_offsets < num_classes
30 result = target_classes[:, None] == col_offsets[None, :]
31 result = result.to(tl.int64)
32 offs_2d = row_offsets[:, None] * num_classes + col_offsets[None, :]
33 tl.store(out_ptr + offs_2d, result, mask=row_mask[:, None] & col_mask[None, :])
36def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor:
37 logger.debug("GEMS ONE_HOT")
38 if not tensor.is_cuda:
39 return torch.nn.functional.one_hot(tensor, num_classes)
40 if not tensor.is_contiguous():
41 tensor = tensor.contiguous()
42 numel = tensor.numel()
43 if num_classes == -1:
44 num_classes = int(tensor.max().item()) + 1
46 out = torch.empty(
47 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64
48 )
49 BLOCK_N = triton.next_power_of_2(num_classes)
50 BLOCK_N = min(BLOCK_N, 128)
51 BLOCK_M = 32
53 grid = (triton.cdiv(numel, BLOCK_M),)
55 one_hot_kernel[grid](
56 tensor,
57 out,
58 num_classes,
59 numel,
60 BLOCK_M=BLOCK_M,
61 BLOCK_N=BLOCK_N,
62 )
63 return out