Coverage for src/flag_gems/ops/one_hot.py: 81%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2 

3import torch 

4 

5from flag_gems.ops.scatter import scatter_ 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor: 

11 logger.debug("GEMS ONE_HOT") 

12 if tensor.dtype != torch.int64: 

13 raise RuntimeError( 

14 "one_hot is only applicable to index tensor of type LongTensor." 

15 ) 

16 

17 if tensor.numel() == 0: 

18 if num_classes <= 0: 

19 raise RuntimeError( 

20 "Can not infer total number of classes from empty tensor." 

21 ) 

22 shape = (*tensor.shape, num_classes) 

23 return torch.empty(shape, device=tensor.device, dtype=torch.int64) 

24 

25 minv = int(tensor.min().item()) 

26 if minv < 0: 

27 raise RuntimeError("Class values must be non-negative.") 

28 maxv = int(tensor.max().item()) 

29 

30 if num_classes == -1: 

31 num_classes = maxv + 1 

32 else: 

33 if num_classes < 1: 

34 raise RuntimeError("num_classes should be positive") 

35 if maxv >= num_classes: 

36 raise RuntimeError("Class values must be smaller than num_classes.") 

37 

38 if tensor.device.type == "cpu": 

39 out = torch.zeros((*tensor.shape, num_classes), device="cpu", dtype=torch.int64) 

40 out.scatter_(-1, tensor.unsqueeze(-1), 1) 

41 return out 

42 

43 out = torch.zeros( 

44 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64 

45 ) 

46 index = tensor.unsqueeze(-1) 

47 src = torch.ones_like(index, dtype=torch.int64) 

48 scatter_(out, -1, index, src, reduce=None) 

49 return out