Coverage for src/flag_gems/experimental_ops/fmin.py: 0%
82 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def fmin_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 y = tl.load(y_ptr + offsets, mask=mask)
14 out = tl.minimum(x, y)
15 tl.store(out_ptr + offsets, out, mask=mask)
18def _to_tensor(x, device=None, dtype=None):
19 if isinstance(x, torch.Tensor):
20 t = x
21 if device is not None and t.device != device:
22 t = t.to(device)
23 if dtype is not None and t.dtype != dtype:
24 t = t.to(dtype)
25 return t
26 return torch.tensor(x, device=device, dtype=dtype)
29def _prepare_inputs(a, b, out=None):
30 # Determine target device
31 dev = None
32 if isinstance(out, torch.Tensor):
33 dev = out.device
34 else:
35 if isinstance(a, torch.Tensor):
36 dev = a.device
37 if isinstance(b, torch.Tensor):
38 dev = b.device if dev is None else dev
39 if dev is None:
40 dev = torch.device("cuda")
41 # Convert to tensors on the target device
42 a = _to_tensor(a, device=dev)
43 b = _to_tensor(b, device=dev)
44 if a.device.type != "cuda" or b.device.type != "cuda":
45 raise ValueError(
46 "Inputs must be CUDA tensors or convertible to CUDA tensors for Triton kernels."
47 )
48 # Broadcast
49 a_b, b_b = torch.broadcast_tensors(a, b)
50 # Determine output dtype
51 out_dtype = torch.result_type(a_b, b_b)
52 if out_dtype.is_complex:
53 raise TypeError("fmin does not support complex dtypes.")
54 # Compute dtype for kernel (avoid bool in kernel by using int8)
55 compute_dtype = torch.int8 if out_dtype == torch.bool else out_dtype
56 a_c = a_b.to(compute_dtype).contiguous()
57 b_c = b_b.to(compute_dtype).contiguous()
58 return a_c, b_c, out_dtype, compute_dtype
61def fmin(a, b):
62 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=None)
63 out_shape = a_c.shape # same as b_c.shape after broadcast
64 # Allocate outputs
65 if compute_dtype == out_dtype:
66 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device)
67 out_c = out
68 else:
69 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device)
70 out_c = torch.empty(out_shape, dtype=compute_dtype, device=a_c.device)
71 n_elements = out_c.numel()
72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024)
74 if out_c.dtype != out.dtype:
75 out.copy_(out_c.to(out_dtype))
76 return out
79def fmin_out(a, b, out):
80 if not isinstance(out, torch.Tensor):
81 raise TypeError("out must be a Tensor")
82 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=out)
83 # Validate out tensor shape/dtype/device
84 expected_shape = a_c.shape
85 if out.device != a_c.device:
86 raise ValueError("out tensor must be on the same device as inputs.")
87 if out.dtype != out_dtype:
88 raise TypeError(f"out tensor has dtype {out.dtype}, expected {out_dtype}.")
89 if tuple(out.shape) != tuple(expected_shape):
90 raise ValueError(
91 f"out tensor has shape {tuple(out.shape)}, expected {tuple(expected_shape)} after broadcasting."
92 )
93 # Prepare a contiguous buffer to write into
94 if compute_dtype == out_dtype and out.is_contiguous():
95 out_c = out
96 else:
97 # If dtype conversion is needed or out is non-contiguous, use a temporary buffer
98 out_c = torch.empty(expected_shape, dtype=compute_dtype, device=out.device)
99 n_elements = out_c.numel()
100 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
101 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024)
102 # Move result into out if we used a temporary buffer or dtype differs
103 if out_c is not out:
104 if out_c.dtype != out.dtype:
105 out.copy_(out_c.to(out.dtype))
106 else:
107 if out.is_contiguous():
108 out.copy_(out_c)
109 else:
110 out.view_as(out.contiguous()).copy_(out_c)
111 return out