Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dropout.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.random_utils import ( 

11 philox_backend_seed_offset, 

12 uint_to_uniform_float, 

13) 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18@triton.heuristics(runtime.get_heuristic_config("dropout")) 

19@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) 

20def dropout_forward_kernel( 

21 X, 

22 Y, 

23 dropout_mask, 

24 N, 

25 p, 

26 philox_seed, 

27 philox_offset, 

28 BLOCK: tl.constexpr, 

29): 

30 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time 

31 philox_seed = philox_seed.to(tl.int64) 

32 philox_offset = philox_offset.to(tl.int64) 

33 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

34 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

35 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

36 c0 += i4 

37 _O = c0 * 0 

38 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

39 r0 = uint_to_uniform_float(r0) 

40 r1 = uint_to_uniform_float(r1) 

41 r2 = uint_to_uniform_float(r2) 

42 r3 = uint_to_uniform_float(r3) 

43 

44 mask0 = r0 > p 

45 mask1 = r1 > p 

46 mask2 = r2 > p 

47 mask3 = r3 > p 

48 p = 1.0 / (1.0 - p) 

49 

50 off_0 = tl.program_id(0) * BLOCK * UNROLL + tl.arange(0, BLOCK) 

51 off_1 = off_0 + BLOCK 

52 off_2 = off_1 + BLOCK 

53 off_3 = off_2 + BLOCK 

54 

55 x0 = tl.load(X + off_0, mask=off_0 < N, other=0.0, eviction_policy="evict_first") 

56 x1 = tl.load(X + off_1, mask=off_1 < N, other=0.0, eviction_policy="evict_first") 

57 x2 = tl.load(X + off_2, mask=off_2 < N, other=0.0, eviction_policy="evict_first") 

58 x3 = tl.load(X + off_3, mask=off_3 < N, other=0.0, eviction_policy="evict_first") 

59 

60 y0 = x0 * p * mask0 # tl.where(mask0, x0 * p, 0.0) 

61 y1 = x1 * p * mask1 # tl.where(mask1, x1 * p, 0.0) 

62 y2 = x2 * p * mask2 # tl.where(mask2, x2 * p, 0.0) 

63 y3 = x3 * p * mask3 # tl.where(mask3, x3 * p, 0.0) 

64 

65 tl.store(dropout_mask + off_0, mask0, mask=off_0 < N, eviction_policy="evict_first") 

66 tl.store(dropout_mask + off_1, mask1, mask=off_1 < N, eviction_policy="evict_first") 

67 tl.store(dropout_mask + off_2, mask2, mask=off_2 < N, eviction_policy="evict_first") 

68 tl.store(dropout_mask + off_3, mask3, mask=off_3 < N, eviction_policy="evict_first") 

69 

70 tl.store(Y + off_0, y0, mask=off_0 < N, eviction_policy="evict_first") 

71 tl.store(Y + off_1, y1, mask=off_1 < N, eviction_policy="evict_first") 

72 tl.store(Y + off_2, y2, mask=off_2 < N, eviction_policy="evict_first") 

73 tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first") 

74 

75 

76# @triton.heuristics(runtime.get_heuristic_config("dropout")) 

77@triton.jit(do_not_specialize=["scale"]) 

78def dropout_backward_kernel( 

79 DY, 

80 DX, 

81 dropout_mask, 

82 N, 

83 scale, 

84 BLOCK: tl.constexpr, 

85): 

86 offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

87 mask = offset < N 

88 m = tl.load( 

89 dropout_mask + offset, mask=mask, other=0, eviction_policy="evict_first" 

90 ) 

91 dy = tl.load(DY + offset, mask=mask, other=0, eviction_policy="evict_first") 

92 dx = dy * m * scale 

93 store_offset = tl.where(mask, offset, -1) 

94 tl.store(DX + store_offset, dx, mask=mask, eviction_policy="evict_first") 

95 

96 

97UNROLL = 4 

98 

99 

100def dropout(input, p, train=True): 

101 logger.debug("GEMS NATIVE DROPOUT FORWARD") 

102 if not train or p == 0: 

103 out = input.clone() 

104 mask = torch.ones_like(input, dtype=torch.bool) 

105 return out, mask 

106 if p == 1: 

107 out = torch.zeros_like(input) 

108 mask = torch.zeros_like(input, dtype=torch.bool) 

109 return out, mask 

110 assert p > 0.0 and p < 1.0, "p must be in (0, 1)" 

111 device = input.device 

112 # TODO: remove contiguous enforcement 

113 input = input.contiguous() 

114 out = torch.empty_like(input) 

115 mask = torch.empty_like(input, dtype=torch.bool) 

116 N = input.numel() 

117 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

118 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

119 # hence we cannot obtain the per thread offset as in Pytorch. 

120 increment = triton.cdiv(N, UNROLL) 

121 with torch_device_fn.device(device): 

122 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

123 dropout_forward_kernel[grid_fn]( 

124 input, out, mask, N, p, philox_seed, philox_offset 

125 ) 

126 return out, mask 

127 

128 

129def dropout_backward(grad_output, mask, scale): 

130 logger.debug("GEMS NATIVE DROPOUT BACKWARD") 

131 grad_output = grad_output.contiguous() 

132 grad_input = torch.empty_like(grad_output) 

133 N = grad_output.numel() 

134 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"]),) 

135 

136 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

137 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

138 

139 with torch_device_fn.device(grad_output.device): 

140 dropout_backward_kernel[grid_fn]( 

141 grad_output, grad_input, mask, N, scale, BLOCK=N 

142 ) 

143 

144 if "TRITONXPU_OTHER_SIM" in os.environ: 

145 del os.environ["TRITONXPU_OTHER_SIM"] 

146 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

147 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

148 return grad_input