Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/eye.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.runtime import device, torch_device_fn 

7 

8from .eye_m import eye_kernel 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11device_ = device 

12 

13 

14def eye(size, *, dtype=None, layout=torch.strided, device=None, pin_memory=None): 

15 """ 

16 Triton-based implementation of torch.eye(n, n), using 2D tiles to split the matrix into blocks. 

17 """ 

18 logger.debug("GEMS EYE") 

19 

20 if dtype is None: 

21 dtype = torch.get_default_dtype() 

22 if device is None: 

23 device = torch.device(device_.name) 

24 if layout != torch.strided: 

25 raise ValueError("Currently only strided layout is supported for eye.") 

26 

27 out = torch.empty( 

28 (size, size), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 

29 ) 

30 BLOCK_SIZE = 32 

31 grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE)) 

32 

33 with torch_device_fn.device(device): 

34 eye_kernel[grid]( 

35 out, 

36 size, 

37 size, 

38 BLOCK_SIZE, 

39 BLOCK_SIZE, 

40 ) 

41 return out