Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/eye.py: 0%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
6from flag_gems.runtime import device, torch_device_fn
8from .eye_m import eye_kernel
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11device_ = device
14def eye(size, *, dtype=None, layout=torch.strided, device=None, pin_memory=None):
15 """
16 Triton-based implementation of torch.eye(n, n), using 2D tiles to split the matrix into blocks.
17 """
18 logger.debug("GEMS EYE")
20 if dtype is None:
21 dtype = torch.get_default_dtype()
22 if device is None:
23 device = torch.device(device_.name)
24 if layout != torch.strided:
25 raise ValueError("Currently only strided layout is supported for eye.")
27 out = torch.empty(
28 (size, size), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
29 )
30 BLOCK_SIZE = 32
31 grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE))
33 with torch_device_fn.device(device):
34 eye_kernel[grid](
35 out,
36 size,
37 size,
38 BLOCK_SIZE,
39 BLOCK_SIZE,
40 )
41 return out