Coverage for src/flag_gems/runtime/backend/_cambricon/ops/zeros.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils.shape_utils import volume 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13device_ = device 

14 

15 

16@triton.autotune( 

17 configs=[ 

18 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

22 ], 

23 key=["n_elements"], 

24) 

25@triton.jit 

26def zeros_kernel( 

27 output_ptr, 

28 n_elements, 

29 BLOCK_SIZE: tl.constexpr, 

30): 

31 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 

32 num_jobs = tl.num_programs(axis=0) 

33 block_start = pid * BLOCK_SIZE 

34 step = num_jobs * BLOCK_SIZE 

35 block_start = block_start.to(tl.int64) 

36 for block_start_offset in range(block_start, n_elements, step): 

37 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

38 mask = offsets < n_elements 

39 tl.store(output_ptr + offsets, 0.0, mask=mask) 

40 

41 

42def zeros(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

43 logger.debug("GEMS_CAMBRICON ZEROS") 

44 if dtype is None: 

45 dtype = torch.get_default_dtype() 

46 if device is None: 

47 device = torch.device(device_.name) 

48 

49 out = torch.empty(size, device=device, dtype=dtype) 

50 N = volume(size) 

51 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

52 with torch_device_fn.device(device): 

53 zeros_kernel[grid_fn](out, N) 

54 return out 

55 

56 

57def zero_(x: torch.Tensor) -> torch.Tensor: 

58 logger.debug("GEMS_CAMBRICON ZERO_") 

59 N = x.numel() 

60 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

61 with torch_device_fn.device(x.device): 

62 zeros_kernel[grid_fn](x, N) 

63 return x