Coverage for src/flag_gems/runtime/backend/_cambricon/ops/ones.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9from flag_gems.utils.shape_utils import volume 

10 

11from ..utils import TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14device_ = device 

15 

16 

17@libentry() 

18@libtuner( 

19 configs=[ 

20 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

23 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

24 ], 

25 key=["n_elements"], 

26 strategy=["align32"], 

27) 

28@triton.jit 

29def ones_kernel( 

30 output_ptr, 

31 n_elements, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tl.program_id(axis=0) 

35 num_jobs = tl.num_programs(axis=0) 

36 block_start = pid * BLOCK_SIZE 

37 step = num_jobs * BLOCK_SIZE 

38 block_start = block_start.to(tl.int64) 

39 for block_start_offset in range(block_start, n_elements, step): 

40 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE) 

41 mask = offsets < n_elements 

42 tl.store(output_ptr + offsets, 1.0, mask=mask) 

43 

44 

45def ones(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

46 logger.debug("GEMS_CAMBRICON ONES") 

47 if dtype is None: 

48 dtype = torch.get_default_dtype() 

49 if device is None: 

50 device = torch.device(device_.name) 

51 

52 out = torch.empty(size, device=device, dtype=dtype) 

53 N = volume(size) 

54 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

55 with torch_device_fn.device(device): 

56 ones_kernel[grid_fn](out, N) 

57 return out