Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/ne.py: 0%
28 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
2import os
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13config_ = CodeGenConfig(
14 512,
15 (65536, 65536, 65536),
16 32,
17 True,
18 prefer_1d_tile=True,
19 isCloseMemoryAsync=False,
20)
23@pointwise_dynamic(
24 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
25 config=config_,
26)
27@triton.jit
28def ne_func(x, y):
29 return x.to(tl.float32) != y.to(tl.float32)
32def ne(A, B):
33 logger.debug("GEMS NE")
34 os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
35 os.environ["TRITONXPU_FP16_FAST"] = "1"
36 res = ne_func(A, B)
37 del os.environ["TRITONXPU_COMPARE_FUSION"]
38 del os.environ["TRITONXPU_FP16_FAST"]
39 return res
42@pointwise_dynamic(
43 is_tensor=[True, False],
44 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
45 config=config_,
46)
47@triton.jit
48def ne_func_scalar(x, y):
49 return x.to(tl.float32) != y.to(tl.float32)
52def ne_scalar(A, B):
53 logger.debug("GEMS NE SCALAR")
54 res = ne_func_scalar(A, B)
55 return res