Coverage for src/flag_gems/experimental_ops/eye.py: 0%
53 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def eye_kernel(
8 out_ptr, # *Pointer* to output 2D tensor
9 n_rows, # number of rows (n)
10 n_cols, # number of cols (m)
11 stride_row, # stride for row dimension
12 stride_col, # stride for col dimension
13 BLOCK_M: tl.constexpr,
14 BLOCK_N: tl.constexpr,
15):
16 pid_m = tl.program_id(axis=0)
17 pid_n = tl.program_id(axis=1)
19 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
20 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
22 row_idx = offs_m[:, None]
23 col_idx = offs_n[None, :]
25 in_bounds = (row_idx < n_rows) & (col_idx < n_cols)
26 is_diag = row_idx == col_idx
28 # Produce 1 on diagonal, 0 elsewhere; Triton will cast to the pointer dtype on store
29 vals = tl.where(is_diag, 1, 0)
30 ptrs = out_ptr + row_idx * stride_row + col_idx * stride_col
31 tl.store(ptrs, vals, mask=in_bounds)
34# Shared implementation
35def _eye_impl(n, m=None, dtype=None, device=None, out: torch.Tensor = None):
36 if m is None:
37 m = n
39 if out is None:
40 if dtype is None:
41 dtype = torch.get_default_dtype()
42 if device is None:
43 device = (
44 torch.device("cuda")
45 if torch.cuda.is_available()
46 else torch.device("cpu")
47 )
48 out = torch.empty((n, m), dtype=dtype, device=device)
49 else:
50 if out.dim() != 2:
51 raise ValueError("out tensor must be 2D")
52 # Resize to expected shape if necessary
53 if out.shape[0] != n or out.shape[1] != m:
54 out.resize_(n, m)
56 # Handle empty tensors
57 if n == 0 or m == 0:
58 out.zero_()
59 return out
61 # CUDA path uses Triton
62 if out.is_cuda:
63 BLOCK_M = 64
64 BLOCK_N = 64
65 grid = (triton.cdiv(n, BLOCK_M), triton.cdiv(m, BLOCK_N))
66 eye_kernel[grid](
67 out,
68 n,
69 m,
70 out.stride(0),
71 out.stride(1),
72 BLOCK_M=BLOCK_M,
73 BLOCK_N=BLOCK_N,
74 )
75 return out
77 # CPU fallback without calling torch.eye
78 out.zero_()
79 k = min(n, m)
80 if k > 0:
81 idx = torch.arange(k, device=out.device)
82 one = torch.ones(k, dtype=out.dtype, device=out.device)
83 out[idx, idx] = one
84 return out
87# Wrappers for ATen operator interfaces
90def eye(n, m=None, dtype=None, device=None):
91 return _eye_impl(n, m, dtype, device, out=None)
94def eye_m(n, m, dtype=None, device=None):
95 return _eye_impl(n, m, dtype, device, out=None)
98def eye_out(n, out: torch.Tensor):
99 # eye.out expects shape (n, n)
100 return _eye_impl(n, n, dtype=out.dtype, device=out.device, out=out)
103def eye_m_out(n, m, out: torch.Tensor):
104 # eye.m_out expects shape (n, m)
105 return _eye_impl(n, m, dtype=out.dtype, device=out.device, out=out)