Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/eq.py: 0%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2import os
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from flag_gems.runtime import device
10from ..utils.pointwise_dynamic import pointwise_dynamic
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13device = device.name
15config_ = CodeGenConfig(
16 512,
17 (65536, 65536, 65536),
18 32,
19 True,
20 prefer_1d_tile=True,
21 isCloseMemoryAsync=False,
22)
25@pointwise_dynamic(
26 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
27 config=config_,
28)
29@triton.jit
30def eq_func(x, y):
31 return x.to(tl.float32) == y.to(tl.float32)
34def eq(A, B):
35 if A.device != B.device:
36 if A.device.type == device:
37 B = B.to(A.device)
38 else:
39 A = A.to(B.device)
40 logger.debug("GEMS EQ")
41 os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
42 os.environ["TRITONXPU_FP16_FAST"] = "1"
43 res = eq_func(A, B)
44 del os.environ["TRITONXPU_COMPARE_FUSION"]
45 del os.environ["TRITONXPU_FP16_FAST"]
46 return res
49@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")])
50@triton.jit
51def eq_func_scalar(x, y):
52 return x.to(tl.float32) == y.to(tl.float32)
55def eq_scalar(A, B):
56 logger.debug("GEMS EQ SCALAR")
57 return eq_func_scalar(A, B)