Coverage for src/flag_gems/experimental_ops/heaviside.py: 0%
53 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _heaviside_kernel(x_ptr, v_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 v = tl.load(v_ptr + offsets, mask=mask)
16 zeros = x - x
17 ones = zeros + 1
18 is_pos = x > zeros
19 is_zero = x == zeros
20 out = tl.where(is_zero, v, tl.where(is_pos, ones, zeros))
22 tl.store(out_ptr + offsets, out, mask=mask)
25def heaviside(input, values):
26 # Prepare tensors
27 if not isinstance(values, torch.Tensor):
28 values = torch.as_tensor(values, device=input.device)
29 # Broadcast
30 x_b, v_b = torch.broadcast_tensors(input, values)
31 # Dtype promotion
32 out_dtype = torch.result_type(x_b, v_b)
33 x = x_b.to(dtype=out_dtype).contiguous()
34 v = v_b.to(dtype=out_dtype).contiguous()
36 # Allocate output
37 out = torch.empty_like(x)
39 # Launch kernel
40 n_elements = out.numel()
41 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
42 _heaviside_kernel[grid](x, v, out, n_elements, BLOCK_SIZE=1024)
43 return out
46def heaviside_out(input, values, out):
47 # Prepare tensors
48 if not isinstance(values, torch.Tensor):
49 values = torch.as_tensor(values, device=input.device)
50 # Broadcast
51 x_b, v_b = torch.broadcast_tensors(input, values)
52 # Dtype promotion
53 expected_dtype = torch.result_type(x_b, v_b)
54 expected_shape = x_b.shape
55 device = x_b.device
56 # Check output tensor
57 if out.device != device:
58 raise ValueError("out tensor device must match input device")
59 if out.dtype != expected_dtype:
60 raise ValueError("out tensor dtype must be the result type of input and values")
61 if out.shape != expected_shape:
62 raise ValueError(
63 "out tensor shape must be the broadcasted shape of input and values"
64 )
66 x = x_b.to(dtype=expected_dtype).contiguous()
67 v = v_b.to(dtype=expected_dtype).contiguous()
69 # If out is contiguous, write directly; otherwise use a temp and copy
70 if out.is_contiguous():
71 target = out
72 else:
73 target = torch.empty_like(out, memory_format=torch.contiguous_format)
75 n_elements = target.numel()
76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
77 _heaviside_kernel[grid](x, v, target, n_elements, BLOCK_SIZE=1024)
79 if target is not out:
80 out.copy_(target)
81 return out