Coverage for src/flag_gems/experimental_ops/rrelu_with_noise_backward.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def rrelu_with_noise_backward_kernel(
8 grad_out_ptr, # *Pointer* to grad_output
9 input_ptr, # *Pointer* to input (or result if self_is_result, either works)
10 noise_ptr, # *Pointer* to noise
11 grad_in_ptr, # *Pointer* to output grad_input
12 n_elements, # Number of elements
13 lower, # float32
14 upper, # float32
15 training, # int32 (1 for training, 0 for eval)
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < n_elements
23 go = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
24 x = tl.load(input_ptr + offsets, mask=mask, other=0)
25 nz = tl.load(noise_ptr + offsets, mask=mask, other=0)
27 go_f32 = go.to(tl.float32)
28 x_f32 = x.to(tl.float32)
29 nz_f32 = nz.to(tl.float32)
31 slope = (lower + upper) * 0.5
33 grad_train = go_f32 * nz_f32
34 grad_eval = go_f32 * tl.where(x_f32 > 0, 1.0, slope)
36 cond = tl.full(go_f32.shape, training, tl.int1)
37 grad_f32 = tl.where(cond, grad_train, grad_eval)
39 grad_cast = grad_f32.to(go.dtype)
40 tl.store(grad_in_ptr + offsets, grad_cast, mask=mask)
43def _launch_rrelu_with_noise_backward(
44 grad_output: torch.Tensor,
45 input: torch.Tensor,
46 noise: torch.Tensor,
47 lower: float,
48 upper: float,
49 training: bool,
50 out: torch.Tensor,
51):
52 assert (
53 grad_output.is_cuda and input.is_cuda and noise.is_cuda and out.is_cuda
54 ), "All tensors must be CUDA"
55 assert (
56 grad_output.numel() == input.numel() == noise.numel() == out.numel()
57 ), "All tensors must have the same number of elements"
58 assert (
59 grad_output.dtype == input.dtype == noise.dtype == out.dtype
60 ), "All tensors must have the same dtype"
62 go = grad_output.contiguous()
63 x = input.contiguous()
64 nz = noise.contiguous()
65 out_t = out.contiguous()
67 n_elements = out_t.numel()
68 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
69 rrelu_with_noise_backward_kernel[grid](
70 go,
71 x,
72 nz,
73 out_t,
74 n_elements,
75 float(lower),
76 float(upper),
77 1 if training else 0,
78 BLOCK_SIZE=1024,
79 )
80 if out is not out_t:
81 out.copy_(out_t)
82 return out
85def rrelu_with_noise_backward(
86 grad_output: torch.Tensor,
87 input: torch.Tensor,
88 noise: torch.Tensor,
89 lower: float,
90 upper: float,
91 training: bool,
92 self_is_result: bool = False,
93):
94 out = torch.empty_like(grad_output)
95 return _launch_rrelu_with_noise_backward(
96 grad_output, input, noise, lower, upper, training, out
97 )
100def rrelu_with_noise_backward_out(
101 grad_output: torch.Tensor,
102 input: torch.Tensor,
103 noise: torch.Tensor,
104 lower: float,
105 upper: float,
106 training: bool,
107 self_is_result: bool,
108 out: torch.Tensor,
109):
110 return _launch_rrelu_with_noise_backward(
111 grad_output, input, noise, lower, upper, training, out
112 )