Coverage for src/flag_gems/runtime/backend/_ascend/ops/masked_fill.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import broadcastable_to, libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
14@libentry()
15@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
16@triton.jit
17def masked_fill_kernel(
18 inp,
19 expand_mask,
20 value,
21 out,
22 N,
23 BLOCK_SIZE: tl.constexpr,
24 BLOCK_SIZE_SUB: tl.constexpr,
25):
26 pid = tle.program_id(axis=0)
27 base_offset = pid * BLOCK_SIZE
29 # 计算需要处理的总块数
30 num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB
32 # 循环处理每个子块
33 for sub_block_idx in range(num_sub_blocks):
34 # 计算当前子块的偏移量
35 sub_offset = base_offset + sub_block_idx * BLOCK_SIZE_SUB
36 offsets = sub_offset + tl.arange(0, BLOCK_SIZE_SUB)
37 mask = offsets < N
39 # 加载 input 和 mask
40 input_vals = tl.load(inp + offsets, mask=mask, other=0)
41 fill_mask_vals = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
43 # 先写入原始输入
44 tl.store(out + offsets, input_vals, mask=mask)
46 # 再在需要填充的位置覆盖写入 value
47 value_to_write = tl.full([BLOCK_SIZE_SUB], value, dtype=input_vals.dtype)
48 overwrite_vals = tl.where(
49 fill_mask_vals, value_to_write, tl.load(out + offsets, mask=mask, other=0)
50 )
51 tl.store(out + offsets, overwrite_vals, mask=mask)
54@libentry()
55@triton.autotune(configs=runtime.get_tuned_config("masked_fill"), key=["N"])
56@triton.jit
57def masked_fill_kernel_self(
58 inp, expand_mask, value, N, BLOCK_SIZE: tl.constexpr, BLOCK_SIZE_SUB: tl.constexpr
59):
60 pid = tle.program_id(axis=0)
61 base_offset = pid * BLOCK_SIZE
63 # 计算需要处理的总块数
64 num_sub_blocks = BLOCK_SIZE // BLOCK_SIZE_SUB
66 # 循环处理每个子块
67 for sub_block_idx in range(num_sub_blocks):
68 # 计算当前子块的偏移量
69 sub_offset = base_offset + sub_block_idx * BLOCK_SIZE_SUB
70 offsets = sub_offset + tl.arange(0, BLOCK_SIZE_SUB)
71 mask = offsets < N
73 # 加载 expand_mask
74 fill_mask = tl.load(expand_mask + offsets, mask=mask, other=0).to(tl.int1)
76 # 构造写入的值:fill_mask==1 用 value,fill_mask==0 保留原值
77 orig = tl.load(inp + offsets, mask=mask, other=0)
78 value_vec = tl.full([BLOCK_SIZE_SUB], value, dtype=orig.dtype)
79 result = tl.where(fill_mask, value_vec, orig)
81 # 存储结果
82 tl.store(inp + offsets, result, mask=mask)
85def masked_fill(inp, mask, value):
86 logger.debug("GEMS_ASCEND MASKED FILL")
87 assert (
88 (torch.is_tensor(value) and value.ndim == 0)
89 or isinstance(value, int)
90 or isinstance(value, float)
91 ), "masked_fill_ only supports a 0-dimensional value tensor"
92 if torch.is_tensor(value):
93 # Value can be a tensor or a scalar
94 value = value.item()
95 assert broadcastable_to(
96 mask.shape, inp.shape
97 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
99 if inp.ndim == 0:
100 # inp is a single-value
101 return (
102 torch.tensor(value, dtype=inp.dtype, device=inp.device)
103 if mask.item()
104 else inp.clone()
105 )
107 inp = inp.contiguous()
108 mask = mask.contiguous()
109 expand_mask = mask.expand(inp.shape)
110 out = torch.empty_like(inp, dtype=inp.dtype, device=inp.device)
112 N = inp.numel()
113 if N == 0:
114 return out
115 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
116 masked_fill_kernel[grid](inp, expand_mask.to(torch.int), value, out, N)
117 return out
120def masked_fill_(inp, mask, value):
121 logger.debug("GEMS_ASCEND MASKED FILL_")
122 assert (
123 (torch.is_tensor(value) and value.ndim == 0)
124 or isinstance(value, int)
125 or isinstance(value, float)
126 ), "masked_fill_ only supports a 0-dimensional value tensor"
127 if torch.is_tensor(value):
128 # Value can be a tensor or a scalar
129 value = value.item()
130 assert broadcastable_to(
131 mask.shape, inp.shape
132 ), "The shape of mask must be broadcastable with the shape of the underlying tensor"
134 if inp.ndim == 0:
135 # inp is a single-value
136 if mask.item():
137 inp[()] = value
138 return inp
140 inp = inp.contiguous()
141 mask = mask.contiguous()
142 expand_mask = mask.expand(inp.shape)
144 N = inp.numel()
145 if N == 0:
146 return inp
147 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
148 masked_fill_kernel_self[grid](inp, expand_mask.to(torch.int), value, N)
149 return inp