Coverage for src/flag_gems/runtime/backend/_cambricon/ops/logspace.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import libentry
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit
17def logspace_kernel(
18 out_ptr,
19 out_stride0,
20 start,
21 step_size,
22 steps,
23 log2_base: tl.constexpr,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tl.program_id(0)
27 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 mask = idx < steps
30 exponent = start + idx * step_size
31 vals = tl.exp2(log2_base * exponent)
33 tl.store(out_ptr + idx * out_stride0, vals, mask=mask)
36def logspace(
37 start,
38 end,
39 steps,
40 base=10.0,
41 *,
42 dtype=None,
43 layout=None,
44 device=None,
45 pin_memory=None,
46) -> torch.Tensor:
47 logger.debug("GEMS_CAMBRICON LOGSPACE")
48 assert steps >= 0, "number of steps must be non-negative"
49 out_dtype = dtype if dtype is not None else torch.get_default_dtype()
50 out = torch.empty(
51 steps,
52 dtype=out_dtype,
53 layout=layout,
54 device=device,
55 pin_memory=pin_memory,
56 )
57 if steps == 0:
58 pass
59 elif steps == 1:
60 if isinstance(start, torch.Tensor):
61 start = start.item()
62 out = torch.fill(out, base**start)
63 else:
64 if isinstance(start, torch.Tensor):
65 start = start.item()
66 if isinstance(end, torch.Tensor):
67 end = end.item()
68 step_size = (float(end) - float(start)) / (steps - 1)
69 BLOCK_SIZE = 256 # according to benchmark, 256 is the best block size
70 grid = lambda meta: (min(triton.cdiv(steps, BLOCK_SIZE), TOTAL_CORE_NUM),)
71 logspace_kernel[grid](
72 out,
73 out.stride(0),
74 start,
75 step_size,
76 steps,
77 log2_base=math.log2(float(base)), # math.log2 require float input
78 BLOCK_SIZE=BLOCK_SIZE,
79 )
81 return out