Coverage for src/flag_gems/ops/eye.py: 90%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
6from flag_gems.ops.eye_m import eye_kernel
7from flag_gems.runtime import device, torch_device_fn
9logger = logging.getLogger(__name__)
10device_ = device
13def eye(size, *, dtype=None, layout=torch.strided, device=None, pin_memory=None):
14 """
15 Triton-based implementation of torch.eye(n, n), using 2D tiles to split the matrix into blocks.
16 """
17 logger.debug("GEMS EYE")
19 if dtype is None:
20 dtype = torch.get_default_dtype()
21 if device is None:
22 device = torch.device(device_.name)
23 if layout != torch.strided:
24 raise ValueError("Currently only strided layout is supported for eye.")
26 out = torch.empty(
27 (size, size), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
28 )
29 BLOCK_SIZE = 32
30 grid = (triton.cdiv(size, BLOCK_SIZE), triton.cdiv(size, BLOCK_SIZE))
32 with torch_device_fn.device(device):
33 eye_kernel[grid](
34 out,
35 size,
36 size,
37 BLOCK_SIZE,
38 BLOCK_SIZE,
39 )
40 return out