Coverage for src/flag_gems/ops/logspace.py: 76%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from ..utils import libentry
9from ..utils import triton_lang_extension as tle
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def logspace_kernel(
17 out_ptr,
18 out_stride0,
19 start,
20 step_size,
21 steps,
22 log2_base: tl.constexpr,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
26 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27 mask = idx < steps
29 exponent = start + idx * step_size
30 vals = tl.exp2(log2_base * exponent)
32 tl.store(out_ptr + idx * out_stride0, vals, mask=mask)
35def logspace(
36 start,
37 end,
38 steps,
39 base=10.0,
40 *,
41 dtype=None,
42 layout=None,
43 device=None,
44 pin_memory=None,
45) -> torch.Tensor:
46 logger.debug("GEMS LOGSPACE")
47 assert steps >= 0, "number of steps must be non-negative"
48 out_dtype = dtype if dtype is not None else torch.get_default_dtype()
49 out = torch.empty(
50 steps,
51 dtype=out_dtype,
52 layout=layout,
53 device=device,
54 pin_memory=pin_memory,
55 )
56 if steps == 0:
57 pass
58 elif steps == 1:
59 if isinstance(start, torch.Tensor):
60 start = start.item()
61 out = torch.fill(out, base**start)
62 else:
63 if isinstance(start, torch.Tensor):
64 start = start.item()
65 if isinstance(end, torch.Tensor):
66 end = end.item()
67 step_size = (float(end) - float(start)) / (steps - 1)
68 BLOCK_SIZE = 256 # according to benchmark, 256 is the best block size
69 grid = (triton.cdiv(steps, BLOCK_SIZE),)
70 logspace_kernel[grid](
71 out,
72 out.stride(0),
73 start,
74 step_size,
75 steps,
76 log2_base=math.log2(float(base)), # math.log2 require float input
77 BLOCK_SIZE=BLOCK_SIZE,
78 )
80 return out