Coverage for src/flag_gems/runtime/backend/_cambricon/ops/embedding.py: 0%
84 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@libentry()
14@triton.jit
15def indice_freq_kernel(
16 indices_freq,
17 indices, # pointer to the input
18 elem_cnt: tl.constexpr, # number of columns in X
19 INDICE_BLOCK_SIZE: tl.constexpr,
20):
21 pid = tl.program_id(0)
22 block_start = pid * INDICE_BLOCK_SIZE
24 offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE)
25 mask = offsets < elem_cnt
27 index_element = tl.load(indices + offsets, mask=mask)
28 tl.atomic_add(indices_freq + index_element, 1, mask=mask)
31@libentry()
32@triton.jit(do_not_specialize=["padding_idx"])
33def embedding_backward_kernel(
34 grad_in, # pointer to the gradient input
35 grad_out, # pointer to the gradient output
36 indices, # pointer to the input
37 padding_idx, # padding_idx
38 HAS_PADDING_IDX: tl.constexpr,
39 N: tl.constexpr, # number of columns in X
40 BLOCK_SIZE: tl.constexpr,
41):
42 pid = tl.program_id(0)
43 grad_out += pid * N
44 indices += pid
46 mask = tl.arange(0, BLOCK_SIZE) < N
47 cols = tl.arange(0, BLOCK_SIZE)
49 row_idx = tl.load(indices).to(tl.int32)
50 if not HAS_PADDING_IDX:
51 grad_in += row_idx * N
52 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
53 if tl.constexpr(embedding_grad.dtype.is_bf16()):
54 embedding_grad = embedding_grad.to(tl.float32)
55 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
56 else:
57 if row_idx != padding_idx:
58 grad_in += row_idx * N
59 embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
60 if tl.constexpr(embedding_grad.dtype.is_bf16()):
61 embedding_grad = embedding_grad.to(tl.float32)
62 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
65@libentry()
66@triton.jit(do_not_specialize=["n_rows"])
67def embedding_grad_scale_kernel(
68 grad_out,
69 indice_freq,
70 n_rows,
71 N,
72 BLOCK_SIZE: tl.constexpr,
73):
74 row_start = tl.program_id(0)
75 row_step = tl.num_programs(0)
77 for row_idx in range(row_start, n_rows, row_step):
78 embedding_scale = 1.0
79 indice_freq_val = tl.load(indice_freq + row_idx)
80 if indice_freq_val > 1:
81 embedding_scale = 1.0 / indice_freq_val
83 cols = tl.arange(0, BLOCK_SIZE)
84 mask = tl.arange(0, BLOCK_SIZE) < N
85 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)
86 scaled_embedding_grad = embedding_grad * embedding_scale
87 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)
90def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
91 logger.debug("GEMS_CAMBRICON EMBEDDING FORWARD")
92 assert not sparse, "Currently do not support sparse format"
94 indices = indices.contiguous()
95 weight = weight.contiguous()
97 from .index_select import index_select
99 output = index_select(weight, 0, indices.flatten())
100 output = output.reshape(indices.shape + (-1,))
102 if padding_idx is not None and padding_idx < 0:
103 padding_idx = None
105 return output
108def embedding_backward(
109 grad_outputs,
110 indices,
111 num_weights,
112 padding_idx=-1,
113 scale_grad_by_freq=False,
114 sparse=False,
115):
116 logger.debug("GEMS_CAMBRICON EMBEDDING BACKWARD")
117 assert not sparse, "Currently do not support sparse format"
119 M = indices.numel()
120 N = grad_outputs.shape[-1]
122 grad_inputs = torch.zeros(
123 (num_weights, grad_outputs.shape[-1]),
124 device=grad_outputs.device,
125 dtype=torch.float32
126 if grad_outputs.dtype is torch.bfloat16
127 else grad_outputs.dtype,
128 )
130 if scale_grad_by_freq:
131 indice_freq = torch.zeros(
132 (num_weights,),
133 requires_grad=False,
134 device=grad_outputs.device,
135 dtype=torch.int32,
136 )
137 INDICE_BLOCK_SIZE = 256
138 indice_grid = lambda meta: (triton.cdiv(M, INDICE_BLOCK_SIZE),)
140 with torch_device_fn.device(grad_outputs.device):
141 indice_freq_kernel[indice_grid](indice_freq, indices, M, INDICE_BLOCK_SIZE)
142 else:
143 indice_freq = None
145 BLOCK_SIZE = triton.next_power_of_2(N)
147 HAS_PADDING_IDX = padding_idx is not None
149 with torch_device_fn.device(grad_outputs.device):
150 embedding_backward_kernel[M,](
151 grad_inputs,
152 grad_outputs,
153 indices,
154 padding_idx,
155 HAS_PADDING_IDX,
156 N,
157 BLOCK_SIZE,
158 )
160 if scale_grad_by_freq:
161 with torch_device_fn.device(grad_outputs.device):
162 embedding_grad_scale_kernel[M,](
163 grad_inputs, indice_freq, num_weights, N, BLOCK_SIZE
164 )
165 return (
166 grad_inputs.to(torch.bfloat16)
167 if grad_outputs.dtype is torch.bfloat16
168 else grad_inputs
169 )