Coverage for src/flag_gems/ops/logit_.py: 52%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def logit_kernel( 

15 x_ptr, 

16 n_elements, 

17 eps, 

18 has_eps: tl.constexpr, 

19 COMPUTE_FP32: tl.constexpr, 

20 COMPUTE_FP64: tl.constexpr, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid = tl.program_id(axis=0) 

24 block_start = pid * BLOCK_SIZE 

25 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

26 mask = offsets < n_elements 

27 

28 x = tl.load(x_ptr + offsets, mask=mask) 

29 

30 if COMPUTE_FP32: 

31 xc = x.to(tl.float32) 

32 if has_eps: 

33 xc = tl.maximum(xc, eps) 

34 xc = tl.minimum(xc, 1.0 - eps) 

35 y = tl.log(xc / (1.0 - xc)) 

36 out = y.to(x.dtype) 

37 elif COMPUTE_FP64: 

38 xc = x 

39 if has_eps: 

40 xc = tl.maximum(xc, eps) 

41 xc = tl.minimum(xc, 1.0 - eps) 

42 out = tl.log(xc / (1.0 - xc)) 

43 else: 

44 xc = x 

45 if has_eps: 

46 xc = tl.maximum(xc, eps) 

47 xc = tl.minimum(xc, 1.0 - eps) 

48 out = tl.log(xc / (1.0 - xc)) 

49 

50 tl.store(x_ptr + offsets, out, mask=mask) 

51 

52 

53def logit_(*args, **kwargs): 

54 logger.debug("GEMS LOGIT_") 

55 if len(args) == 0: 

56 raise TypeError("logit_ expected at least 1 argument (got 0)") 

57 x = args[0] 

58 eps = None 

59 if len(args) > 1: 

60 eps = args[1] 

61 if "eps" in kwargs: 

62 eps = kwargs["eps"] 

63 

64 if not isinstance(x, torch.Tensor): 

65 raise TypeError("logit_ expects a torch.Tensor as the first argument") 

66 if not x.is_floating_point(): 

67 raise TypeError("logit_ expects a floating point tensor") 

68 

69 has_eps = eps is not None 

70 eps_value = float(eps) if has_eps else 0.0 

71 

72 needs_copy_back = not x.is_contiguous() 

73 buf = x if not needs_copy_back else x.contiguous() 

74 

75 n_elements = buf.numel() 

76 if n_elements == 0: 

77 return x 

78 

79 dtype = buf.dtype 

80 compute_in_fp32 = dtype in (torch.float16, torch.bfloat16) 

81 compute_in_fp64 = dtype == torch.float64 

82 

83 BLOCK_SIZE = 1024 

84 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

85 

86 with torch_device_fn.device(x.device): 

87 logit_kernel[grid]( 

88 buf, 

89 n_elements, 

90 eps_value, 

91 has_eps=has_eps, 

92 COMPUTE_FP32=compute_in_fp32, 

93 COMPUTE_FP64=compute_in_fp64, 

94 BLOCK_SIZE=BLOCK_SIZE, 

95 ) 

96 

97 if needs_copy_back: 

98 x.copy_(buf) 

99 

100 return x