Coverage for src/flag_gems/experimental_ops/heaviside_.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def heaviside_(x_ptr, v_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 v = tl.load(v_ptr + offsets, mask=mask) 

15 

16 is_zero = x == 0 

17 is_neg = x < 0 

18 is_pos = x > 0 

19 

20 # For NaN handling on floating types: if none of the comparisons are true, use x + x to propagate NaN. 

21 res = tl.where(is_zero, v, tl.where(is_neg, 0, tl.where(is_pos, 1, x + x))) 

22 tl.store(x_ptr + offsets, res, mask=mask) 

23 

24 

25# Keep a handle to the kernel (its __name__ is "heaviside_") 

26heaviside__kernel = heaviside_ 

27 

28 

29def heaviside_(*args, **kwargs): 

30 # Parse arguments similar to torch.ops.aten.heaviside_(self, values) 

31 if len(args) >= 2: 

32 x, values = args[0], args[1] 

33 else: 

34 # Fallback to kwargs if provided 

35 x = kwargs.get("input", kwargs.get("self", None)) 

36 values = kwargs.get("values", None) 

37 assert ( 

38 x is not None and values is not None 

39 ), "heaviside_ requires two arguments: input tensor and values." 

40 

41 # Ensure CUDA tensors 

42 assert x.is_cuda, "Input tensor must be on CUDA device." 

43 assert x.is_contiguous(), "Input tensor must be contiguous." 

44 

45 # Prepare values tensor (support scalar or tensor), broadcast to input shape and ensure same dtype/device 

46 if not torch.is_tensor(values): 

47 v_tensor = torch.as_tensor(values, device=x.device, dtype=x.dtype) 

48 else: 

49 v_tensor = values.to(device=x.device, dtype=x.dtype) 

50 

51 v_tensor = v_tensor.expand_as(x).contiguous() 

52 assert ( 

53 v_tensor.is_cuda and v_tensor.is_contiguous() 

54 ), "Values tensor must be CUDA and contiguous after expansion." 

55 

56 n_elements = x.numel() 

57 BLOCK_SIZE = 1024 

58 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

59 

60 heaviside__kernel[grid](x, v_tensor, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

61 return x