Coverage for src/flag_gems/experimental_ops/zero.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def zero_kernel( 

8 out_ptr, # *Pointer* to tensor to be zeroed 

9 n_elements, # Number of elements 

10 BLOCK_SIZE: tl.constexpr, 

11): 

12 pid = tl.program_id(axis=0) 

13 block_start = pid * BLOCK_SIZE 

14 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

15 mask = offsets < n_elements 

16 # Create a zero value with the correct dtype using a dummy load to infer dtype 

17 dummy = tl.load(out_ptr + offsets, mask=mask, other=0) 

18 z = tl.zeros([BLOCK_SIZE], dtype=dummy.dtype) 

19 tl.store(out_ptr + offsets, z, mask=mask) 

20 

21 

22def _launch_zero_kernel(tensor: torch.Tensor): 

23 assert isinstance(tensor, torch.Tensor), "Expected a torch.Tensor" 

24 assert tensor.is_cuda, "Tensor must be on CUDA device" 

25 assert tensor.is_contiguous(), "Tensor must be contiguous" 

26 assert tensor.numel() >= 0 

27 n_elements = tensor.numel() 

28 if n_elements == 0: 

29 return tensor 

30 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

31 zero_kernel[grid](tensor, n_elements, BLOCK_SIZE=1024) 

32 return tensor 

33 

34 

35def zero(*args, **kwargs): 

36 # Accept common conventions: first positional as target, or 'self'/'input'/'out' in kwargs 

37 target = None 

38 if len(args) >= 1 and isinstance(args[0], torch.Tensor): 

39 target = args[0] 

40 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

41 target = kwargs["self"] 

42 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

43 target = kwargs["input"] 

44 elif "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

45 target = kwargs["out"] 

46 else: 

47 raise ValueError( 

48 "zero expects a Tensor as the first argument or in kwargs as 'self', 'input', or 'out'" 

49 ) 

50 return _launch_zero_kernel(target) 

51 

52 

53def zero_out(*args, **kwargs): 

54 # Out variant: prefer 'out' kwarg; else first positional 

55 out = None 

56 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

57 out = kwargs["out"] 

58 elif len(args) >= 1 and isinstance(args[0], torch.Tensor): 

59 out = args[0] 

60 else: 

61 raise ValueError( 

62 "zero_out expects an output Tensor as the first positional argument or 'out' kwarg" 

63 ) 

64 return _launch_zero_kernel(out)