Coverage for src/flag_gems/runtime/backend/_cambricon/ops/dropout.py: 0%
84 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import logging
3import torch
4import torch_mlu # noqa: F401
5import triton
6import triton.language as tl
7from triton.language.extra.mlu.libdevice import philox as _philox
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils.random_utils import (
12 philox_backend_seed_offset,
13 uint_to_uniform_float,
14)
16from ..utils import TOTAL_CORE_NUM
18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
21@triton.heuristics(runtime.get_heuristic_config("dropout"))
22@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
23def dropout_forward_kernel(
24 X,
25 Y,
26 dropout_mask,
27 N,
28 p,
29 philox_seed,
30 philox_offset,
31 BLOCK: tl.constexpr,
32):
33 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time
34 philox_seed = philox_seed.to(tl.int64)
35 philox_offset = philox_offset.to(tl.int64)
37 pid = tl.program_id(0)
38 num_jobs = tl.num_programs(0)
39 i4_start = pid * BLOCK
40 block_start = pid * UNROLL * BLOCK
41 step = num_jobs * BLOCK * UNROLL
42 mp = 1.0 / (1.0 - p)
44 for block_offset in range(block_start, N, step):
45 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32)
46 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32)
47 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
48 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
49 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10)
50 r = uint_to_uniform_float(r)
52 mask = r > p
54 off = block_offset + tl.arange(0, UNROLL * BLOCK)
55 x = tl.load(X + off, mask=off < N, other=0.0)
56 y = (
57 x * mp * tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True)
58 ) # tl.where(mask0, x0 * p, 0.0)
59 mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True)
60 tl.store(dropout_mask + off, mask_reshaped, mask=off < N)
61 tl.store(Y + off, y, mask=off < N)
62 i4_start += num_jobs * BLOCK
65@triton.heuristics(runtime.get_heuristic_config("dropout"))
66@triton.jit(do_not_specialize=["scale"])
67def dropout_backward_kernel(
68 DY,
69 DX,
70 dropout_mask,
71 N,
72 scale,
73 BLOCK: tl.constexpr,
74):
75 UNROLL: tl.constexpr = 4
77 pid = tl.program_id(0)
78 num_programs = tl.num_programs(0)
79 block_start = pid * UNROLL * BLOCK
80 step = num_programs * UNROLL * BLOCK
81 for block_offset in range(block_start, N, step):
82 off = block_offset + tl.arange(0, UNROLL * BLOCK)
83 mask = tl.load(
84 dropout_mask + off, mask=off < N, other=0, eviction_policy="evict_first"
85 )
86 dy = tl.load(DY + off, mask=off < N, other=0.0, eviction_policy="evict_first")
87 dx = dy * mask * scale
89 tl.store(DX + off, dx, mask=off < N, eviction_policy="evict_first")
92UNROLL = 4
95def dropout(input, p, train=True):
96 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT FORWARD")
97 if not train or p == 0:
98 out = input.clone()
99 mask = torch.ones_like(input, dtype=torch.bool)
100 return out, mask
101 if p == 1:
102 out = torch.zeros_like(input)
103 mask = torch.zeros_like(input, dtype=torch.bool)
104 return out, mask
105 assert p > 0.0 and p < 1.0, "p must be in (0, 1)"
106 device = input.device
107 # TODO: remove contiguous enforcement
108 input = input.contiguous()
109 out = torch.empty_like(input)
110 mask = torch.empty_like(input, dtype=torch.bool)
111 N = input.numel()
112 grid_fn = lambda meta: (
113 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
114 )
115 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
116 # hence we cannot obtain the per thread offset as in Pytorch.
117 increment = triton.cdiv(N, UNROLL)
118 with torch_device_fn.device(device):
119 philox_seed, philox_offset = philox_backend_seed_offset(increment)
120 dropout_forward_kernel[grid_fn](
121 input,
122 out,
123 mask,
124 N,
125 p,
126 philox_seed,
127 philox_offset,
128 num_warps=1,
129 num_stages=3,
130 )
131 return out, mask
134def dropout_backward(grad_output, mask, scale):
135 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT BACKWARD")
136 grad_output = grad_output.contiguous()
137 grad_input = torch.empty_like(grad_output)
138 N = grad_output.numel()
139 grid_fn = lambda meta: (
140 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM),
141 )
142 with torch_device_fn.device(grad_output.device):
143 dropout_backward_kernel[grid_fn](
144 grad_output,
145 grad_input,
146 mask,
147 N,
148 scale,
149 num_stages=3,
150 num_warps=1,
151 )
152 return grad_input