Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/ge.py: 0%
28 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
2import os
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12config_ = CodeGenConfig(
13 512,
14 (65536, 65536, 65536),
15 32,
16 True,
17 prefer_1d_tile=True,
18 isCloseMemoryAsync=False,
19 kunlunAutoGrid=True,
20 unroll_num=8,
21)
24@pointwise_dynamic(
25 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
26 config=config_,
27)
28@triton.jit
29def ge_func(x, y):
30 return x.to(tl.float32) >= y
33def ge(A, B):
34 logger.debug("GEMS GE")
35 os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
36 os.environ["TRITONXPU_FP16_FAST"] = "1"
37 res = ge_func(A, B)
38 del os.environ["TRITONXPU_COMPARE_FUSION"]
39 del os.environ["TRITONXPU_FP16_FAST"]
40 return res
43@pointwise_dynamic(
44 is_tensor=[True, False],
45 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
46 config=config_,
47)
48@triton.jit
49def ge_func_scalar(x, y):
50 return x.to(tl.float32) >= y
53def ge_scalar(A, B):
54 logger.debug("GEMS GE SCALAR")
55 res = ge_func_scalar(A, B)
56 return res