Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dropout.py: 0%
97 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils.random_utils import (
11 philox_backend_seed_offset,
12 uint_to_uniform_float,
13)
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18@triton.heuristics(runtime.get_heuristic_config("dropout"))
19@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
20def dropout_forward_kernel(
21 X,
22 Y,
23 dropout_mask,
24 N,
25 p,
26 philox_seed,
27 philox_offset,
28 BLOCK: tl.constexpr,
29):
30 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time
31 philox_seed = philox_seed.to(tl.int64)
32 philox_offset = philox_offset.to(tl.int64)
33 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
34 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
35 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
36 c0 += i4
37 _O = c0 * 0
38 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
39 r0 = uint_to_uniform_float(r0)
40 r1 = uint_to_uniform_float(r1)
41 r2 = uint_to_uniform_float(r2)
42 r3 = uint_to_uniform_float(r3)
44 mask0 = r0 > p
45 mask1 = r1 > p
46 mask2 = r2 > p
47 mask3 = r3 > p
48 p = 1.0 / (1.0 - p)
50 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK)
51 off_1 = off_0 + BLOCK
52 off_2 = off_1 + BLOCK
53 off_3 = off_2 + BLOCK
55 x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first")
56 x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first")
57 x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first")
58 x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first")
60 y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0)
61 y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0)
62 y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0)
63 y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0)
65 tl.store(dropout_mask + off_0, mask0, mask=off_0 < N, eviction_policy="evict_first")
66 tl.store(dropout_mask + off_1, mask1, mask=off_1 < N, eviction_policy="evict_first")
67 tl.store(dropout_mask + off_2, mask2, mask=off_2 < N, eviction_policy="evict_first")
68 tl.store(dropout_mask + off_3, mask3, mask=off_3 < N, eviction_policy="evict_first")
70 tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
71 tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
72 tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
73 tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
76# @triton.heuristics(runtime.get_heuristic_config("dropout"))
77@triton.jit(do_not_specialize=["scale"])
78def dropout_backward_kernel(
79 DY,
80 DX,
81 dropout_mask,
82 N,
83 scale,
84 BLOCK: tl.constexpr,
85):
86 offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
87 mask = offset < N
88 m = tl.load(
89 dropout_mask + offset, mask=mask, other=0, eviction_policy="evict_first"
90 )
91 dy = tl.load(DY + offset, mask=mask, other=0, eviction_policy="evict_first")
92 dx = dy * m * scale
93 store_offset = tl.where(mask, offset, -1)
94 tl.store(DX + store_offset, dx, mask=mask, eviction_policy="evict_first")
97UNROLL = 4
100def dropout(input, p, train=True):
101 logger.debug("GEMS NATIVE DROPOUT FORWARD")
102 if not train or p == 0:
103 out = input.clone()
104 mask = torch.ones_like(input, dtype=torch.bool)
105 return out, mask
106 if p == 1:
107 out = torch.zeros_like(input)
108 mask = torch.zeros_like(input, dtype=torch.bool)
109 return out, mask
110 assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
111 device = input.device
112 # TODO: remove contiguous enforcement
113 input = input.contiguous()
114 out = torch.empty_like(input)
115 mask = torch.empty_like(input, dtype=torch.bool)
116 N = input.numel()
117 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
118 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
119 # hence we cannot obtain the per thread offset as in Pytorch.
120 increment = triton.cdiv(N, UNROLL)
121 with torch_device_fn.device(device):
122 philox_seed, philox_offset = philox_backend_seed_offset(increment)
123 dropout_forward_kernel[grid_fn](
124 input, out, mask, N, p, philox_seed, philox_offset
125 )
126 return out, mask
129def dropout_backward(grad_output, mask, scale):
130 logger.debug("GEMS NATIVE DROPOUT BACKWARD")
131 grad_output = grad_output.contiguous()
132 grad_input = torch.empty_like(grad_output)
133 N = grad_output.numel()
134 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"]),)
136 os.environ["TRITONXPU_OTHER_SIM"] = "1"
137 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
139 with torch_device_fn.device(grad_output.device):
140 dropout_backward_kernel[grid_fn](
141 grad_output, grad_input, mask, N, scale, BLOCK=N
142 )
144 if "TRITONXPU_OTHER_SIM" in os.environ:
145 del os.environ["TRITONXPU_OTHER_SIM"]
146 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
147 del os.environ["TRITONXPU_STORE_MASK_SIM"]
148 return grad_input