Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/lerp.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic(is_tensor=[True, True, True], promotion_methods=[(0, 1, "DEFAULT")]) 

12@triton.jit 

13def lerp_tensor_kernel(input, end, weight): 

14 input32 = input.to(tl.float32) 

15 end32 = end.to(tl.float32) 

16 weight32 = weight.to(tl.float32) 

17 res32 = tl.where( 

18 tl.abs(weight32) < 0.5, 

19 input32 + weight32 * (end32 - input32), 

20 end32 - (end32 - input32) * (1 - weight32), 

21 ) 

22 return res32.to(input.dtype) 

23 

24 

25@pointwise_dynamic( 

26 is_tensor=[True, True, False], 

27 dtypes=[None, None, float], 

28 promotion_methods=[(0, 1, "DEFAULT")], 

29) 

30@triton.jit(do_not_specialize=["weight"]) 

31def lerp_scalar_kernel_head(input, end, weight): 

32 input32 = input.to(tl.float32) 

33 end32 = end.to(tl.float32) 

34 weight32 = weight.to(tl.float32) 

35 return (input32 + weight32 * (end32 - input32)).to(input.dtype) 

36 

37 

38@pointwise_dynamic( 

39 is_tensor=[True, True, False], 

40 dtypes=[None, None, float], 

41 promotion_methods=[(0, 1, "DEFAULT")], 

42) 

43@triton.jit(do_not_specialize=["weight"]) 

44def lerp_scalar_kernel_tail(input, end, weight): 

45 input32 = input.to(tl.float32) 

46 end32 = end.to(tl.float32) 

47 weight32 = weight.to(tl.float32) 

48 return (end32 - (end32 - input32) * (1 - weight32)).to(input.dtype) 

49 

50 

51def lerp_tensor(input, end, weight): 

52 logger.debug("GEMS LERP TENSOR") 

53 out = lerp_tensor_kernel(input, end, weight) 

54 return out 

55 

56 

57def lerp_tensor_(input, end, weight): 

58 logger.debug("GEMS LERP INPLACE TENSOR") 

59 return lerp_tensor_kernel(input, end, weight, out0=input) 

60 

61 

62def lerp_scalar(input, end, weight): 

63 logger.debug("GEMS LERP TENSOR") 

64 if weight < 0.5: 

65 out = lerp_scalar_kernel_head(input, end, weight) 

66 else: 

67 out = lerp_scalar_kernel_tail(input, end, weight) 

68 return out 

69 

70 

71def lerp_scalar_(input, end, weight): 

72 logger.debug("GEMS LERP INPLACE TENSOR") 

73 if weight < 0.5: 

74 return lerp_scalar_kernel_head(input, end, weight, out0=input) 

75 else: 

76 return lerp_scalar_kernel_tail(input, end, weight, out0=input)