Coverage for src/flag_gems/experimental_ops/zeros_like.py: 0%
79 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _fill_zero_kernel(
8 out_ptr, # *Pointer* to output vector.
9 n_elements, # Number of elements to write.
10 BLOCK_SIZE: tl.constexpr, # Number of elements per program.
11 OUT_DTYPE: tl.constexpr, # Triton dtype for the output.
12):
13 pid = tl.program_id(axis=0)
14 block_start = pid * BLOCK_SIZE
15 offsets = block_start + tl.arange(0, BLOCK_SIZE)
16 mask = offsets < n_elements
17 zeros = tl.full([BLOCK_SIZE], 0, dtype=OUT_DTYPE)
18 tl.store(out_ptr + offsets, zeros, mask=mask)
21def _torch_dtype_to_triton_dtype(dtype: torch.dtype):
22 # Map torch dtypes to Triton dtypes
23 if dtype is torch.float32:
24 return tl.float32
25 if dtype is torch.float16:
26 return tl.float16
27 if dtype is torch.bfloat16:
28 return tl.bfloat16
29 if dtype is torch.float64:
30 return tl.float64
31 if dtype is torch.int8:
32 return tl.int8
33 if dtype is torch.uint8:
34 return tl.uint8
35 if dtype is torch.int16:
36 return tl.int16
37 if dtype is torch.int32:
38 return tl.int32
39 if dtype is torch.int64:
40 return tl.int64
41 if dtype is torch.bool:
42 # Triton bool storage is not directly exposed; use int8 for 0/1 storage
43 return tl.int8
44 raise NotImplementedError(f"Unsupported dtype for Triton zeros_like: {dtype}")
47def _launch_fill_zero(out: torch.Tensor, block_size: int = 4096):
48 # Fallback for non-CUDA or empty tensors
49 n_elements = out.numel()
50 if n_elements == 0:
51 return
52 if not out.is_cuda:
53 out.zero_()
54 return
55 # For simplicity, only handle contiguous tensors with the Triton kernel.
56 # Fallback to PyTorch for non-contiguous outputs.
57 if not out.is_contiguous():
58 out.zero_()
59 return
60 out_dtype = _torch_dtype_to_triton_dtype(out.dtype)
61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
62 _fill_zero_kernel[grid](out, n_elements, BLOCK_SIZE=block_size, OUT_DTYPE=out_dtype)
65def zeros_like(*args, **kwargs):
66 # Extract input tensor (first positional or 'input'/'self' kw)
67 inp = None
68 if len(args) >= 1:
69 inp = args[0]
70 else:
71 inp = kwargs.get("input", kwargs.get("self", None))
72 if inp is None:
73 raise ValueError("zeros_like expects an input tensor as the first argument.")
75 dtype = kwargs.get("dtype", None)
76 layout = kwargs.get("layout", None)
77 device = kwargs.get("device", None)
78 pin_memory = kwargs.get("pin_memory", None)
79 memory_format = kwargs.get("memory_format", torch.preserve_format)
81 # Allocate output tensor with requested properties
82 out = torch.empty_like(
83 inp,
84 dtype=dtype,
85 layout=layout,
86 device=device,
87 pin_memory=pin_memory if pin_memory is not None else False,
88 memory_format=memory_format,
89 )
90 _launch_fill_zero(out)
91 return out
94def zeros_like_out(*args, **kwargs):
95 # Expected signature: zeros_like.out(input, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None, out) # noqa: E501
96 # Extract input and out tensors
97 inp = None
98 if len(args) >= 1:
99 inp = args[0]
100 else:
101 inp = kwargs.get("input", kwargs.get("self", None))
102 out = kwargs.get("out", None)
103 if out is None and len(args) >= 2:
104 out = args[-1]
105 if inp is None or out is None:
106 raise ValueError("zeros_like_out expects 'input' and 'out' tensors.")
108 # Optional consistency checks per .out semantics (if provided)
109 dtype = kwargs.get("dtype", None)
110 device = kwargs.get("device", None)
111 if dtype is not None and out.dtype != dtype:
112 raise ValueError(f"Provided dtype {dtype} does not match out.dtype {out.dtype}")
113 if device is not None and str(out.device) != str(device):
114 raise ValueError(
115 f"Provided device {device} does not match out.device {out.device}"
116 )
117 # Shape/layout checks could be added; we keep minimal checks for generality.
119 _launch_fill_zero(out)
120 return out