Coverage for src/flag_gems/ops/fmin.py: 70%
87 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def fmin_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
19 x = tl.load(x_ptr + offsets, mask=mask)
20 y = tl.load(y_ptr + offsets, mask=mask)
21 out = tl.minimum(x, y)
22 tl.store(out_ptr + offsets, out, mask=mask)
25def _to_tensor(x, device=None, dtype=None):
26 if isinstance(x, torch.Tensor):
27 t = x
28 if device is not None and t.device != device:
29 t = t.to(device)
30 if dtype is not None and t.dtype != dtype:
31 t = t.to(dtype)
32 return t
33 return torch.tensor(x, device=device, dtype=dtype)
36def _prepare_inputs(a, b, out=None):
37 dev = None
38 if isinstance(out, torch.Tensor):
39 dev = out.device
40 else:
41 if isinstance(a, torch.Tensor):
42 dev = a.device
43 if isinstance(b, torch.Tensor):
44 dev = b.device if dev is None else dev
45 if dev is None:
46 dev = torch.device("cuda")
47 a = _to_tensor(a, device=dev)
48 b = _to_tensor(b, device=dev)
49 a_b, b_b = torch.broadcast_tensors(a, b)
50 out_dtype = torch.result_type(a_b, b_b)
51 if out_dtype.is_complex:
52 raise TypeError("fmin does not support complex dtypes.")
53 compute_dtype = torch.int8 if out_dtype == torch.bool else out_dtype
54 a_c = a_b.to(compute_dtype).contiguous()
55 b_c = b_b.to(compute_dtype).contiguous()
56 return a_c, b_c, out_dtype, compute_dtype
59def fmin(a, b):
60 logger.debug("GEMS FMIN")
61 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=None)
62 out_shape = a_c.shape
63 if compute_dtype == out_dtype:
64 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device)
65 out_c = out
66 else:
67 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device)
68 out_c = torch.empty(out_shape, dtype=compute_dtype, device=a_c.device)
69 n_elements = out_c.numel()
70 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
71 with torch_device_fn.device(a_c.device):
72 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024)
73 if out_c.dtype != out.dtype:
74 out.copy_(out_c.to(out_dtype))
75 return out
78def fmin_out(a, b, out):
79 logger.debug("GEMS FMIN_OUT")
80 if not isinstance(out, torch.Tensor):
81 raise TypeError("out must be a Tensor")
82 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=out)
83 expected_shape = a_c.shape
84 if out.device != a_c.device:
85 raise ValueError("out tensor must be on the same device as inputs.")
86 if out.dtype != out_dtype:
87 raise TypeError(f"out tensor has dtype {out.dtype}, expected {out_dtype}.")
88 if tuple(out.shape) != tuple(expected_shape):
89 raise ValueError(
90 f"out tensor has shape {tuple(out.shape)}, expected {tuple(expected_shape)} after broadcasting."
91 )
92 if compute_dtype == out_dtype and out.is_contiguous():
93 out_c = out
94 else:
95 out_c = torch.empty(expected_shape, dtype=compute_dtype, device=out.device)
96 n_elements = out_c.numel()
97 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
98 with torch_device_fn.device(out.device):
99 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024)
100 if out_c is not out:
101 if out_c.dtype != out.dtype:
102 out.copy_(out_c.to(out.dtype))
103 else:
104 if out.is_contiguous():
105 out.copy_(out_c)
106 else:
107 out.view_as(out.contiguous()).copy_(out_c)
108 return out