Coverage for src/flag_gems/experimental_ops/_log_softmax_backward_data.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import math 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def _log_softmax_bwd_kernel( 

10 grad_ptr, # pointer to grad_output (dL/dy) 

11 y_logsm_ptr, # pointer to output of log_softmax (y = log_softmax(x)) 

12 grad_in_ptr, # pointer to grad_input (dL/dx) 

13 M, # number of rows (product of all dims except reduction dim) 

14 K, # length of the reduction dimension 

15 BLOCK_SIZE: tl.constexpr, 

16): 

17 pid = tl.program_id(axis=0) 

18 row_start = pid * K 

19 offs = tl.arange(0, BLOCK_SIZE) 

20 

21 # First pass: compute s = sum_j grad_output[j] along dim 

22 s = tl.zeros((), dtype=tl.float32) 

23 num_chunks = tl.cdiv(K, BLOCK_SIZE) 

24 for chunk in range(0, num_chunks): 

25 cols = chunk * BLOCK_SIZE + offs 

26 mask = cols < K 

27 go_chunk = tl.load(grad_ptr + row_start + cols, mask=mask, other=0.0) 

28 go32 = go_chunk.to(tl.float32) 

29 s += tl.sum(go32, axis=0) 

30 

31 # Second pass: grad_input = grad_output - exp(output) * s 

32 for chunk in range(0, num_chunks): 

33 cols = chunk * BLOCK_SIZE + offs 

34 mask = cols < K 

35 go_chunk = tl.load(grad_ptr + row_start + cols, mask=mask, other=0.0) 

36 go32 = go_chunk.to(tl.float32) 

37 y_chunk = tl.load(y_logsm_ptr + row_start + cols, mask=mask, other=0.0) 

38 y32 = y_chunk.to(tl.float32) 

39 sm = tl.exp(y32) 

40 gi32 = go32 - sm * s 

41 gi = gi32.to(go_chunk.dtype) 

42 tl.store(grad_in_ptr + row_start + cols, gi, mask=mask) 

43 

44 

45def _normalize_dim(dim: int, ndim: int) -> int: 

46 if dim < 0: 

47 dim += ndim 

48 return dim 

49 

50 

51def _choose_block_size(K: int) -> int: 

52 if K <= 1: 

53 return 1 

54 bs = 1 << (int(math.ceil(math.log2(K)))) 

55 return min(1024, max(1, bs)) 

56 

57 

58def _log_softmax_backward_data_impl( 

59 grad_output: torch.Tensor, output: torch.Tensor, dim: int 

60): 

61 assert ( 

62 grad_output.shape == output.shape 

63 ), "grad_output and output must have the same shape" 

64 assert ( 

65 grad_output.device.type == "cuda" and output.device.type == "cuda" 

66 ), "Inputs must be CUDA tensors" 

67 assert grad_output.device == output.device, "Inputs must be on the same device" 

68 assert grad_output.dtype == output.dtype, "Inputs must have the same dtype" 

69 assert ( 

70 grad_output.is_contiguous(memory_format=torch.contiguous_format) 

71 == grad_output.is_contiguous() 

72 ), "Unsupported memory format" 

73 

74 if grad_output.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

75 # Fallback for unsupported dtype (e.g., float64), compute via PyTorch 

76 # grad_input = grad_output - exp(output) * sum(grad_output, dim=dim, keepdim=True) 

77 s = grad_output.sum(dim=dim, keepdim=True) 

78 return grad_output - output.exp() * s 

79 

80 dim = _normalize_dim(dim, grad_output.ndim) 

81 

82 # Move reduction dim to the last for contiguous 2D layout [M, K] 

83 go_last = torch.movedim(grad_output, dim, -1).contiguous() 

84 y_last = torch.movedim(output, dim, -1).contiguous() 

85 K = go_last.shape[-1] 

86 if go_last.numel() == 0 or K == 0: 

87 return grad_output.clone() 

88 

89 M = go_last.numel() // K 

90 

91 go_2d = go_last.view(M, K) 

92 y_2d = y_last.view(M, K) 

93 gi_2d = torch.empty_like(go_2d) 

94 

95 BLOCK_SIZE = _choose_block_size(K) 

96 grid = (M,) 

97 

98 _log_softmax_bwd_kernel[grid]( 

99 go_2d, 

100 y_2d, 

101 gi_2d, 

102 M, 

103 K, 

104 BLOCK_SIZE=BLOCK_SIZE, 

105 ) 

106 

107 gi_last = gi_2d.view_as(go_last) 

108 grad_input = torch.movedim(gi_last, -1, dim) 

109 return grad_input 

110 

111 

112def _log_softmax_backward_data( 

113 grad_output: torch.Tensor, output: torch.Tensor, dim: int, input_dtype: torch.dtype 

114): 

115 return _log_softmax_backward_data_impl(grad_output, output, dim) 

116 

117 

118def _log_softmax_backward_data_out( 

119 grad_output: torch.Tensor, 

120 output: torch.Tensor, 

121 dim: int, 

122 input_dtype: torch.dtype, 

123 out: torch.Tensor, 

124): 

125 res = _log_softmax_backward_data_impl(grad_output, output, dim) 

126 out.copy_(res) 

127 return out