Coverage for src/flag_gems/runtime/backend/_mthreads/ops/utils.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import os 

2 

3import numpy as np 

4import torch 

5import triton 

6import triton.language as tl 

7 

8 

9def create_tma_device_descriptor(tensor, block_m, block_n, device): 

10 assert tensor.dim() == 2, "TMA descriptor only supports 2D tensors" 

11 TMA_DESCRIPTOR_SIZE = 64 

12 desc_np = np.empty(TMA_DESCRIPTOR_SIZE, dtype=np.int8) 

13 shapes = [tensor.shape[0], tensor.shape[1]] 

14 if not tensor.is_contiguous(): 

15 assert ( 

16 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

17 ), "TMA descriptor only supports contiguous or transposed 2D tensors" 

18 shapes.reverse() 

19 triton.runtime.driver.active.utils.fill_2d_tma_descriptor( 

20 tensor.data_ptr(), 

21 shapes[0], 

22 shapes[1], 

23 block_m, 

24 block_n, 

25 tensor.element_size(), 

26 desc_np, 

27 ) 

28 desc = torch.tensor(desc_np, device=device) 

29 return desc 

30 

31 

32def get_triton_dtype(dtype): 

33 dtype_map = { 

34 torch.float16: tl.float16, 

35 torch.bfloat16: tl.bfloat16, 

36 torch.float32: tl.float32, 

37 } 

38 return dtype_map.get(dtype, None) 

39 

40 

41def should_enable_sqmma(a_dtype, b_dtype, M, N, K): 

42 return ( 

43 (os.getenv("MUSA_ENABLE_SQMMA", "0") == "1") 

44 and (a_dtype in [torch.float16, torch.bfloat16] and a_dtype.itemsize == 2) 

45 and ((M, N, K) not in [(1, 1, 32), (15, 160, 1024), (495, 5333, 71)]) 

46 )