Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/eye_m.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import device, torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11device_ = device 

12 

13 

14@libentry() 

15@triton.jit 

16def eye_kernel( 

17 out_ptr, 

18 N, 

19 M, 

20 BLOCK_i: tl.constexpr, 

21 BLOCK_j: tl.constexpr, 

22): 

23 pid_i = tl.program_id(0) # block id 

24 off_i = pid_i * BLOCK_i + tl.arange(0, BLOCK_i) 

25 mask_i = off_i < N 

26 

27 pid_j = tl.program_id(1) # block id 

28 off_j = pid_j * BLOCK_j + tl.arange(0, BLOCK_j) 

29 mask_j = off_j < M 

30 

31 val = tl.where(off_i[:, None] == off_j[None, :], 1.0, 0.0) 

32 mask = mask_i[:, None] & mask_j[None, :] 

33 off_ij = off_i[:, None] * M + off_j[None, :] 

34 

35 tl.store(out_ptr + off_ij, val, mask=mask) 

36 

37 

38def eye_m(n, m, *, dtype=None, layout=torch.strided, device=None, pin_memory=None): 

39 """ 

40 Triton-based implementation of torch.eye_m(n, m), using 2D tiles to split the matrix into blocks. 

41 """ 

42 logger.debug("GEMS EYE_M") 

43 if dtype is None: 

44 dtype = torch.get_default_dtype() 

45 if device is None: 

46 device = torch.device(device_.name) 

47 if layout != torch.strided: 

48 raise ValueError("Currently only strided layout is supported for eye_m.") 

49 

50 out = torch.empty( 

51 (n, m), dtype=dtype, device=device, layout=layout, pin_memory=pin_memory 

52 ) 

53 BLOCK_SIZE = 32 

54 grid = (triton.cdiv(n, BLOCK_SIZE), triton.cdiv(m, BLOCK_SIZE)) 

55 

56 with torch_device_fn.device(device): 

57 eye_kernel[grid]( 

58 out, 

59 n, 

60 m, 

61 BLOCK_SIZE, 

62 BLOCK_SIZE, 

63 ) 

64 return out