Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/lerp.py: 0%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import triton
4import triton.language as tl
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(is_tensor=[True, True, True], promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def lerp_tensor_kernel(input, end, weight):
14 input32 = input.to(tl.float32)
15 end32 = end.to(tl.float32)
16 weight32 = weight.to(tl.float32)
17 res32 = tl.where(
18 tl.abs(weight32) < 0.5,
19 input32 + weight32 * (end32 - input32),
20 end32 - (end32 - input32) * (1 - weight32),
21 )
22 return res32.to(input.dtype)
25@pointwise_dynamic(
26 is_tensor=[True, True, False],
27 dtypes=[None, None, float],
28 promotion_methods=[(0, 1, "DEFAULT")],
29)
30@triton.jit(do_not_specialize=["weight"])
31def lerp_scalar_kernel_head(input, end, weight):
32 input32 = input.to(tl.float32)
33 end32 = end.to(tl.float32)
34 weight32 = weight.to(tl.float32)
35 return (input32 + weight32 * (end32 - input32)).to(input.dtype)
38@pointwise_dynamic(
39 is_tensor=[True, True, False],
40 dtypes=[None, None, float],
41 promotion_methods=[(0, 1, "DEFAULT")],
42)
43@triton.jit(do_not_specialize=["weight"])
44def lerp_scalar_kernel_tail(input, end, weight):
45 input32 = input.to(tl.float32)
46 end32 = end.to(tl.float32)
47 weight32 = weight.to(tl.float32)
48 return (end32 - (end32 - input32) * (1 - weight32)).to(input.dtype)
51def lerp_tensor(input, end, weight):
52 logger.debug("GEMS LERP TENSOR")
53 out = lerp_tensor_kernel(input, end, weight)
54 return out
57def lerp_tensor_(input, end, weight):
58 logger.debug("GEMS LERP INPLACE TENSOR")
59 return lerp_tensor_kernel(input, end, weight, out0=input)
62def lerp_scalar(input, end, weight):
63 logger.debug("GEMS LERP TENSOR")
64 if weight < 0.5:
65 out = lerp_scalar_kernel_head(input, end, weight)
66 else:
67 out = lerp_scalar_kernel_tail(input, end, weight)
68 return out
71def lerp_scalar_(input, end, weight):
72 logger.debug("GEMS LERP INPLACE TENSOR")
73 if weight < 0.5:
74 return lerp_scalar_kernel_head(input, end, weight, out0=input)
75 else:
76 return lerp_scalar_kernel_tail(input, end, weight, out0=input)