Coverage for src/flag_gems/experimental_ops/native_dropout_backward.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _native_dropout_backward_kernel(
8 grad_ptr, # *Pointer* to grad_output tensor
9 mask_ptr, # *Pointer* to mask tensor (cast to same dtype as grad)
10 out_ptr, # *Pointer* to output grad_input tensor
11 n_elements, # Number of elements
12 scale, # Scaling factor (float)
13 BLOCK_SIZE: tl.constexpr,
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 in_bounds = offsets < n_elements
20 g = tl.load(grad_ptr + offsets, mask=in_bounds, other=0)
21 m = tl.load(mask_ptr + offsets, mask=in_bounds, other=0)
23 # grad_input = grad_output * mask * scale
24 out = g * m * scale
25 tl.store(out_ptr + offsets, out, mask=in_bounds)
28def _launch_native_dropout_backward(
29 grad_output: torch.Tensor, mask: torch.Tensor, scale: float, out: torch.Tensor
30):
31 assert (
32 grad_output.is_cuda and mask.is_cuda and out.is_cuda
33 ), "All tensors must be CUDA tensors"
34 assert (
35 grad_output.numel() == mask.numel() == out.numel()
36 ), "grad_output, mask, and out must have the same number of elements"
37 assert grad_output.dtype in (
38 torch.float16,
39 torch.bfloat16,
40 torch.float32,
41 ), "Supported dtypes: float16, bfloat16, float32"
42 assert out.dtype == grad_output.dtype, "Output dtype must match grad_output dtype"
43 assert (
44 grad_output.device == mask.device == out.device
45 ), "All tensors must be on the same device"
47 go = grad_output.contiguous()
48 m = mask.contiguous()
49 if m.dtype != go.dtype:
50 m = m.to(dtype=go.dtype)
52 out_contig = out if out.is_contiguous() else torch.empty_like(go)
54 n_elements = go.numel()
55 BLOCK_SIZE = 1024
56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
57 _native_dropout_backward_kernel[grid](
58 go, m, out_contig, n_elements, float(scale), BLOCK_SIZE=BLOCK_SIZE
59 )
61 if out_contig.data_ptr() != out.data_ptr():
62 out.copy_(out_contig)
63 return out
66def native_dropout_backward(
67 grad_output: torch.Tensor, mask: torch.Tensor, scale: float
68):
69 """
70 Wrapper for aten::native_dropout_backward
71 Computes grad_input = grad_output * mask.to(grad_output.dtype) * scale
72 """
73 out = torch.empty_like(grad_output)
74 return _launch_native_dropout_backward(grad_output, mask, scale, out)
77def native_dropout_backward_out(
78 grad_output: torch.Tensor, mask: torch.Tensor, scale: float, out: torch.Tensor
79):
80 """
81 Wrapper for aten::native_dropout_backward.out
82 Writes result into 'out'
83 """
84 _launch_native_dropout_backward(grad_output, mask, scale, out)
85 return out