Coverage for src/flag_gems/ops/fmin.py: 70%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def fmin_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 x = tl.load(x_ptr + offsets, mask=mask) 

20 y = tl.load(y_ptr + offsets, mask=mask) 

21 out = tl.minimum(x, y) 

22 tl.store(out_ptr + offsets, out, mask=mask) 

23 

24 

25def _to_tensor(x, device=None, dtype=None): 

26 if isinstance(x, torch.Tensor): 

27 t = x 

28 if device is not None and t.device != device: 

29 t = t.to(device) 

30 if dtype is not None and t.dtype != dtype: 

31 t = t.to(dtype) 

32 return t 

33 return torch.tensor(x, device=device, dtype=dtype) 

34 

35 

36def _prepare_inputs(a, b, out=None): 

37 dev = None 

38 if isinstance(out, torch.Tensor): 

39 dev = out.device 

40 else: 

41 if isinstance(a, torch.Tensor): 

42 dev = a.device 

43 if isinstance(b, torch.Tensor): 

44 dev = b.device if dev is None else dev 

45 if dev is None: 

46 dev = torch.device("cuda") 

47 a = _to_tensor(a, device=dev) 

48 b = _to_tensor(b, device=dev) 

49 a_b, b_b = torch.broadcast_tensors(a, b) 

50 out_dtype = torch.result_type(a_b, b_b) 

51 if out_dtype.is_complex: 

52 raise TypeError("fmin does not support complex dtypes.") 

53 compute_dtype = torch.int8 if out_dtype == torch.bool else out_dtype 

54 a_c = a_b.to(compute_dtype).contiguous() 

55 b_c = b_b.to(compute_dtype).contiguous() 

56 return a_c, b_c, out_dtype, compute_dtype 

57 

58 

59def fmin(a, b): 

60 logger.debug("GEMS FMIN") 

61 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=None) 

62 out_shape = a_c.shape 

63 if compute_dtype == out_dtype: 

64 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device) 

65 out_c = out 

66 else: 

67 out = torch.empty(out_shape, dtype=out_dtype, device=a_c.device) 

68 out_c = torch.empty(out_shape, dtype=compute_dtype, device=a_c.device) 

69 n_elements = out_c.numel() 

70 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

71 with torch_device_fn.device(a_c.device): 

72 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024) 

73 if out_c.dtype != out.dtype: 

74 out.copy_(out_c.to(out_dtype)) 

75 return out 

76 

77 

78def fmin_out(a, b, out): 

79 logger.debug("GEMS FMIN_OUT") 

80 if not isinstance(out, torch.Tensor): 

81 raise TypeError("out must be a Tensor") 

82 a_c, b_c, out_dtype, compute_dtype = _prepare_inputs(a, b, out=out) 

83 expected_shape = a_c.shape 

84 if out.device != a_c.device: 

85 raise ValueError("out tensor must be on the same device as inputs.") 

86 if out.dtype != out_dtype: 

87 raise TypeError(f"out tensor has dtype {out.dtype}, expected {out_dtype}.") 

88 if tuple(out.shape) != tuple(expected_shape): 

89 raise ValueError( 

90 f"out tensor has shape {tuple(out.shape)}, expected {tuple(expected_shape)} after broadcasting." 

91 ) 

92 if compute_dtype == out_dtype and out.is_contiguous(): 

93 out_c = out 

94 else: 

95 out_c = torch.empty(expected_shape, dtype=compute_dtype, device=out.device) 

96 n_elements = out_c.numel() 

97 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

98 with torch_device_fn.device(out.device): 

99 fmin_kernel[grid](a_c, b_c, out_c, n_elements, BLOCK_SIZE=1024) 

100 if out_c is not out: 

101 if out_c.dtype != out.dtype: 

102 out.copy_(out_c.to(out.dtype)) 

103 else: 

104 if out.is_contiguous(): 

105 out.copy_(out_c) 

106 else: 

107 out.view_as(out.contiguous()).copy_(out_c) 

108 return out