Coverage for src/flag_gems/experimental_ops/threshold.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def threshold_kernel( 

8 x_ptr, # *Pointer* to input tensor 

9 y_ptr, # *Pointer* to output tensor 

10 n_elements, # Number of elements 

11 threshold, # Scalar threshold 

12 value, # Scalar value to use when x <= threshold 

13 BLOCK_SIZE: tl.constexpr, 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 y = tl.where(x > threshold, x, value) 

22 tl.store(y_ptr + offsets, y, mask=mask) 

23 

24 

25def _coerce_scalars_for_dtype(dtype, threshold, value): 

26 if dtype.is_complex: 

27 raise TypeError("aten.threshold does not support complex dtypes.") 

28 if dtype == torch.bool: 

29 raise TypeError("aten.threshold does not support bool dtype.") 

30 if dtype.is_floating_point: 

31 thr = float(threshold) 

32 val = float(value) 

33 else: 

34 thr = int(threshold) 

35 val = int(value) 

36 return thr, val 

37 

38 

39def threshold(input: torch.Tensor, threshold, value): 

40 if input.device.type != "cuda": 

41 raise RuntimeError("This Triton implementation requires CUDA tensors.") 

42 x = input.contiguous() 

43 n_elements = x.numel() 

44 out = torch.empty_like(x) 

45 

46 if n_elements == 0: 

47 return out 

48 

49 thr_scalar, val_scalar = _coerce_scalars_for_dtype(x.dtype, threshold, value) 

50 

51 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

52 threshold_kernel[grid]( 

53 x, 

54 out, 

55 n_elements, 

56 thr_scalar, 

57 val_scalar, 

58 BLOCK_SIZE=1024, 

59 ) 

60 return out 

61 

62 

63def threshold_out(input: torch.Tensor, threshold, value, out: torch.Tensor): 

64 if input.device.type != "cuda" or out.device.type != "cuda": 

65 raise RuntimeError("This Triton implementation requires CUDA tensors.") 

66 if out.shape != input.shape: 

67 raise RuntimeError( 

68 f"out shape {out.shape} must match input shape {input.shape}" 

69 ) 

70 if out.dtype != input.dtype: 

71 raise RuntimeError( 

72 f"out dtype {out.dtype} must match input dtype {input.dtype}" 

73 ) 

74 

75 x = input.contiguous() 

76 n_elements = x.numel() 

77 

78 # Prepare output (contiguous temp if needed) 

79 y = out if out.is_contiguous() else torch.empty_like(x) 

80 

81 if n_elements == 0: 

82 if y is not out: 

83 out.copy_(y) 

84 return out 

85 

86 thr_scalar, val_scalar = _coerce_scalars_for_dtype(x.dtype, threshold, value) 

87 

88 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

89 threshold_kernel[grid]( 

90 x, 

91 y, 

92 n_elements, 

93 thr_scalar, 

94 val_scalar, 

95 BLOCK_SIZE=1024, 

96 ) 

97 

98 if y is not out: 

99 out.copy_(y) 

100 return out