Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/lt.py: 0%
28 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
2import os
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12config_ = CodeGenConfig(
13 512,
14 (65536, 65536, 65536),
15 32,
16 True,
17 prefer_1d_tile=True,
18 isCloseMemoryAsync=False,
19 unroll_num=8,
20 kunlunAutoGrid=True,
21)
24@pointwise_dynamic(
25 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
26 config=config_,
27)
28@triton.jit
29def lt_func(x, y):
30 return x.to(tl.float32) < y
33def lt(A, B):
34 logger.debug("GEMS LT")
35 os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
36 os.environ["TRITONXPU_FP16_FAST"] = "1"
37 res = lt_func(A, B)
38 del os.environ["TRITONXPU_COMPARE_FUSION"]
39 del os.environ["TRITONXPU_FP16_FAST"]
40 return res
43@pointwise_dynamic(
44 is_tensor=[True, False],
45 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
46 config=config_,
47)
48@triton.jit
49def lt_func_scalar(x, y):
50 return x.to(tl.float32) < y
53def lt_scalar(A, B):
54 logger.debug("GEMS LT SCALAR")
55 res = lt_func_scalar(A, B)
56 return res