Coverage for src/flag_gems/ops/mv.py: 56%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=runtime.get_tuned_config("mv"), 

18 key=["M", "N"], 

19) 

20@triton.jit 

21def mv_kernel( 

22 A, 

23 B, 

24 C, 

25 N, 

26 M, 

27 stride_an, 

28 stride_am, 

29 stride_bm, 

30 stride_cn, 

31 BLOCK_N: tl.constexpr, 

32 BLOCK_M: tl.constexpr, 

33): 

34 pid = tle.program_id(0) 

35 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None] 

36 offset_m = tl.arange(0, BLOCK_M)[None, :] 

37 n_mask = offset_n < N 

38 A_ptrs = A + offset_n * stride_an + offset_m * stride_am 

39 B_ptrs = B + offset_m * stride_bm 

40 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32) 

41 for m in range(0, M, BLOCK_M): 

42 m_mask = m + offset_m < M 

43 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32) 

44 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32) 

45 acc += a * b 

46 A_ptrs += BLOCK_M * stride_am 

47 B_ptrs += BLOCK_M * stride_bm 

48 

49 acc = tl.sum(acc, axis=1) 

50 C_ptrs = C + offset_n * stride_cn 

51 tl.store(C_ptrs, acc[:, None], mask=n_mask) 

52 

53 

54def mv(inp, vec): 

55 logger.debug("GEMS MV") 

56 assert inp.shape[1] == vec.shape[0], "incompatible dimensions" 

57 N, M = inp.shape 

58 out = torch.empty((N,), device=inp.device, dtype=inp.dtype) 

59 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),) 

60 with torch_device_fn.device(inp.device): 

61 mv_kernel[grid]( 

62 inp, 

63 vec, 

64 out, 

65 N, 

66 M, 

67 inp.stride(0), 

68 inp.stride(1), 

69 vec.stride(0), 

70 out.stride(0), 

71 ) 

72 return out