Coverage for src/flag_gems/runtime/backend/_mthreads/ops/log.py: 0%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2from typing import Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.ops.log import log as default_log # fallback 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils.triton_lang_helper import tl_extra_shim 

12 

13logger = logging.getLogger( 

14 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

15) 

16 

17_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float32} 

18 

19 

20@libentry() 

21@triton.autotune( 

22 configs=[ 

23 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=2), 

24 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=2), 

25 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=2), 

26 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=2), 

27 triton.Config({"BLOCK_SIZE": 4096}, num_warps=16, num_stages=2), 

28 ], 

29 key=["n_elements", "dtype_size"], 

30) 

31@triton.jit 

32def log_kernel( 

33 x_ptr, 

34 out_ptr, 

35 n_elements, 

36 dtype_size, 

37 BLOCK_SIZE: tl.constexpr, 

38 USE_APPROX: tl.constexpr, 

39): 

40 pid = tl.program_id(0) 

41 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

42 mask = offsets < n_elements 

43 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

44 x_fp32 = x.to(tl.float32) 

45 if USE_APPROX: 

46 pos_mask = x_fp32 > 0 

47 zero_mask = x_fp32 == 0 

48 ix = x_fp32.to(tl.int32, bitcast=True) 

49 exp = (ix >> 23) & 0xFF 

50 mant = (ix & 0x7FFFFF) | 0x3F800000 

51 m = mant.to(tl.float32, bitcast=True) 

52 k = exp.to(tl.int32) - 127 

53 t = (m - 1.0) / (m + 1.0) 

54 t2 = t * t 

55 # t4 = t2 * t2 

56 # t6 = t4 * t2 

57 log_m = 2.0 * (t + t2 * t * (1.0 / 3.0 + t2 * (1.0 / 5.0 + t2 * (1.0 / 7.0)))) 

58 log_val = log_m + k.to(tl.float32) * 0.6931471805599453 

59 nan_or_inf = tl.where(zero_mask, -float("inf"), float("nan")) 

60 y = tl.where(pos_mask, log_val, nan_or_inf) 

61 else: 

62 y = tl_extra_shim.log(x_fp32) 

63 tl.store(out_ptr + offsets, y, mask=mask) 

64 

65 

66def _use_triton_kernel(x: torch.Tensor) -> Tuple[bool, int]: 

67 if not isinstance(x, torch.Tensor): 

68 return False, 0 

69 if x.device.type != "musa" or x.dtype not in _SUPPORTED_DTYPES: 

70 return False, 0 

71 if x.numel() == 0 or not x.is_contiguous(): 

72 return False, 0 

73 return True, x.element_size() 

74 

75 

76def _launch_log(x: torch.Tensor, out: torch.Tensor, dtype_size: int): 

77 n_elements = out.numel() 

78 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

79 with torch_device_fn.device(out.device): 

80 log_kernel[grid](x, out, n_elements, dtype_size, USE_APPROX=dtype_size == 2) 

81 return out 

82 

83 

84def log(x): 

85 logger.debug("GEMS_MTHREADS LOG") 

86 use_triton, dtype_size = _use_triton_kernel(x) 

87 if not use_triton: 

88 return default_log(x) 

89 

90 out = torch.empty_like(x) 

91 return _launch_log(x, out, dtype_size)