Coverage for src/flag_gems/experimental_ops/threshold.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def threshold_kernel(
8 x_ptr, # *Pointer* to input tensor
9 y_ptr, # *Pointer* to output tensor
10 n_elements, # Number of elements
11 threshold, # Scalar threshold
12 value, # Scalar value to use when x <= threshold
13 BLOCK_SIZE: tl.constexpr,
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask)
21 y = tl.where(x > threshold, x, value)
22 tl.store(y_ptr + offsets, y, mask=mask)
25def _coerce_scalars_for_dtype(dtype, threshold, value):
26 if dtype.is_complex:
27 raise TypeError("aten.threshold does not support complex dtypes.")
28 if dtype == torch.bool:
29 raise TypeError("aten.threshold does not support bool dtype.")
30 if dtype.is_floating_point:
31 thr = float(threshold)
32 val = float(value)
33 else:
34 thr = int(threshold)
35 val = int(value)
36 return thr, val
39def threshold(input: torch.Tensor, threshold, value):
40 if input.device.type != "cuda":
41 raise RuntimeError("This Triton implementation requires CUDA tensors.")
42 x = input.contiguous()
43 n_elements = x.numel()
44 out = torch.empty_like(x)
46 if n_elements == 0:
47 return out
49 thr_scalar, val_scalar = _coerce_scalars_for_dtype(x.dtype, threshold, value)
51 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
52 threshold_kernel[grid](
53 x,
54 out,
55 n_elements,
56 thr_scalar,
57 val_scalar,
58 BLOCK_SIZE=1024,
59 )
60 return out
63def threshold_out(input: torch.Tensor, threshold, value, out: torch.Tensor):
64 if input.device.type != "cuda" or out.device.type != "cuda":
65 raise RuntimeError("This Triton implementation requires CUDA tensors.")
66 if out.shape != input.shape:
67 raise RuntimeError(
68 f"out shape {out.shape} must match input shape {input.shape}"
69 )
70 if out.dtype != input.dtype:
71 raise RuntimeError(
72 f"out dtype {out.dtype} must match input dtype {input.dtype}"
73 )
75 x = input.contiguous()
76 n_elements = x.numel()
78 # Prepare output (contiguous temp if needed)
79 y = out if out.is_contiguous() else torch.empty_like(x)
81 if n_elements == 0:
82 if y is not out:
83 out.copy_(y)
84 return out
86 thr_scalar, val_scalar = _coerce_scalars_for_dtype(x.dtype, threshold, value)
88 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
89 threshold_kernel[grid](
90 x,
91 y,
92 n_elements,
93 thr_scalar,
94 val_scalar,
95 BLOCK_SIZE=1024,
96 )
98 if y is not out:
99 out.copy_(y)
100 return out