Coverage for src/flag_gems/experimental_ops/rrelu_with_noise_backward.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def rrelu_with_noise_backward_kernel( 

8 grad_out_ptr, # *Pointer* to grad_output 

9 input_ptr, # *Pointer* to input (or result if self_is_result, either works) 

10 noise_ptr, # *Pointer* to noise 

11 grad_in_ptr, # *Pointer* to output grad_input 

12 n_elements, # Number of elements 

13 lower, # float32 

14 upper, # float32 

15 training, # int32 (1 for training, 0 for eval) 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < n_elements 

22 

23 go = tl.load(grad_out_ptr + offsets, mask=mask, other=0) 

24 x = tl.load(input_ptr + offsets, mask=mask, other=0) 

25 nz = tl.load(noise_ptr + offsets, mask=mask, other=0) 

26 

27 go_f32 = go.to(tl.float32) 

28 x_f32 = x.to(tl.float32) 

29 nz_f32 = nz.to(tl.float32) 

30 

31 slope = (lower + upper) * 0.5 

32 

33 grad_train = go_f32 * nz_f32 

34 grad_eval = go_f32 * tl.where(x_f32 > 0, 1.0, slope) 

35 

36 cond = tl.full(go_f32.shape, training, tl.int1) 

37 grad_f32 = tl.where(cond, grad_train, grad_eval) 

38 

39 grad_cast = grad_f32.to(go.dtype) 

40 tl.store(grad_in_ptr + offsets, grad_cast, mask=mask) 

41 

42 

43def _launch_rrelu_with_noise_backward( 

44 grad_output: torch.Tensor, 

45 input: torch.Tensor, 

46 noise: torch.Tensor, 

47 lower: float, 

48 upper: float, 

49 training: bool, 

50 out: torch.Tensor, 

51): 

52 assert ( 

53 grad_output.is_cuda and input.is_cuda and noise.is_cuda and out.is_cuda 

54 ), "All tensors must be CUDA" 

55 assert ( 

56 grad_output.numel() == input.numel() == noise.numel() == out.numel() 

57 ), "All tensors must have the same number of elements" 

58 assert ( 

59 grad_output.dtype == input.dtype == noise.dtype == out.dtype 

60 ), "All tensors must have the same dtype" 

61 

62 go = grad_output.contiguous() 

63 x = input.contiguous() 

64 nz = noise.contiguous() 

65 out_t = out.contiguous() 

66 

67 n_elements = out_t.numel() 

68 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

69 rrelu_with_noise_backward_kernel[grid]( 

70 go, 

71 x, 

72 nz, 

73 out_t, 

74 n_elements, 

75 float(lower), 

76 float(upper), 

77 1 if training else 0, 

78 BLOCK_SIZE=1024, 

79 ) 

80 if out is not out_t: 

81 out.copy_(out_t) 

82 return out 

83 

84 

85def rrelu_with_noise_backward( 

86 grad_output: torch.Tensor, 

87 input: torch.Tensor, 

88 noise: torch.Tensor, 

89 lower: float, 

90 upper: float, 

91 training: bool, 

92 self_is_result: bool = False, 

93): 

94 out = torch.empty_like(grad_output) 

95 return _launch_rrelu_with_noise_backward( 

96 grad_output, input, noise, lower, upper, training, out 

97 ) 

98 

99 

100def rrelu_with_noise_backward_out( 

101 grad_output: torch.Tensor, 

102 input: torch.Tensor, 

103 noise: torch.Tensor, 

104 lower: float, 

105 upper: float, 

106 training: bool, 

107 self_is_result: bool, 

108 out: torch.Tensor, 

109): 

110 return _launch_rrelu_with_noise_backward( 

111 grad_output, input, noise, lower, upper, training, out 

112 )