Coverage for src/flag_gems/experimental_ops/logit_.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def logit_( 

8 x_ptr, 

9 n_elements, 

10 eps, 

11 has_eps: tl.constexpr, 

12 COMPUTE_FP32: tl.constexpr, 

13 COMPUTE_FP64: tl.constexpr, 

14 BLOCK_SIZE: tl.constexpr, 

15): 

16 pid = tl.program_id(axis=0) 

17 block_start = pid * BLOCK_SIZE 

18 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

19 mask = offsets < n_elements 

20 

21 x = tl.load(x_ptr + offsets, mask=mask) 

22 

23 # Promote to higher precision for computation if needed 

24 if COMPUTE_FP32: 

25 xc = x.to(tl.float32) 

26 if has_eps: 

27 xc = tl.maximum(xc, eps) 

28 xc = tl.minimum(xc, 1.0 - eps) 

29 y = tl.log(xc / (1.0 - xc)) 

30 out = y.to(x.dtype) 

31 elif COMPUTE_FP64: 

32 xc = x # already float64 

33 if has_eps: 

34 xc = tl.maximum(xc, eps) 

35 xc = tl.minimum(xc, 1.0 - eps) 

36 out = tl.log(xc / (1.0 - xc)) 

37 else: 

38 # float32 compute 

39 xc = x 

40 if has_eps: 

41 xc = tl.maximum(xc, eps) 

42 xc = tl.minimum(xc, 1.0 - eps) 

43 out = tl.log(xc / (1.0 - xc)) 

44 

45 tl.store(x_ptr + offsets, out, mask=mask) 

46 

47 

48# Keep a handle to the Triton kernel before defining the Python wrapper with the same name 

49logit___kernel = logit_ 

50 

51 

52def logit_(*args, **kwargs): 

53 # Parse arguments similar to torch.logit_(input, eps=None) 

54 if len(args) == 0: 

55 raise TypeError("logit_ expected at least 1 argument (got 0)") 

56 x = args[0] 

57 eps = None 

58 if len(args) > 1: 

59 eps = args[1] 

60 if "eps" in kwargs: 

61 eps = kwargs["eps"] 

62 

63 if not isinstance(x, torch.Tensor): 

64 raise TypeError("logit_ expects a torch.Tensor as the first argument") 

65 if not x.is_cuda: 

66 raise ValueError("logit_ Triton implementation requires a CUDA tensor") 

67 if not x.is_floating_point(): 

68 raise TypeError("logit_ expects a floating point tensor") 

69 

70 has_eps = eps is not None 

71 eps_value = float(eps) if has_eps else 0.0 

72 

73 # Work on a contiguous buffer; copy back if needed to preserve in-place semantics 

74 needs_copy_back = not x.is_contiguous() 

75 buf = x if not needs_copy_back else x.contiguous() 

76 

77 n_elements = buf.numel() 

78 if n_elements == 0: 

79 return x 

80 

81 dtype = buf.dtype 

82 compute_in_fp32 = dtype in (torch.float16, torch.bfloat16) 

83 compute_in_fp64 = dtype == torch.float64 

84 

85 BLOCK_SIZE = 1024 

86 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

87 

88 logit___kernel[grid]( 

89 buf, 

90 n_elements, 

91 eps_value, 

92 has_eps=has_eps, 

93 COMPUTE_FP32=compute_in_fp32, 

94 COMPUTE_FP64=compute_in_fp64, 

95 BLOCK_SIZE=BLOCK_SIZE, 

96 ) 

97 

98 if needs_copy_back: 

99 x.copy_(buf) 

100 

101 return x