Coverage for src/flag_gems/ops/dropout.py: 44%
90 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
14logger = logging.getLogger(__name__)
17@triton.heuristics(runtime.get_heuristic_config("dropout"))
18@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
19def dropout_forward_kernel(
20 X,
21 Y,
22 dropout_mask,
23 N,
24 p,
25 philox_seed,
26 philox_offset,
27 BLOCK: tl.constexpr,
28):
29 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time
30 philox_seed = philox_seed.to(tl.int64)
31 philox_offset = philox_offset.to(tl.int64)
32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
34 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
35 c0 += i4
36 _O = c0 * 0
37 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
38 r0 = uint_to_uniform_float(r0)
39 r1 = uint_to_uniform_float(r1)
40 r2 = uint_to_uniform_float(r2)
41 r3 = uint_to_uniform_float(r3)
43 mask0 = r0 > p
44 mask1 = r1 > p
45 mask2 = r2 > p
46 mask3 = r3 > p
47 p = 1.0 / (1.0 - p)
49 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
50 off_1 = off_0 + BLOCK
51 off_2 = off_1 + BLOCK
52 off_3 = off_2 + BLOCK
54 x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first")
55 x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first")
56 x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first")
57 x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first")
59 y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0)
60 y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0)
61 y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0)
62 y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0)
64 tl.store(dropout_mask + off_0, mask0, mask=off_0 < N, eviction_policy="evict_first")
65 tl.store(dropout_mask + off_1, mask1, mask=off_1 < N, eviction_policy="evict_first")
66 tl.store(dropout_mask + off_2, mask2, mask=off_2 < N, eviction_policy="evict_first")
67 tl.store(dropout_mask + off_3, mask3, mask=off_3 < N, eviction_policy="evict_first")
69 tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
70 tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
71 tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
72 tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
75@triton.heuristics(runtime.get_heuristic_config("dropout"))
76@triton.jit(do_not_specialize=["scale"])
77def dropout_backward_kernel(
78 DY,
79 DX,
80 dropout_mask,
81 N,
82 scale,
83 BLOCK: tl.constexpr,
84):
85 offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
86 mask = offset < N
87 m = tl.load(
88 dropout_mask + offset, mask=mask, other=0, eviction_policy="evict_first"
89 )
90 dy = tl.load(DY + offset, mask=mask, other=0, eviction_policy="evict_first")
91 dx = dy * m * scale
92 tl.store(DX + offset, dx, mask=mask, eviction_policy="evict_first")
95UNROLL = 4
98def dropout(input, p, train=True):
99 logger.debug("GEMS NATIVE DROPOUT FORWARD")
100 if not train or p == 0:
101 out = input.clone()
102 mask = torch.ones_like(input, dtype=torch.bool)
103 return out, mask
104 if p == 1:
105 out = torch.zeros_like(input)
106 mask = torch.zeros_like(input, dtype=torch.bool)
107 return out, mask
108 assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
109 device = input.device
110 # TODO: remove contiguous enforcement
111 input = input.contiguous()
112 out = torch.empty_like(input)
113 mask = torch.empty_like(input, dtype=torch.bool)
114 N = input.numel()
115 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
116 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
117 # hence we cannot obtain the per thread offset as in Pytorch.
118 increment = triton.cdiv(N, UNROLL)
119 with torch_device_fn.device(device):
120 philox_seed, philox_offset = philox_backend_seed_offset(increment)
121 dropout_forward_kernel[grid_fn](
122 input, out, mask, N, p, philox_seed, philox_offset
123 )
124 return out, mask
127def dropout_backward(grad_output, mask, scale):
128 logger.debug("GEMS NATIVE DROPOUT BACKWARD")
129 grad_output = grad_output.contiguous()
130 grad_input = torch.empty_like(grad_output)
131 N = grad_output.numel()
132 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"]),)
133 with torch_device_fn.device(grad_output.device):
134 dropout_backward_kernel[grid_fn](grad_output, grad_input, mask, N, scale)
135 return grad_input