Coverage for src/flag_gems/runtime/backend/_mthreads/ops/utils.py: 0%
21 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import os
3import numpy as np
4import torch
5import triton
6import triton.language as tl
9def create_tma_device_descriptor(tensor, block_m, block_n, device):
10 assert tensor.dim() == 2, "TMA descriptor only supports 2D tensors"
11 TMA_DESCRIPTOR_SIZE = 64
12 desc_np = np.empty(TMA_DESCRIPTOR_SIZE, dtype=np.int8)
13 shapes = [tensor.shape[0], tensor.shape[1]]
14 if not tensor.is_contiguous():
15 assert (
16 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
17 ), "TMA descriptor only supports contiguous or transposed 2D tensors"
18 shapes.reverse()
19 triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
20 tensor.data_ptr(),
21 shapes[0],
22 shapes[1],
23 block_m,
24 block_n,
25 tensor.element_size(),
26 desc_np,
27 )
28 desc = torch.tensor(desc_np, device=device)
29 return desc
32def get_triton_dtype(dtype):
33 dtype_map = {
34 torch.float16: tl.float16,
35 torch.bfloat16: tl.bfloat16,
36 torch.float32: tl.float32,
37 }
38 return dtype_map.get(dtype, None)
41def should_enable_sqmma(a_dtype, b_dtype, M, N, K):
42 return (
43 (os.getenv("MUSA_ENABLE_SQMMA", "0") == "1")
44 and (a_dtype in [torch.float16, torch.bfloat16] and a_dtype.itemsize == 2)
45 and ((M, N, K) not in [(1, 1, 32), (15, 160, 1024), (495, 5333, 71)])
46 )