Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/ones.py: 0%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.shape_utils import volume 

11 

12device_ = device 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@triton.jit 

18def ones_kernel( 

19 output_ptr, 

20 n_elements, 

21 value, 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 pid = tle.program_id(axis=0) 

25 block_start = pid * BLOCK_SIZE 

26 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

27 mask = offsets < n_elements 

28 tl.store(output_ptr + offsets, value, mask=mask) 

29 

30 

31def ones(size, *, dtype=None, layout=None, device=None, pin_memory=None): 

32 logger.debug("GEMS ONES") 

33 if dtype is None: 

34 dtype = torch.get_default_dtype() 

35 if device is None: 

36 device = torch.device(device_.name) 

37 

38 out = torch.empty(size, device=device, dtype=dtype) 

39 N = volume(size) 

40 if N == 0: 

41 return out 

42 

43 grid_fn = (12, 1, 1) 

44 block_size = triton.next_power_of_2(triton.cdiv(N, 12)) 

45 with torch_device_fn.device(device): 

46 ones_kernel[grid_fn]( 

47 out, 

48 N, 

49 1.0, 

50 BLOCK_SIZE=block_size, 

51 buffer_size_limit=2048, 

52 isCloseDtypeConvert=True, 

53 ) 

54 return out