Coverage for src/flag_gems/experimental_ops/zeros_like.py: 0%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _fill_zero_kernel( 

8 out_ptr, # *Pointer* to output vector. 

9 n_elements, # Number of elements to write. 

10 BLOCK_SIZE: tl.constexpr, # Number of elements per program. 

11 OUT_DTYPE: tl.constexpr, # Triton dtype for the output. 

12): 

13 pid = tl.program_id(axis=0) 

14 block_start = pid * BLOCK_SIZE 

15 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

16 mask = offsets < n_elements 

17 zeros = tl.full([BLOCK_SIZE], 0, dtype=OUT_DTYPE) 

18 tl.store(out_ptr + offsets, zeros, mask=mask) 

19 

20 

21def _torch_dtype_to_triton_dtype(dtype: torch.dtype): 

22 # Map torch dtypes to Triton dtypes 

23 if dtype is torch.float32: 

24 return tl.float32 

25 if dtype is torch.float16: 

26 return tl.float16 

27 if dtype is torch.bfloat16: 

28 return tl.bfloat16 

29 if dtype is torch.float64: 

30 return tl.float64 

31 if dtype is torch.int8: 

32 return tl.int8 

33 if dtype is torch.uint8: 

34 return tl.uint8 

35 if dtype is torch.int16: 

36 return tl.int16 

37 if dtype is torch.int32: 

38 return tl.int32 

39 if dtype is torch.int64: 

40 return tl.int64 

41 if dtype is torch.bool: 

42 # Triton bool storage is not directly exposed; use int8 for 0/1 storage 

43 return tl.int8 

44 raise NotImplementedError(f"Unsupported dtype for Triton zeros_like: {dtype}") 

45 

46 

47def _launch_fill_zero(out: torch.Tensor, block_size: int = 4096): 

48 # Fallback for non-CUDA or empty tensors 

49 n_elements = out.numel() 

50 if n_elements == 0: 

51 return 

52 if not out.is_cuda: 

53 out.zero_() 

54 return 

55 # For simplicity, only handle contiguous tensors with the Triton kernel. 

56 # Fallback to PyTorch for non-contiguous outputs. 

57 if not out.is_contiguous(): 

58 out.zero_() 

59 return 

60 out_dtype = _torch_dtype_to_triton_dtype(out.dtype) 

61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

62 _fill_zero_kernel[grid](out, n_elements, BLOCK_SIZE=block_size, OUT_DTYPE=out_dtype) 

63 

64 

65def zeros_like(*args, **kwargs): 

66 # Extract input tensor (first positional or 'input'/'self' kw) 

67 inp = None 

68 if len(args) >= 1: 

69 inp = args[0] 

70 else: 

71 inp = kwargs.get("input", kwargs.get("self", None)) 

72 if inp is None: 

73 raise ValueError("zeros_like expects an input tensor as the first argument.") 

74 

75 dtype = kwargs.get("dtype", None) 

76 layout = kwargs.get("layout", None) 

77 device = kwargs.get("device", None) 

78 pin_memory = kwargs.get("pin_memory", None) 

79 memory_format = kwargs.get("memory_format", torch.preserve_format) 

80 

81 # Allocate output tensor with requested properties 

82 out = torch.empty_like( 

83 inp, 

84 dtype=dtype, 

85 layout=layout, 

86 device=device, 

87 pin_memory=pin_memory if pin_memory is not None else False, 

88 memory_format=memory_format, 

89 ) 

90 _launch_fill_zero(out) 

91 return out 

92 

93 

94def zeros_like_out(*args, **kwargs): 

95 # Expected signature: zeros_like.out(input, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None, out) # noqa: E501 

96 # Extract input and out tensors 

97 inp = None 

98 if len(args) >= 1: 

99 inp = args[0] 

100 else: 

101 inp = kwargs.get("input", kwargs.get("self", None)) 

102 out = kwargs.get("out", None) 

103 if out is None and len(args) >= 2: 

104 out = args[-1] 

105 if inp is None or out is None: 

106 raise ValueError("zeros_like_out expects 'input' and 'out' tensors.") 

107 

108 # Optional consistency checks per .out semantics (if provided) 

109 dtype = kwargs.get("dtype", None) 

110 device = kwargs.get("device", None) 

111 if dtype is not None and out.dtype != dtype: 

112 raise ValueError(f"Provided dtype {dtype} does not match out.dtype {out.dtype}") 

113 if device is not None and str(out.device) != str(device): 

114 raise ValueError( 

115 f"Provided device {device} does not match out.device {out.device}" 

116 ) 

117 # Shape/layout checks could be added; we keep minimal checks for generality. 

118 

119 _launch_fill_zero(out) 

120 return out