Coverage for src/flag_gems/ops/eye_m.py: 66%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
11device_ = device
14@libentry()
15@triton.jit
16def eye_kernel(
17 out_ptr,
18 N,
19 M,
20 BLOCK_i: tl.constexpr,
21 BLOCK_j: tl.constexpr,
22):
23 pid_i = tl.program_id(0) # block id
24 off_i = pid_i * BLOCK_i + tl.arange(0, BLOCK_i)
25 mask_i = off_i < N
27 pid_j = tl.program_id(1) # block id
28 off_j = pid_j * BLOCK_j + tl.arange(0, BLOCK_j)
29 mask_j = off_j < M
31 val = tl.where(off_i[:, None] == off_j[None, :], 1.0, 0.0)
32 mask = mask_i[:, None] & mask_j[None, :]
33 off_ij = off_i[:, None] * M + off_j[None, :]
35 tl.store(out_ptr + off_ij, val, mask=mask)
38def eye_m(n, m, *, dtype=None, layout=torch.strided, device=None, pin_memory=None):
39 """
40 Triton-based implementation of torch.eye_m(n, m), using 2D tiles to split the matrix into blocks.
41 """
42 logger.debug("GEMS EYE_M")
43 if dtype is None:
44 dtype = torch.get_default_dtype()
45 if device is None:
46 device = torch.device(device_.name)
47 if layout != torch.strided:
48 raise ValueError("Currently only strided layout is supported for eye_m.")
50 out = torch.empty(
51 (n, m), dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
52 )
53 BLOCK_SIZE = 32
54 grid = (triton.cdiv(n, BLOCK_SIZE), triton.cdiv(m, BLOCK_SIZE))
56 with torch_device_fn.device(device):
57 eye_kernel[grid](
58 out,
59 n,
60 m,
61 BLOCK_SIZE,
62 BLOCK_SIZE,
63 )
64 return out