Coverage for src/flag_gems/ops/lerp.py: 86%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(is_tensor=[True, True, True], promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def lerp_tensor_kernel(input, end, weight):
14 return tl.where(
15 tl.abs(weight) < 0.5,
16 input + weight * (end - input),
17 end - (end - input) * (1 - weight),
18 )
21@pointwise_dynamic(
22 is_tensor=[True, True, False],
23 dtypes=[None, None, float],
24 promotion_methods=[(0, 1, "DEFAULT")],
25)
26@triton.jit(do_not_specialize=["weight"])
27def lerp_scalar_kernel_head(input, end, weight):
28 return input + weight * (end - input)
31@pointwise_dynamic(
32 is_tensor=[True, True, False],
33 dtypes=[None, None, float],
34 promotion_methods=[(0, 1, "DEFAULT")],
35)
36@triton.jit(do_not_specialize=["weight"])
37def lerp_scalar_kernel_tail(input, end, weight):
38 return end - (end - input) * (1 - weight)
41def lerp_tensor(input, end, weight):
42 logger.debug("GEMS LERP TENSOR")
43 out = lerp_tensor_kernel(input, end, weight)
44 return out
47def lerp_tensor_(input, end, weight):
48 logger.debug("GEMS LERP INPLACE TENSOR")
49 return lerp_tensor_kernel(input, end, weight, out0=input)
52def lerp_scalar(input, end, weight):
53 logger.debug("GEMS LERP TENSOR")
54 if weight < 0.5:
55 out = lerp_scalar_kernel_head(input, end, weight)
56 else:
57 out = lerp_scalar_kernel_tail(input, end, weight)
58 return out
61def lerp_scalar_(input, end, weight):
62 logger.debug("GEMS LERP INPLACE TENSOR")
63 if weight < 0.5:
64 return lerp_scalar_kernel_head(input, end, weight, out0=input)
65 else:
66 return lerp_scalar_kernel_tail(input, end, weight, out0=input)