Coverage for src/flag_gems/ops/one_hot.py: 60%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def one_hot_kernel( 

12 index_ptr, 

13 out_ptr, 

14 num_classes, 

15 numel, 

16 BLOCK_M: tl.constexpr, 

17 BLOCK_N: tl.constexpr, 

18): 

19 pid = tl.program_id(0) 

20 

21 row_start = pid * BLOCK_M 

22 row_offsets = row_start + tl.arange(0, BLOCK_M) 

23 row_mask = row_offsets < numel 

24 

25 target_classes = tl.load(index_ptr + row_offsets, mask=row_mask, other=0) 

26 

27 for col_st in range(0, num_classes, BLOCK_N): 

28 col_offsets = col_st + tl.arange(0, BLOCK_N) 

29 col_mask = col_offsets < num_classes 

30 result = target_classes[:, None] == col_offsets[None, :] 

31 result = result.to(tl.int64) 

32 offs_2d = row_offsets[:, None] * num_classes + col_offsets[None, :] 

33 tl.store(out_ptr + offs_2d, result, mask=row_mask[:, None] & col_mask[None, :]) 

34 

35 

36def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor: 

37 logger.debug("GEMS ONE_HOT") 

38 if not tensor.is_cuda: 

39 return torch.nn.functional.one_hot(tensor, num_classes) 

40 if not tensor.is_contiguous(): 

41 tensor = tensor.contiguous() 

42 numel = tensor.numel() 

43 if num_classes == -1: 

44 num_classes = int(tensor.max().item()) + 1 

45 

46 out = torch.empty( 

47 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64 

48 ) 

49 BLOCK_N = triton.next_power_of_2(num_classes) 

50 BLOCK_N = min(BLOCK_N, 128) 

51 BLOCK_M = 32 

52 

53 grid = (triton.cdiv(numel, BLOCK_M),) 

54 

55 one_hot_kernel[grid]( 

56 tensor, 

57 out, 

58 num_classes, 

59 numel, 

60 BLOCK_M=BLOCK_M, 

61 BLOCK_N=BLOCK_N, 

62 ) 

63 return out