Coverage for src/flag_gems/ops/one_hot.py: 81%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
5from flag_gems.ops.scatter import scatter_
7logger = logging.getLogger(__name__)
10def one_hot(tensor: torch.Tensor, num_classes: int = -1) -> torch.Tensor:
11 logger.debug("GEMS ONE_HOT")
12 if tensor.dtype != torch.int64:
13 raise RuntimeError(
14 "one_hot is only applicable to index tensor of type LongTensor."
15 )
17 if tensor.numel() == 0:
18 if num_classes <= 0:
19 raise RuntimeError(
20 "Can not infer total number of classes from empty tensor."
21 )
22 shape = (*tensor.shape, num_classes)
23 return torch.empty(shape, device=tensor.device, dtype=torch.int64)
25 minv = int(tensor.min().item())
26 if minv < 0:
27 raise RuntimeError("Class values must be non-negative.")
28 maxv = int(tensor.max().item())
30 if num_classes == -1:
31 num_classes = maxv + 1
32 else:
33 if num_classes < 1:
34 raise RuntimeError("num_classes should be positive")
35 if maxv >= num_classes:
36 raise RuntimeError("Class values must be smaller than num_classes.")
38 if tensor.device.type == "cpu":
39 out = torch.zeros((*tensor.shape, num_classes), device="cpu", dtype=torch.int64)
40 out.scatter_(-1, tensor.unsqueeze(-1), 1)
41 return out
43 out = torch.zeros(
44 (*tensor.shape, num_classes), device=tensor.device, dtype=torch.int64
45 )
46 index = tensor.unsqueeze(-1)
47 src = torch.ones_like(index, dtype=torch.int64)
48 scatter_(out, -1, index, src, reduce=None)
49 return out