Coverage for src/flag_gems/experimental_ops/threshold_.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def threshold_( 

8 x_ptr, # Pointer to input/output tensor (in-place) 

9 n_elements, # Number of elements 

10 threshold_ptr, # Pointer to scalar threshold (0-d tensor) 

11 value_ptr, # Pointer to scalar value (0-d tensor) 

12 BLOCK_SIZE: tl.constexpr, 

13): 

14 pid = tl.program_id(axis=0) 

15 block_start = pid * BLOCK_SIZE 

16 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

17 mask = offsets < n_elements 

18 

19 # Load data 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 

22 # Load scalars (dtype matches x because we pass 0-d tensors of x.dtype) 

23 thr = tl.load(threshold_ptr) 

24 val = tl.load(value_ptr) 

25 

26 # Apply threshold in-place: if x <= thr, set to val, else keep x 

27 out = tl.where(x <= thr, val, x) 

28 

29 # Store back 

30 tl.store(x_ptr + offsets, out, mask=mask) 

31 

32 

33# Keep a handle to the Triton kernel before defining the Python wrapper of the same name 

34threshold__triton_kernel = threshold_ 

35 

36 

37def threshold_(*args, **kwargs): 

38 # Extract arguments similar to aten.threshold_ signature: (self, threshold, value=0) 

39 x = kwargs.get("input", args[0] if len(args) > 0 else None) 

40 threshold = kwargs.get("threshold", args[1] if len(args) > 1 else None) 

41 value = kwargs.get("value", args[2] if len(args) > 2 else 0) 

42 

43 if x is None or threshold is None: 

44 raise ValueError("threshold_ requires at least (input, threshold) arguments") 

45 

46 if not x.is_cuda: 

47 raise ValueError( 

48 "Input tensor must be on CUDA device for Triton kernel execution" 

49 ) 

50 if x.is_complex(): 

51 raise ValueError("Complex dtypes are not supported by this kernel") 

52 if not x.is_contiguous(): 

53 raise ValueError("Input tensor must be contiguous for this Triton kernel") 

54 

55 n_elements = x.numel() 

56 

57 # Prepare scalar tensors for threshold and value with matching dtype/device 

58 thr_t = torch.tensor(threshold, dtype=x.dtype, device=x.device) 

59 val_t = torch.tensor(value, dtype=x.dtype, device=x.device) 

60 

61 # Launch configuration 

62 BLOCK_SIZE = 1024 

63 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

64 

65 # Launch the Triton kernel (in-place) 

66 threshold__triton_kernel[grid](x, n_elements, thr_t, val_t, BLOCK_SIZE=BLOCK_SIZE) 

67 

68 return x