Coverage for src/flag_gems/ops/rrelu_with_noise_backward.py: 56%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def rrelu_with_noise_backward_kernel(
15 grad_out_ptr,
16 input_ptr,
17 noise_ptr,
18 grad_in_ptr,
19 n_elements,
20 lower,
21 upper,
22 training,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tl.program_id(axis=0)
26 block_start = pid * BLOCK_SIZE
27 offsets = block_start + tl.arange(0, BLOCK_SIZE)
28 mask = offsets < n_elements
30 go = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
31 x = tl.load(input_ptr + offsets, mask=mask, other=0)
32 nz = tl.load(noise_ptr + offsets, mask=mask, other=0)
34 go_f32 = go.to(tl.float32)
35 x_f32 = x.to(tl.float32)
36 nz_f32 = nz.to(tl.float32)
38 slope = (lower + upper) * 0.5
40 grad_train = go_f32 * nz_f32
41 grad_eval = go_f32 * tl.where(x_f32 > 0, 1.0, slope)
43 cond = tl.full(go_f32.shape, training, tl.int1)
44 grad_f32 = tl.where(cond, grad_train, grad_eval)
46 grad_cast = grad_f32.to(go.dtype)
47 tl.store(grad_in_ptr + offsets, grad_cast, mask=mask)
50def _launch_rrelu_with_noise_backward(
51 grad_output: torch.Tensor,
52 input: torch.Tensor,
53 noise: torch.Tensor,
54 lower: float,
55 upper: float,
56 training: bool,
57 out: torch.Tensor,
58):
59 go = grad_output.contiguous()
60 x = input.contiguous()
61 nz = noise.contiguous()
62 out_t = out.contiguous()
64 n_elements = out_t.numel()
65 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
66 with torch_device_fn.device(grad_output.device):
67 rrelu_with_noise_backward_kernel[grid](
68 go,
69 x,
70 nz,
71 out_t,
72 n_elements,
73 float(lower),
74 float(upper),
75 1 if training else 0,
76 BLOCK_SIZE=1024,
77 )
78 if out is not out_t:
79 out.copy_(out_t)
80 return out
83def rrelu_with_noise_backward(
84 grad_output: torch.Tensor,
85 input: torch.Tensor,
86 noise: torch.Tensor,
87 lower: float,
88 upper: float,
89 training: bool,
90 self_is_result: bool = False,
91):
92 logger.debug("GEMS RRELU_WITH_NOISE_BACKWARD")
93 out = torch.empty_like(grad_output)
94 return _launch_rrelu_with_noise_backward(
95 grad_output, input, noise, lower, upper, training, out
96 )