Coverage for src/flag_gems/ops/eye.py: 90%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.ops.eye_m import eye_kernel 

7from flag_gems.runtime import device, torch_device_fn 

8 

9logger = logging.getLogger(__name__) 

10device_ = device 

11 

12 

13def eye(size, *, dtype=None, layout=torch.strided, device=None, pin_memory=None): 

14 """ 

15 Triton-based implementation of torch.eye(n, n), using 2D tiles to split the matrix into blocks. 

16 """ 

17 logger.debug("GEMS EYE") 

18 

19 if dtype is None: 

20 dtype = torch.get_default_dtype() 

21 if device is None: 

22 device = torch.device(device_.name) 

23 if layout != torch.strided: 

24 raise ValueError("Currently only strided layout is supported for eye.") 

25 

26 out = torch.empty( 

27 (size, size), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 

28 ) 

29 BLOCK_SIZE = 32 

30 grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE)) 

31 

32 with torch_device_fn.device(device): 

33 eye_kernel[grid]( 

34 out, 

35 size, 

36 size, 

37 BLOCK_SIZE, 

38 BLOCK_SIZE, 

39 ) 

40 return out