Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/gt.py: 0%
28 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-19 02:32 +0800
1import logging
2import os
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13config_ = CodeGenConfig(
14 512,
15 (65536, 65536, 65536),
16 32,
17 True,
18 prefer_1d_tile=True,
19 isCloseMemoryAsync=False,
20 kunlunAutoGrid=True,
21 unroll_num=8,
22)
25@pointwise_dynamic(
26 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
27 config=config_,
28)
29@triton.jit
30def gt_func(x, y):
31 return x.to(tl.float32) > y
34def gt(A, B):
35 logger.debug("GEMS GT")
36 os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
37 os.environ["TRITONXPU_FP16_FAST"] = "1"
38 res = gt_func(A, B)
39 del os.environ["TRITONXPU_COMPARE_FUSION"]
40 del os.environ["TRITONXPU_FP16_FAST"]
41 return res
44@pointwise_dynamic(
45 is_tensor=[True, False],
46 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
47 config=config_,
48)
49@triton.jit
50def gt_func_scalar(x, y):
51 return x.to(tl.float32) > y
54def gt_scalar(A, B):
55 logger.debug("GEMS GT SCALAR")
56 res = gt_func_scalar(A, B)
57 return res