Coverage for src/flag_gems/experimental_ops/threshold_.py: 0%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def threshold_(
8 x_ptr, # Pointer to input/output tensor (in-place)
9 n_elements, # Number of elements
10 threshold_ptr, # Pointer to scalar threshold (0-d tensor)
11 value_ptr, # Pointer to scalar value (0-d tensor)
12 BLOCK_SIZE: tl.constexpr,
13):
14 pid = tl.program_id(axis=0)
15 block_start = pid * BLOCK_SIZE
16 offsets = block_start + tl.arange(0, BLOCK_SIZE)
17 mask = offsets < n_elements
19 # Load data
20 x = tl.load(x_ptr + offsets, mask=mask)
22 # Load scalars (dtype matches x because we pass 0-d tensors of x.dtype)
23 thr = tl.load(threshold_ptr)
24 val = tl.load(value_ptr)
26 # Apply threshold in-place: if x <= thr, set to val, else keep x
27 out = tl.where(x <= thr, val, x)
29 # Store back
30 tl.store(x_ptr + offsets, out, mask=mask)
33# Keep a handle to the Triton kernel before defining the Python wrapper of the same name
34threshold__triton_kernel = threshold_
37def threshold_(*args, **kwargs):
38 # Extract arguments similar to aten.threshold_ signature: (self, threshold, value=0)
39 x = kwargs.get("input", args[0] if len(args) > 0 else None)
40 threshold = kwargs.get("threshold", args[1] if len(args) > 1 else None)
41 value = kwargs.get("value", args[2] if len(args) > 2 else 0)
43 if x is None or threshold is None:
44 raise ValueError("threshold_ requires at least (input, threshold) arguments")
46 if not x.is_cuda:
47 raise ValueError(
48 "Input tensor must be on CUDA device for Triton kernel execution"
49 )
50 if x.is_complex():
51 raise ValueError("Complex dtypes are not supported by this kernel")
52 if not x.is_contiguous():
53 raise ValueError("Input tensor must be contiguous for this Triton kernel")
55 n_elements = x.numel()
57 # Prepare scalar tensors for threshold and value with matching dtype/device
58 thr_t = torch.tensor(threshold, dtype=x.dtype, device=x.device)
59 val_t = torch.tensor(value, dtype=x.dtype, device=x.device)
61 # Launch configuration
62 BLOCK_SIZE = 1024
63 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
65 # Launch the Triton kernel (in-place)
66 threshold__triton_kernel[grid](x, n_elements, thr_t, val_t, BLOCK_SIZE=BLOCK_SIZE)
68 return x