Coverage for src/flag_gems/runtime/backend/_cambricon/ops/dropout.py: 0%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import torch_mlu # noqa: F401 

5import triton 

6import triton.language as tl 

7from triton.language.extra.mlu.libdevice import philox as _philox 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils.random_utils import ( 

12 philox_backend_seed_offset, 

13 uint_to_uniform_float, 

14) 

15 

16from ..utils import TOTAL_CORE_NUM 

17 

18logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

19 

20 

21@triton.heuristics(runtime.get_heuristic_config("dropout")) 

22@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"]) 

23def dropout_forward_kernel( 

24 X, 

25 Y, 

26 dropout_mask, 

27 N, 

28 p, 

29 philox_seed, 

30 philox_offset, 

31 BLOCK: tl.constexpr, 

32): 

33 UNROLL: tl.constexpr = 4 # philox generate 128 random bits at a time 

34 philox_seed = philox_seed.to(tl.int64) 

35 philox_offset = philox_offset.to(tl.int64) 

36 

37 pid = tl.program_id(0) 

38 num_jobs = tl.num_programs(0) 

39 i4_start = pid * BLOCK 

40 block_start = pid * UNROLL * BLOCK 

41 step = num_jobs * BLOCK * UNROLL 

42 mp = 1.0 / (1.0 - p) 

43 

44 for block_offset in range(block_start, N, step): 

45 sl = (philox_seed & 0xFFFFFFFF).to(tl.uint32) 

46 sh = ((philox_seed >> 32) & 0xFFFFFFFF).to(tl.uint32) 

47 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

48 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

49 r = _philox(BLOCK, sl, sh, c0 + i4_start, c1, 0, 0, 10) 

50 r = uint_to_uniform_float(r) 

51 

52 mask = r > p 

53 

54 off = block_offset + tl.arange(0, UNROLL * BLOCK) 

55 x = tl.load(X + off, mask=off < N, other=0.0) 

56 y = ( 

57 x * mp * tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True) 

58 ) # tl.where(mask0, x0 * p, 0.0) 

59 mask_reshaped = tl.reshape(mask, [UNROLL * BLOCK], can_reorder=True) 

60 tl.store(dropout_mask + off, mask_reshaped, mask=off < N) 

61 tl.store(Y + off, y, mask=off < N) 

62 i4_start += num_jobs * BLOCK 

63 

64 

65@triton.heuristics(runtime.get_heuristic_config("dropout")) 

66@triton.jit(do_not_specialize=["scale"]) 

67def dropout_backward_kernel( 

68 DY, 

69 DX, 

70 dropout_mask, 

71 N, 

72 scale, 

73 BLOCK: tl.constexpr, 

74): 

75 UNROLL: tl.constexpr = 4 

76 

77 pid = tl.program_id(0) 

78 num_programs = tl.num_programs(0) 

79 block_start = pid * UNROLL * BLOCK 

80 step = num_programs * UNROLL * BLOCK 

81 for block_offset in range(block_start, N, step): 

82 off = block_offset + tl.arange(0, UNROLL * BLOCK) 

83 mask = tl.load( 

84 dropout_mask + off, mask=off < N, other=0, eviction_policy="evict_first" 

85 ) 

86 dy = tl.load(DY + off, mask=off < N, other=0.0, eviction_policy="evict_first") 

87 dx = dy * mask * scale 

88 

89 tl.store(DX + off, dx, mask=off < N, eviction_policy="evict_first") 

90 

91 

92UNROLL = 4 

93 

94 

95def dropout(input, p, train=True): 

96 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT FORWARD") 

97 if not train or p == 0: 

98 out = input.clone() 

99 mask = torch.ones_like(input, dtype=torch.bool) 

100 return out, mask 

101 if p == 1: 

102 out = torch.zeros_like(input) 

103 mask = torch.zeros_like(input, dtype=torch.bool) 

104 return out, mask 

105 assert p > 0.0 and p < 1.0, "p must be in (0, 1)" 

106 device = input.device 

107 # TODO: remove contiguous enforcement 

108 input = input.contiguous() 

109 out = torch.empty_like(input) 

110 mask = torch.empty_like(input, dtype=torch.bool) 

111 N = input.numel() 

112 grid_fn = lambda meta: ( 

113 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

114 ) 

115 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller, 

116 # hence we cannot obtain the per thread offset as in Pytorch. 

117 increment = triton.cdiv(N, UNROLL) 

118 with torch_device_fn.device(device): 

119 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

120 dropout_forward_kernel[grid_fn]( 

121 input, 

122 out, 

123 mask, 

124 N, 

125 p, 

126 philox_seed, 

127 philox_offset, 

128 num_warps=1, 

129 num_stages=3, 

130 ) 

131 return out, mask 

132 

133 

134def dropout_backward(grad_output, mask, scale): 

135 logger.debug("GEMS_CAMBRICON NATIVE DROPOUT BACKWARD") 

136 grad_output = grad_output.contiguous() 

137 grad_input = torch.empty_like(grad_output) 

138 N = grad_output.numel() 

139 grid_fn = lambda meta: ( 

140 min(triton.cdiv(N, meta["BLOCK"] * UNROLL), TOTAL_CORE_NUM), 

141 ) 

142 with torch_device_fn.device(grad_output.device): 

143 dropout_backward_kernel[grid_fn]( 

144 grad_output, 

145 grad_input, 

146 mask, 

147 N, 

148 scale, 

149 num_stages=3, 

150 num_warps=1, 

151 ) 

152 return grad_input