Coverage for src/flag_gems/runtime/backend/_ascend/ops/embedding.py: 0%
110 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.jit
17def embedding_kernel(
18 out_ptr, # pointer to the output
19 in_ptr, # pointer to the input
20 weight_ptr, # pointer to the weights
21 N: tl.constexpr, # number of columns in X
22 BLOCK_SIZE: tl.constexpr,
23):
24 pid = tle.program_id(0)
25 out_ptr += pid * N
26 in_ptr += pid
28 mask = tl.arange(0, BLOCK_SIZE) < N
29 cols = tl.arange(0, BLOCK_SIZE)
31 row_idx = tl.load(in_ptr)
32 weight_ptr += row_idx * N
33 embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0)
34 tl.store(out_ptr + cols, embedding_weight, mask)
37@libentry()
38@triton.jit
39def indice_freq_kernel(
40 indices_freq,
41 indices, # pointer to the input
42 elem_cnt: tl.constexpr, # number of columns in X
43 INDICE_BLOCK_SIZE: tl.constexpr,
44):
45 pid = tle.program_id(0)
46 block_start = pid * INDICE_BLOCK_SIZE
48 for i in range(INDICE_BLOCK_SIZE):
49 off = block_start + i
50 if off < elem_cnt:
51 idx = tl.load(indices + off)
52 tl.atomic_add(indices_freq + idx, 1)
55@libentry()
56@triton.jit(do_not_specialize=["padding_idx"])
57def embedding_backward_kernel(
58 grad_in, # pointer to the gradient input
59 grad_out, # pointer to the gradient output
60 indices, # pointer to the input
61 padding_idx, # padding_idx
62 HAS_PADDING_IDX: tl.constexpr,
63 N: tl.constexpr, # number of columns in X
64 BLOCK_SIZE: tl.constexpr,
65):
66 pid = tle.program_id(0)
67 grad_out += pid * N
68 indices += pid
70 mask = tl.arange(0, BLOCK_SIZE) < N
71 cols = tl.arange(0, BLOCK_SIZE)
73 row_idx = tl.load(indices).to(tl.int32)
74 if not HAS_PADDING_IDX:
75 grad_in += row_idx * N
76 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
77 if tl.constexpr(embedding_grad.dtype.is_bf16()):
78 embedding_grad = embedding_grad.to(tl.float32)
79 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
80 else:
81 if row_idx != padding_idx:
82 grad_in += row_idx * N
83 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
84 if tl.constexpr(embedding_grad.dtype.is_bf16()):
85 embedding_grad = embedding_grad.to(tl.float32)
86 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
89@libentry()
90@triton.jit(do_not_specialize=["n_rows"])
91def embedding_grad_scale_kernel(
92 grad_out,
93 indice_freq,
94 n_rows,
95 N,
96 BLOCK_SIZE: tl.constexpr,
97):
98 row_start = tle.program_id(0)
99 row_step = tle.num_programs(0)
101 for row_idx in range(row_start, n_rows, row_step):
102 embedding_scale = 1.0
103 indice_freq_val = tl.load(indice_freq + row_idx)
104 if indice_freq_val > 1:
105 embedding_scale = 1.0 / indice_freq_val
107 cols = tl.arange(0, BLOCK_SIZE)
108 mask = tl.arange(0, BLOCK_SIZE) < N
109 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)
110 scaled_embedding_grad = embedding_grad * embedding_scale
111 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)
114class Embedding(torch.autograd.Function):
115 @staticmethod
116 def forward(
117 ctx, weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False
118 ):
119 logger.debug("GEMS_ASCEND EMBEDDING FORWARD")
120 assert not sparse, "Currently do not support sparse format"
122 M = math.prod(indices.shape)
123 N = weight.shape[-1]
125 BLOCK_SIZE = triton.next_power_of_2(N)
126 indices = indices.contiguous()
127 weight = weight.contiguous()
128 output = torch.empty(
129 (*indices.shape, N), device=indices.device, dtype=weight.dtype
130 )
132 with torch_device_fn.device(weight.device):
133 embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE)
135 ctx.M = M
136 ctx.N = N
137 ctx.num_weights = weight.shape[0]
138 ctx.padding_idx = padding_idx
139 ctx.scale_grad_by_freq = scale_grad_by_freq
140 ctx.sparse = sparse
141 ctx.indices = indices
143 return output
145 @staticmethod
146 def backward(ctx, grad_outputs):
147 logger.debug("GEMS_ASCEND EMBEDDING BACKWARD")
148 assert not ctx.sparse, "Currently do not support sparse format"
150 grad_inputs = torch.zeros(
151 (ctx.num_weights, grad_outputs.shape[-1]),
152 device=grad_outputs.device,
153 dtype=(
154 torch.float32
155 if grad_outputs.dtype is torch.bfloat16
156 else grad_outputs.dtype
157 ),
158 )
160 if ctx.scale_grad_by_freq:
161 indice_freq = torch.zeros(
162 (ctx.num_weights,),
163 requires_grad=False,
164 device=grad_outputs.device,
165 dtype=torch.int32,
166 )
167 INDICE_BLOCK_SIZE = 256
168 indice_grid = lambda meta: (triton.cdiv(ctx.M, INDICE_BLOCK_SIZE),)
170 with torch_device_fn.device(grad_outputs.device):
171 indice_freq_kernel[indice_grid](
172 indice_freq, ctx.indices, ctx.M, INDICE_BLOCK_SIZE
173 )
174 else:
175 indice_freq = None
177 BLOCK_SIZE = triton.next_power_of_2(ctx.N)
179 HAS_PADDING_IDX = ctx.padding_idx is not None
181 with torch_device_fn.device(grad_outputs.device):
182 embedding_backward_kernel[ctx.M,](
183 grad_inputs,
184 grad_outputs,
185 ctx.indices,
186 ctx.padding_idx,
187 HAS_PADDING_IDX,
188 ctx.N,
189 BLOCK_SIZE,
190 )
192 if ctx.scale_grad_by_freq:
193 with torch_device_fn.device(grad_outputs.device):
194 embedding_grad_scale_kernel[ctx.M,](
195 grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE
196 )
197 return (
198 (
199 grad_inputs.to(torch.bfloat16)
200 if grad_outputs.dtype is torch.bfloat16
201 else grad_inputs
202 ),
203 None,
204 None,
205 None,
206 None,
207 )
210def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
211 return Embedding.apply(weight, indices, padding_idx, scale_grad_by_freq, sparse)