Coverage for src/flag_gems/experimental_ops/heaviside.py: 0%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _heaviside_kernel(x_ptr, v_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 v = tl.load(v_ptr + offsets, mask=mask) 

15 

16 zeros = x - x 

17 ones = zeros + 1 

18 is_pos = x > zeros 

19 is_zero = x == zeros 

20 out = tl.where(is_zero, v, tl.where(is_pos, ones, zeros)) 

21 

22 tl.store(out_ptr + offsets, out, mask=mask) 

23 

24 

25def heaviside(input, values): 

26 # Prepare tensors 

27 if not isinstance(values, torch.Tensor): 

28 values = torch.as_tensor(values, device=input.device) 

29 # Broadcast 

30 x_b, v_b = torch.broadcast_tensors(input, values) 

31 # Dtype promotion 

32 out_dtype = torch.result_type(x_b, v_b) 

33 x = x_b.to(dtype=out_dtype).contiguous() 

34 v = v_b.to(dtype=out_dtype).contiguous() 

35 

36 # Allocate output 

37 out = torch.empty_like(x) 

38 

39 # Launch kernel 

40 n_elements = out.numel() 

41 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

42 _heaviside_kernel[grid](x, v, out, n_elements, BLOCK_SIZE=1024) 

43 return out 

44 

45 

46def heaviside_out(input, values, out): 

47 # Prepare tensors 

48 if not isinstance(values, torch.Tensor): 

49 values = torch.as_tensor(values, device=input.device) 

50 # Broadcast 

51 x_b, v_b = torch.broadcast_tensors(input, values) 

52 # Dtype promotion 

53 expected_dtype = torch.result_type(x_b, v_b) 

54 expected_shape = x_b.shape 

55 device = x_b.device 

56 # Check output tensor 

57 if out.device != device: 

58 raise ValueError("out tensor device must match input device") 

59 if out.dtype != expected_dtype: 

60 raise ValueError("out tensor dtype must be the result type of input and values") 

61 if out.shape != expected_shape: 

62 raise ValueError( 

63 "out tensor shape must be the broadcasted shape of input and values" 

64 ) 

65 

66 x = x_b.to(dtype=expected_dtype).contiguous() 

67 v = v_b.to(dtype=expected_dtype).contiguous() 

68 

69 # If out is contiguous, write directly; otherwise use a temp and copy 

70 if out.is_contiguous(): 

71 target = out 

72 else: 

73 target = torch.empty_like(out, memory_format=torch.contiguous_format) 

74 

75 n_elements = target.numel() 

76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

77 _heaviside_kernel[grid](x, v, target, n_elements, BLOCK_SIZE=1024) 

78 

79 if target is not out: 

80 out.copy_(target) 

81 return out