Coverage for src/flag_gems/runtime/backend/_mthreads/ops/log.py: 0%
57 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
2from typing import Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.ops.log import log as default_log # fallback
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils.triton_lang_helper import tl_extra_shim
13logger = logging.getLogger(
14 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
15)
17_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float32}
20@libentry()
21@triton.autotune(
22 configs=[
23 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=2),
24 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=2),
25 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=2),
26 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2),
27 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=2),
28 ],
29 key=["n_elements", "dtype_size"],
30)
31@triton.jit
32def log_kernel(
33 x_ptr,
34 out_ptr,
35 n_elements,
36 dtype_size,
37 BLOCK_SIZE: tl.constexpr,
38 USE_APPROX: tl.constexpr,
39):
40 pid = tl.program_id(0)
41 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 mask = offsets < n_elements
43 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
44 x_fp32 = x.to(tl.float32)
45 if USE_APPROX:
46 pos_mask = x_fp32 > 0
47 zero_mask = x_fp32 == 0
48 ix = x_fp32.to(tl.int32, bitcast=True)
49 exp = (ix >> 23) & 0xFF
50 mant = (ix & 0x7FFFFF) | 0x3F800000
51 m = mant.to(tl.float32, bitcast=True)
52 k = exp.to(tl.int32) - 127
53 t = (m - 1.0) / (m + 1.0)
54 t2 = t * t
55 # t4 = t2 * t2
56 # t6 = t4 * t2
57 log_m = 2.0 * (t + t2 * t * (1.0 / 3.0 + t2 * (1.0 / 5.0 + t2 * (1.0 / 7.0))))
58 log_val = log_m + k.to(tl.float32) * 0.6931471805599453
59 nan_or_inf = tl.where(zero_mask, -float("inf"), float("nan"))
60 y = tl.where(pos_mask, log_val, nan_or_inf)
61 else:
62 y = tl_extra_shim.log(x_fp32)
63 tl.store(out_ptr + offsets, y, mask=mask)
66def _use_triton_kernel(x: torch.Tensor) -> Tuple[bool, int]:
67 if not isinstance(x, torch.Tensor):
68 return False, 0
69 if x.device.type != "musa" or x.dtype not in _SUPPORTED_DTYPES:
70 return False, 0
71 if x.numel() == 0 or not x.is_contiguous():
72 return False, 0
73 return True, x.element_size()
76def _launch_log(x: torch.Tensor, out: torch.Tensor, dtype_size: int):
77 n_elements = out.numel()
78 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
79 with torch_device_fn.device(out.device):
80 log_kernel[grid](x, out, n_elements, dtype_size, USE_APPROX=dtype_size == 2)
81 return out
84def log(x):
85 logger.debug("GEMS_MTHREADS LOG")
86 use_triton, dtype_size = _use_triton_kernel(x)
87 if not use_triton:
88 return default_log(x)
90 out = torch.empty_like(x)
91 return _launch_log(x, out, dtype_size)