Coverage for src/flag_gems/runtime/backend/_mthreads/ops/utils.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import os 

2from collections import OrderedDict 

3 

4import numpy as np 

5import torch 

6import triton 

7import triton.language as tl 

8 

9_TMA_DESCRIPTOR_CACHE_MAXSIZE = 256 

10_tma_descriptor_cache = OrderedDict() 

11 

12# Detect once whether fill_2d_tma_descriptor expects a pointer (int) or numpy array. 

13# triton >= 3.2 changed the last parameter from numpy array to int pointer. 

14_fill_2d_tma = triton.runtime.driver.active.utils.fill_2d_tma_descriptor 

15_tma_desc_wants_ptr = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2) 

16 

17 

18def _tma_desc_arg(desc_np): 

19 return int(desc_np.ctypes.data) if _tma_desc_wants_ptr else desc_np 

20 

21 

22def create_tma_device_descriptor(tensor, block_m, block_n, device): 

23 assert tensor.dim() == 2, "TMA descriptor only supports 2D tensors" 

24 TMA_DESCRIPTOR_SIZE = 64 

25 desc_np = np.empty(TMA_DESCRIPTOR_SIZE, dtype=np.int8) 

26 shapes = [tensor.shape[0], tensor.shape[1]] 

27 if not tensor.is_contiguous(): 

28 assert ( 

29 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

30 ), "TMA descriptor only supports contiguous or transposed 2D tensors" 

31 shapes.reverse() 

32 _fill_2d_tma( 

33 tensor.data_ptr(), 

34 shapes[0], 

35 shapes[1], 

36 block_m, 

37 block_n, 

38 tensor.element_size(), 

39 _tma_desc_arg(desc_np), 

40 ) 

41 desc = torch.tensor(desc_np, device=device) 

42 return desc 

43 

44 

45def _tma_descriptor_cache_key(tensor, block_m, block_n, device): 

46 return ( 

47 tensor.data_ptr(), 

48 tuple(tensor.shape), 

49 tuple(tensor.stride()), 

50 str(tensor.dtype), 

51 block_m, 

52 block_n, 

53 str(device), 

54 ) 

55 

56 

57def get_cached_tma_device_descriptor(tensor, block_m, block_n, device): 

58 key = _tma_descriptor_cache_key(tensor, block_m, block_n, device) 

59 desc = _tma_descriptor_cache.get(key) 

60 if desc is not None: 

61 _tma_descriptor_cache.move_to_end(key) 

62 return desc 

63 

64 desc = create_tma_device_descriptor(tensor, block_m, block_n, device) 

65 _tma_descriptor_cache[key] = desc 

66 if len(_tma_descriptor_cache) > _TMA_DESCRIPTOR_CACHE_MAXSIZE: 

67 _tma_descriptor_cache.popitem(last=False) 

68 return desc 

69 

70 

71def get_triton_dtype(dtype): 

72 dtype_map = { 

73 torch.float16: tl.float16, 

74 torch.bfloat16: tl.bfloat16, 

75 torch.float32: tl.float32, 

76 } 

77 return dtype_map.get(dtype, None) 

78 

79 

80def should_enable_sqmma(a_dtype, b_dtype, M, N, K): 

81 return ( 

82 (os.getenv("MUSA_ENABLE_SQMMA", "0") == "1") 

83 and (a_dtype in [torch.float16, torch.bfloat16] and a_dtype.itemsize == 2) 

84 and ((M, N, K) not in [(1, 1, 32), (15, 160, 1024), (495, 5333, 71)]) 

85 )