Coverage for src/flag_gems/experimental_ops/heaviside_.py: 0%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def heaviside_(x_ptr, v_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 v = tl.load(v_ptr + offsets, mask=mask)
16 is_zero = x == 0
17 is_neg = x < 0
18 is_pos = x > 0
20 # For NaN handling on floating types: if none of the comparisons are true, use x + x to propagate NaN.
21 res = tl.where(is_zero, v, tl.where(is_neg, 0, tl.where(is_pos, 1, x + x)))
22 tl.store(x_ptr + offsets, res, mask=mask)
25# Keep a handle to the kernel (its __name__ is "heaviside_")
26heaviside__kernel = heaviside_
29def heaviside_(*args, **kwargs):
30 # Parse arguments similar to torch.ops.aten.heaviside_(self, values)
31 if len(args) >= 2:
32 x, values = args[0], args[1]
33 else:
34 # Fallback to kwargs if provided
35 x = kwargs.get("input", kwargs.get("self", None))
36 values = kwargs.get("values", None)
37 assert (
38 x is not None and values is not None
39 ), "heaviside_ requires two arguments: input tensor and values."
41 # Ensure CUDA tensors
42 assert x.is_cuda, "Input tensor must be on CUDA device."
43 assert x.is_contiguous(), "Input tensor must be contiguous."
45 # Prepare values tensor (support scalar or tensor), broadcast to input shape and ensure same dtype/device
46 if not torch.is_tensor(values):
47 v_tensor = torch.as_tensor(values, device=x.device, dtype=x.dtype)
48 else:
49 v_tensor = values.to(device=x.device, dtype=x.dtype)
51 v_tensor = v_tensor.expand_as(x).contiguous()
52 assert (
53 v_tensor.is_cuda and v_tensor.is_contiguous()
54 ), "Values tensor must be CUDA and contiguous after expansion."
56 n_elements = x.numel()
57 BLOCK_SIZE = 1024
58 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
60 heaviside__kernel[grid](x, v_tensor, n_elements, BLOCK_SIZE=BLOCK_SIZE)
61 return x