Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/logspace.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@triton.jit
15def exp2_tmp(x):
16 LN2 = 0.69314718056
17 return tl.exp(x.to(tl.float32) * LN2)
20@libentry()
21@triton.jit
22def logspace_kernel(
23 out_ptr,
24 out_stride0,
25 start,
26 step_size,
27 steps,
28 log2_base: tl.constexpr,
29 BLOCK_SIZE: tl.constexpr,
30):
31 pid = tle.program_id(0)
32 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
33 mask = idx < steps
35 exponent = start + idx * step_size
36 vals = exp2_tmp(log2_base * exponent)
38 tl.store(out_ptr + idx * out_stride0, vals, mask=mask)
41def logspace(
42 start,
43 end,
44 steps,
45 base=10.0,
46 *,
47 dtype=None,
48 layout=None,
49 device=None,
50 pin_memory=None,
51) -> torch.Tensor:
52 logger.debug("GEMS LOGSPACE")
53 assert steps >= 0, "number of steps must be non-negative"
54 out_dtype = dtype if dtype is not None else torch.get_default_dtype()
55 out = torch.empty(
56 steps,
57 dtype=out_dtype,
58 layout=layout,
59 device=device,
60 pin_memory=pin_memory,
61 )
62 if steps == 0:
63 pass
64 elif steps == 1:
65 if isinstance(start, torch.Tensor):
66 start = start.item()
67 out = torch.fill(out, base**start)
68 else:
69 if isinstance(start, torch.Tensor):
70 start = start.item()
71 if isinstance(end, torch.Tensor):
72 end = end.item()
73 step_size = (float(end) - float(start)) / (steps - 1)
74 BLOCK_SIZE = 256 # according to benchmark, 256 is the best block size
75 grid = (triton.cdiv(steps, BLOCK_SIZE),)
76 logspace_kernel[grid](
77 out,
78 out.stride(0),
79 start,
80 step_size,
81 steps,
82 log2_base=math.log2(float(base)), # math.log2 require float input
83 BLOCK_SIZE=BLOCK_SIZE,
84 )
86 return out