Coverage for src/flag_gems/experimental_ops/zero.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def zero_kernel(
8 out_ptr, # *Pointer* to tensor to be zeroed
9 n_elements, # Number of elements
10 BLOCK_SIZE: tl.constexpr,
11):
12 pid = tl.program_id(axis=0)
13 block_start = pid * BLOCK_SIZE
14 offsets = block_start + tl.arange(0, BLOCK_SIZE)
15 mask = offsets < n_elements
16 # Create a zero value with the correct dtype using a dummy load to infer dtype
17 dummy = tl.load(out_ptr + offsets, mask=mask, other=0)
18 z = tl.zeros([BLOCK_SIZE], dtype=dummy.dtype)
19 tl.store(out_ptr + offsets, z, mask=mask)
22def _launch_zero_kernel(tensor: torch.Tensor):
23 assert isinstance(tensor, torch.Tensor), "Expected a torch.Tensor"
24 assert tensor.is_cuda, "Tensor must be on CUDA device"
25 assert tensor.is_contiguous(), "Tensor must be contiguous"
26 assert tensor.numel() >= 0
27 n_elements = tensor.numel()
28 if n_elements == 0:
29 return tensor
30 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
31 zero_kernel[grid](tensor, n_elements, BLOCK_SIZE=1024)
32 return tensor
35def zero(*args, **kwargs):
36 # Accept common conventions: first positional as target, or 'self'/'input'/'out' in kwargs
37 target = None
38 if len(args) >= 1 and isinstance(args[0], torch.Tensor):
39 target = args[0]
40 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor):
41 target = kwargs["self"]
42 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor):
43 target = kwargs["input"]
44 elif "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
45 target = kwargs["out"]
46 else:
47 raise ValueError(
48 "zero expects a Tensor as the first argument or in kwargs as 'self', 'input', or 'out'"
49 )
50 return _launch_zero_kernel(target)
53def zero_out(*args, **kwargs):
54 # Out variant: prefer 'out' kwarg; else first positional
55 out = None
56 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
57 out = kwargs["out"]
58 elif len(args) >= 1 and isinstance(args[0], torch.Tensor):
59 out = args[0]
60 else:
61 raise ValueError(
62 "zero_out expects an output Tensor as the first positional argument or 'out' kwarg"
63 )
64 return _launch_zero_kernel(out)