Coverage for src/flag_gems/ops/embedding_dense_backward.py: 53%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def _embedding_dense_backward_kernel(
12 grad_output_ptr,
13 indices_ptr,
14 grad_weight_ptr,
15 num_weights,
16 padding_idx,
17 BLOCK_D: tl.constexpr,
18 EMBED_DIM: tl.constexpr,
19):
20 pid_n = tl.program_id(0)
21 pid_d = tl.program_id(1)
23 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
24 mask_d = offs_d < EMBED_DIM
26 idx = tl.load(indices_ptr + pid_n)
27 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights)
29 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d
30 go = tl.load(go_ptrs, mask=mask_d, other=0).to(tl.float32)
32 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d
33 mask = mask_d & valid
34 tl.atomic_add(gw_ptrs, go, mask=mask)
37@triton.jit
38def _embedding_dense_backward_count_kernel(
39 indices_ptr,
40 counts_ptr,
41 N,
42 num_weights,
43 padding_idx,
44 BLOCK_N: tl.constexpr,
45):
46 pid = tl.program_id(0)
47 offs = pid * BLOCK_N + tl.arange(0, BLOCK_N)
48 mask = offs < N
49 idx = tl.load(indices_ptr + offs, mask=mask, other=0).to(tl.int32)
50 valid = mask & (idx != padding_idx) & (idx >= 0) & (idx < num_weights)
51 tl.atomic_add(counts_ptr + idx, 1, mask=valid)
54@triton.jit
55def _embedding_dense_backward_kernel_scale_by_freq(
56 grad_output_ptr,
57 indices_ptr,
58 counts_ptr,
59 grad_weight_ptr,
60 num_weights,
61 padding_idx,
62 BLOCK_D: tl.constexpr,
63 EMBED_DIM: tl.constexpr,
64):
65 pid_n = tl.program_id(0)
66 pid_d = tl.program_id(1)
68 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)
69 mask_d = offs_d < EMBED_DIM
71 idx = tl.load(indices_ptr + pid_n).to(tl.int32)
72 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights)
74 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d
75 # go = tl.load(go_ptrs, mask=mask_d, other=0.0).to(tl.float32)
76 go = tl.load(go_ptrs, mask=mask_d, other=0.0)
78 # cnt = tl.load(counts_ptr + idx, mask=valid, other=1).to(tl.float32)
79 cnt = tl.load(counts_ptr + idx, mask=valid, other=1)
80 go = go / cnt
82 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d
83 mask = mask_d & valid
84 tl.atomic_add(gw_ptrs, go, mask=mask)
87def embedding_dense_backward(
88 grad_output: torch.Tensor,
89 indices: torch.Tensor,
90 num_weights: int,
91 padding_idx: int,
92 scale_grad_by_freq: bool,
93):
94 logger.debug("GEMS: embedding_dense_backward")
95 assert indices.dtype in (
96 torch.int32,
97 torch.int64,
98 ), "Indices must be int32 or int64."
99 assert (
100 grad_output.is_cuda and indices.is_cuda and grad_output.device == indices.device
101 ), "Inputs must be CUDA tensors on the same device."
103 device = grad_output.device
104 assert (
105 grad_output.dim() >= 2
106 ), "grad_output must have embedding dimension as the last dim."
108 D = grad_output.shape[-1]
109 go = grad_output.contiguous().view(-1, D) # (N, D)
110 idx = indices.contiguous().view(-1)
111 N = idx.numel()
113 assert go.shape[0] == N, "indices number must match grad_output rows."
114 grad_weight_fp32 = torch.zeros((num_weights, D), device=device, dtype=torch.float32)
116 BLOCK_D = 128
117 grid = (N, triton.cdiv(D, BLOCK_D))
119 if scale_grad_by_freq:
120 counts = torch.zeros((num_weights,), device=device, dtype=torch.int32)
121 BLOCK_N = 512
122 _embedding_dense_backward_count_kernel[(triton.cdiv(N, BLOCK_N),)](
123 idx,
124 counts,
125 N,
126 num_weights,
127 padding_idx if padding_idx is not None else -1,
128 BLOCK_N=BLOCK_N,
129 )
131 _embedding_dense_backward_kernel_scale_by_freq[grid](
132 go,
133 idx,
134 counts,
135 grad_weight_fp32,
136 num_weights,
137 padding_idx if padding_idx is not None else -1,
138 BLOCK_D=BLOCK_D,
139 EMBED_DIM=D,
140 )
141 else:
142 _embedding_dense_backward_kernel[grid](
143 go,
144 idx,
145 grad_weight_fp32,
146 num_weights,
147 padding_idx if padding_idx is not None else -1,
148 BLOCK_D=BLOCK_D,
149 EMBED_DIM=D,
150 )
152 if grad_output.dtype != torch.float32:
153 return grad_weight_fp32.to(grad_output.dtype)
154 return grad_weight_fp32