Coverage for src/flag_gems/experimental_ops/fmin.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def fmin_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 y = tl.load(y_ptr + offsets, mask=mask) 

14 out = tl.minimum(x, y) 

15 tl.store(out_ptr + offsets, out, mask=mask) 

16 

17 

18def _to_tensor(x, device=None, dtype=None): 

19 if isinstance(x, torch.Tensor): 

20 t = x 

21 if device is not None and t.device != device: 

22 t = t.to(device) 

23 if dtype is not None and t.dtype != dtype: 

24 t = t.to(dtype) 

25 return t 

26 return torch.tensor(x, device=device, dtype=dtype) 

27 

28 

29def _prepare_inputs(a, b, out=None): 

30 # Determine target device 

31 dev = None 

32 if isinstance(out, torch.Tensor): 

33 dev = out.device 

34 else: 

35 if isinstance(a, torch.Tensor): 

36 dev = a.device 

37 if isinstance(b, torch.Tensor): 

38 dev = b.device if dev is None else dev 

39 if dev is None: 

40 dev = torch.device("cuda") 

41 # Convert to tensors on the target device 

42 a = _to_tensor(a, device=dev) 

43 b = _to_tensor(b, device=dev) 

44 if a.device.type != "cuda" or b.device.type != "cuda": 

45 raise ValueError( 

46 "Inputs must be CUDA tensors or convertible to CUDA tensors for Triton kernels." 

47 ) 

48 # Broadcast 

49 a_b, b_b = torch.broadcast_tensors(a, b) 

50 # Determine output dtype 

51 out_dtype = torch.result_type(a_b, b_b) 

52 if out_dtype.is_complex: 

53 raise TypeError("fmin does not support complex dtypes.") 

54 # Compute dtype for kernel (avoid bool in kernel by using int8) 

55 compute_dtype = torch.int8 if out_dtype == torch.bool else out_dtype 

56 a_c = a_b.to(compute_dtype).contiguous() 

57 b_c = b_b.to(compute_dtype).contiguous() 

58 return a_c, b_c, out_dtype, compute_dtype 

59 

60 

61def fmin(a, b): 

62 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=None) 

63 out_shape = a_c.shape # same as b_c.shape after broadcast 

64 # Allocate outputs 

65 if compute_dtype == out_dtype: 

66 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device) 

67 out_c = out 

68 else: 

69 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device) 

70 out_c = torch.empty(out_shape, dtype=compute_dtype, device=a_c.device) 

71 n_elements = out_c.numel() 

72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

73 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024) 

74 if out_c.dtype != out.dtype: 

75 out.copy_(out_c.to(out_dtype)) 

76 return out 

77 

78 

79def fmin_out(a, b, out): 

80 if not isinstance(out, torch.Tensor): 

81 raise TypeError("out must be a Tensor") 

82 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=out) 

83 # Validate out tensor shape/dtype/device 

84 expected_shape = a_c.shape 

85 if out.device != a_c.device: 

86 raise ValueError("out tensor must be on the same device as inputs.") 

87 if out.dtype != out_dtype: 

88 raise TypeError(f"out tensor has dtype {out.dtype}, expected {out_dtype}.") 

89 if tuple(out.shape) != tuple(expected_shape): 

90 raise ValueError( 

91 f"out tensor has shape {tuple(out.shape)}, expected {tuple(expected_shape)} after broadcasting." 

92 ) 

93 # Prepare a contiguous buffer to write into 

94 if compute_dtype == out_dtype and out.is_contiguous(): 

95 out_c = out 

96 else: 

97 # If dtype conversion is needed or out is non-contiguous, use a temporary buffer 

98 out_c = torch.empty(expected_shape, dtype=compute_dtype, device=out.device) 

99 n_elements = out_c.numel() 

100 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

101 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024) 

102 # Move result into out if we used a temporary buffer or dtype differs 

103 if out_c is not out: 

104 if out_c.dtype != out.dtype: 

105 out.copy_(out_c.to(out.dtype)) 

106 else: 

107 if out.is_contiguous(): 

108 out.copy_(out_c) 

109 else: 

110 out.view_as(out.contiguous()).copy_(out_c) 

111 return out