Coverage for src/flag_gems/ops/minimum.py: 78%
18 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.runtime import device
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
10device = device.name
13@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 0, "DEFAULT")])
14@triton.jit
15def minimum_kernel(X, Y):
16 if X.dtype == tl.bfloat16:
17 X = X.to(tl.float32)
18 Y = Y.to(tl.float32)
19 return tl.minimum(X, Y)
22def minimum(X, Y):
23 logger.debug("GEMS MINIMUM")
24 assert X.device.type == device and Y.device.type == device
25 return minimum_kernel(X, Y)