Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/lt.py: 0%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import logging 

2import os 

3 

4import triton 

5import triton.language as tl 

6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

7 

8from ..utils.pointwise_dynamic import pointwise_dynamic 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12config_ = CodeGenConfig( 

13 512, 

14 (65536, 65536, 65536), 

15 32, 

16 True, 

17 prefer_1d_tile=True, 

18 isCloseMemoryAsync=False, 

19 unroll_num=8, 

20 kunlunAutoGrid=True, 

21) 

22 

23 

24@pointwise_dynamic( 

25 promotion_methods=[(0, 1, "ALWAYS_BOOL")], 

26 config=config_, 

27) 

28@triton.jit 

29def lt_func(x, y): 

30 return x.to(tl.float32) < y 

31 

32 

33def lt(A, B): 

34 logger.debug("GEMS LT") 

35 os.environ["TRITONXPU_COMPARE_FUSION"] = "1" 

36 os.environ["TRITONXPU_FP16_FAST"] = "1" 

37 res = lt_func(A, B) 

38 del os.environ["TRITONXPU_COMPARE_FUSION"] 

39 del os.environ["TRITONXPU_FP16_FAST"] 

40 return res 

41 

42 

43@pointwise_dynamic( 

44 is_tensor=[True, False], 

45 promotion_methods=[(0, 1, "ALWAYS_BOOL")], 

46 config=config_, 

47) 

48@triton.jit 

49def lt_func_scalar(x, y): 

50 return x.to(tl.float32) < y 

51 

52 

53def lt_scalar(A, B): 

54 logger.debug("GEMS LT SCALAR") 

55 res = lt_func_scalar(A, B) 

56 return res