Coverage for src/flag_gems/experimental_ops/native_dropout_backward.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _native_dropout_backward_kernel( 

8 grad_ptr, # *Pointer* to grad_output tensor 

9 mask_ptr, # *Pointer* to mask tensor (cast to same dtype as grad) 

10 out_ptr, # *Pointer* to output grad_input tensor 

11 n_elements, # Number of elements 

12 scale, # Scaling factor (float) 

13 BLOCK_SIZE: tl.constexpr, 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 in_bounds = offsets < n_elements 

19 

20 g = tl.load(grad_ptr + offsets, mask=in_bounds, other=0) 

21 m = tl.load(mask_ptr + offsets, mask=in_bounds, other=0) 

22 

23 # grad_input = grad_output * mask * scale 

24 out = g * m * scale 

25 tl.store(out_ptr + offsets, out, mask=in_bounds) 

26 

27 

28def _launch_native_dropout_backward( 

29 grad_output: torch.Tensor, mask: torch.Tensor, scale: float, out: torch.Tensor 

30): 

31 assert ( 

32 grad_output.is_cuda and mask.is_cuda and out.is_cuda 

33 ), "All tensors must be CUDA tensors" 

34 assert ( 

35 grad_output.numel() == mask.numel() == out.numel() 

36 ), "grad_output, mask, and out must have the same number of elements" 

37 assert grad_output.dtype in ( 

38 torch.float16, 

39 torch.bfloat16, 

40 torch.float32, 

41 ), "Supported dtypes: float16, bfloat16, float32" 

42 assert out.dtype == grad_output.dtype, "Output dtype must match grad_output dtype" 

43 assert ( 

44 grad_output.device == mask.device == out.device 

45 ), "All tensors must be on the same device" 

46 

47 go = grad_output.contiguous() 

48 m = mask.contiguous() 

49 if m.dtype != go.dtype: 

50 m = m.to(dtype=go.dtype) 

51 

52 out_contig = out if out.is_contiguous() else torch.empty_like(go) 

53 

54 n_elements = go.numel() 

55 BLOCK_SIZE = 1024 

56 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

57 _native_dropout_backward_kernel[grid]( 

58 go, m, out_contig, n_elements, float(scale), BLOCK_SIZE=BLOCK_SIZE 

59 ) 

60 

61 if out_contig.data_ptr() != out.data_ptr(): 

62 out.copy_(out_contig) 

63 return out 

64 

65 

66def native_dropout_backward( 

67 grad_output: torch.Tensor, mask: torch.Tensor, scale: float 

68): 

69 """ 

70 Wrapper for aten::native_dropout_backward 

71 Computes grad_input = grad_output * mask.to(grad_output.dtype) * scale 

72 """ 

73 out = torch.empty_like(grad_output) 

74 return _launch_native_dropout_backward(grad_output, mask, scale, out) 

75 

76 

77def native_dropout_backward_out( 

78 grad_output: torch.Tensor, mask: torch.Tensor, scale: float, out: torch.Tensor 

79): 

80 """ 

81 Wrapper for aten::native_dropout_backward.out 

82 Writes result into 'out' 

83 """ 

84 _launch_native_dropout_backward(grad_output, mask, scale, out) 

85 return out