Coverage for src/flag_gems/ops/rrelu_with_noise_backward.py: 56%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def rrelu_with_noise_backward_kernel( 

15 grad_out_ptr, 

16 input_ptr, 

17 noise_ptr, 

18 grad_in_ptr, 

19 n_elements, 

20 lower, 

21 upper, 

22 training, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tl.program_id(axis=0) 

26 block_start = pid * BLOCK_SIZE 

27 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

28 mask = offsets < n_elements 

29 

30 go = tl.load(grad_out_ptr + offsets, mask=mask, other=0) 

31 x = tl.load(input_ptr + offsets, mask=mask, other=0) 

32 nz = tl.load(noise_ptr + offsets, mask=mask, other=0) 

33 

34 go_f32 = go.to(tl.float32) 

35 x_f32 = x.to(tl.float32) 

36 nz_f32 = nz.to(tl.float32) 

37 

38 slope = (lower + upper) * 0.5 

39 

40 grad_train = go_f32 * nz_f32 

41 grad_eval = go_f32 * tl.where(x_f32 > 0, 1.0, slope) 

42 

43 cond = tl.full(go_f32.shape, training, tl.int1) 

44 grad_f32 = tl.where(cond, grad_train, grad_eval) 

45 

46 grad_cast = grad_f32.to(go.dtype) 

47 tl.store(grad_in_ptr + offsets, grad_cast, mask=mask) 

48 

49 

50def _launch_rrelu_with_noise_backward( 

51 grad_output: torch.Tensor, 

52 input: torch.Tensor, 

53 noise: torch.Tensor, 

54 lower: float, 

55 upper: float, 

56 training: bool, 

57 out: torch.Tensor, 

58): 

59 go = grad_output.contiguous() 

60 x = input.contiguous() 

61 nz = noise.contiguous() 

62 out_t = out.contiguous() 

63 

64 n_elements = out_t.numel() 

65 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

66 with torch_device_fn.device(grad_output.device): 

67 rrelu_with_noise_backward_kernel[grid]( 

68 go, 

69 x, 

70 nz, 

71 out_t, 

72 n_elements, 

73 float(lower), 

74 float(upper), 

75 1 if training else 0, 

76 BLOCK_SIZE=1024, 

77 ) 

78 if out is not out_t: 

79 out.copy_(out_t) 

80 return out 

81 

82 

83def rrelu_with_noise_backward( 

84 grad_output: torch.Tensor, 

85 input: torch.Tensor, 

86 noise: torch.Tensor, 

87 lower: float, 

88 upper: float, 

89 training: bool, 

90 self_is_result: bool = False, 

91): 

92 logger.debug("GEMS RRELU_WITH_NOISE_BACKWARD") 

93 out = torch.empty_like(grad_output) 

94 return _launch_rrelu_with_noise_backward( 

95 grad_output, input, noise, lower, upper, training, out 

96 )