Coverage for src/flag_gems/ops/zero.py: 62%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def zero_kernel( 

13 out_ptr, # *Pointer* to tensor to be zeroed 

14 n_elements, # Number of elements 

15 BLOCK_SIZE: tl.constexpr, 

16): 

17 pid = tl.program_id(axis=0) 

18 block_start = pid * BLOCK_SIZE 

19 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < n_elements 

21 # Create a zero value with the correct dtype using a dummy load to infer dtype 

22 dummy = tl.load(out_ptr + offsets, mask=mask, other=0) 

23 z = tl.zeros([BLOCK_SIZE], dtype=dummy.dtype) 

24 tl.store(out_ptr + offsets, z, mask=mask) 

25 

26 

27def _launch_zero_kernel(tensor: torch.Tensor): 

28 assert isinstance(tensor, torch.Tensor), "Expected a torch.Tensor" 

29 assert tensor.is_cuda, "Tensor must be on CUDA device" 

30 assert tensor.is_contiguous(), "Tensor must be contiguous" 

31 assert tensor.numel() >= 0 

32 n_elements = tensor.numel() 

33 if n_elements == 0: 

34 return tensor 

35 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

36 zero_kernel[grid](tensor, n_elements, BLOCK_SIZE=1024) 

37 return tensor 

38 

39 

40def zero(*args, **kwargs): 

41 logger.debug("GEMS ZERO") 

42 # Accept common conventions: first positional as target, or 'self'/'input'/'out' in kwargs 

43 target = None 

44 if len(args) >= 1 and isinstance(args[0], torch.Tensor): 

45 target = args[0] 

46 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor): 

47 target = kwargs["self"] 

48 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

49 target = kwargs["input"] 

50 elif "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

51 target = kwargs["out"] 

52 else: 

53 raise ValueError( 

54 "zero expects a Tensor as the first argument or in kwargs as 'self', 'input', or 'out'" 

55 ) 

56 return _launch_zero_kernel(target) 

57 

58 

59def zero_out(*args, **kwargs): 

60 logger.debug("GEMS ZERO_OUT") 

61 # Out variant: prefer 'out' kwarg; else first positional 

62 out = None 

63 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

64 out = kwargs["out"] 

65 elif len(args) >= 1 and isinstance(args[0], torch.Tensor): 

66 out = args[0] 

67 else: 

68 raise ValueError( 

69 "zero_out expects an output Tensor as the first positional argument or 'out' kwarg" 

70 ) 

71 return _launch_zero_kernel(out)