Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/embedding.py: 0%
98 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@triton.jit
16def embedding_kernel(
17 out_ptr, # pointer to the output
18 in_ptr, # pointer to the input
19 weight_ptr, # pointer to the weights
20 N: tl.constexpr, # number of columns in X
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = tle.program_id(0)
24 out_ptr += pid * N
25 in_ptr += pid
27 mask = tl.arange(0, BLOCK_SIZE) < N
28 cols = tl.arange(0, BLOCK_SIZE)
30 row_idx = tl.load(in_ptr)
31 weight_ptr += row_idx * N
32 embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0)
33 tl.store(out_ptr + cols, embedding_weight, mask)
36@libentry()
37@triton.jit
38def indice_freq_kernel(
39 indices_freq,
40 indices, # pointer to the input
41 elem_cnt: tl.constexpr, # number of columns in X
42 INDICE_BLOCK_SIZE: tl.constexpr,
43):
44 pid = tle.program_id(0)
45 block_start = pid * INDICE_BLOCK_SIZE
47 offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE)
48 mask = offsets < elem_cnt
50 index_element = tl.load(indices + offsets, mask=mask)
51 tl.atomic_add(indices_freq + index_element, 1, mask=mask)
54@libentry()
55@triton.jit(do_not_specialize=["padding_idx"])
56def embedding_backward_kernel(
57 grad_in, # pointer to the gradient input
58 grad_out, # pointer to the gradient output
59 indices, # pointer to the input
60 padding_idx, # padding_idx
61 HAS_PADDING_IDX: tl.constexpr,
62 N: tl.constexpr, # number of columns in X
63 BLOCK_SIZE: tl.constexpr,
64):
65 pid = tle.program_id(0)
66 grad_out += pid * N
67 indices += pid
69 mask = tl.arange(0, BLOCK_SIZE) < N
70 cols = tl.arange(0, BLOCK_SIZE)
72 row_idx = tl.load(indices).to(tl.int32)
73 if not HAS_PADDING_IDX:
74 grad_in += row_idx * N
75 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
76 if tl.constexpr(embedding_grad.dtype.is_bf16()):
77 embedding_grad = embedding_grad.to(tl.float32)
78 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
79 else:
80 if row_idx != padding_idx:
81 grad_in += row_idx * N
82 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
83 if tl.constexpr(embedding_grad.dtype.is_bf16()):
84 embedding_grad = embedding_grad.to(tl.float32)
85 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
88@libentry()
89@triton.jit(do_not_specialize=["n_rows"])
90def embedding_grad_scale_kernel(
91 grad_out,
92 indice_freq,
93 n_rows,
94 N,
95 BLOCK_SIZE: tl.constexpr,
96):
97 row_start = tle.program_id(0)
98 row_step = tle.num_programs(0)
100 for row_idx in range(row_start, n_rows, row_step):
101 embedding_scale = 1.0
102 indice_freq_val = tl.load(indice_freq + row_idx)
103 if indice_freq_val > 1:
104 embedding_scale = 1.0 / indice_freq_val
106 cols = tl.arange(0, BLOCK_SIZE)
107 mask = tl.arange(0, BLOCK_SIZE) < N
108 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)
109 scaled_embedding_grad = embedding_grad * embedding_scale
110 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)
113def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
114 logger.debug("GEMS EMBEDDING FORWARD")
115 assert not sparse, "Currently do not support sparse format"
117 M = indices.numel()
118 N = weight.shape[-1]
120 BLOCK_SIZE = triton.next_power_of_2(N)
121 # TODO: remove contiguous enforcement
122 indices = indices.contiguous()
123 weight = weight.contiguous()
124 output = torch.empty((*indices.shape, N), device=indices.device, dtype=weight.dtype)
126 with torch_device_fn.device(weight.device):
127 embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE)
129 return output
132def embedding_backward(
133 grad_outputs,
134 indices,
135 num_weights,
136 padding_idx=-1,
137 scale_grad_by_freq=False,
138 sparse=False,
139):
140 logger.debug("GEMS EMBEDDING BACKWARD")
141 assert not sparse, "Currently do not support sparse format"
143 M = indices.numel()
144 N = grad_outputs.shape[-1]
146 grad_inputs = torch.zeros(
147 (num_weights, grad_outputs.shape[-1]),
148 device=grad_outputs.device,
149 dtype=(
150 torch.float32
151 if grad_outputs.dtype is torch.bfloat16
152 else grad_outputs.dtype
153 ),
154 )
156 if scale_grad_by_freq:
157 indice_freq = torch.zeros(
158 (num_weights,),
159 requires_grad=False,
160 device=grad_outputs.device,
161 dtype=torch.int32,
162 )
163 INDICE_BLOCK_SIZE = 256
164 indice_grid = (triton.cdiv(M, INDICE_BLOCK_SIZE),)
166 with torch_device_fn.device(grad_outputs.device):
167 indice_freq_kernel[indice_grid](
168 indice_freq,
169 indices,
170 M,
171 INDICE_BLOCK_SIZE,
172 isCLOSE_TTXPU_O_ATOMIC_SIM=True,
173 )
174 else:
175 indice_freq = None
177 BLOCK_SIZE = triton.next_power_of_2(N)
179 HAS_PADDING_IDX = padding_idx is not None
181 with torch_device_fn.device(grad_outputs.device):
182 embedding_backward_kernel[M,](
183 grad_inputs,
184 grad_outputs,
185 indices,
186 padding_idx,
187 HAS_PADDING_IDX,
188 N,
189 BLOCK_SIZE,
190 )
192 if scale_grad_by_freq:
193 with torch_device_fn.device(grad_outputs.device):
194 embedding_grad_scale_kernel[M,](
195 grad_inputs, indice_freq, num_weights, N, BLOCK_SIZE
196 )
197 return (
198 grad_inputs.to(torch.bfloat16)
199 if grad_outputs.dtype is torch.bfloat16
200 else grad_inputs
201 )