Coverage for src/flag_gems/experimental_ops/reciprocal.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def reciprocal_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 one = tl.full([BLOCK_SIZE], 1, x.dtype)
15 y = one / x
16 tl.store(out_ptr + offsets, y, mask=mask)
19def _reciprocal_impl(x: torch.Tensor, out: torch.Tensor = None):
20 # Fallback for unsupported dtypes/devices
21 if not x.is_cuda or x.is_complex():
22 if out is None:
23 return torch.ops.aten.reciprocal(x)
24 else:
25 return torch.ops.aten.reciprocal.out(x, out=out)
27 if out is None:
28 out = torch.empty_like(x)
30 # Ensure same device and dtype
31 assert out.device == x.device, "Input and output must be on the same device"
32 assert out.dtype == x.dtype, "Output dtype must match input dtype"
33 assert (
34 out.numel() == x.numel()
35 ), "Output must have the same number of elements as input"
37 x_contig = x.contiguous()
38 out_contig = out.contiguous()
40 n_elements = x_contig.numel()
41 if n_elements == 0:
42 return out # nothing to do
44 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
45 reciprocal_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=1024)
47 if out is not out_contig:
48 out.copy_(out_contig)
49 return out
52# ('reciprocal', <Autograd.disable: False>)
53def reciprocal(*args, **kwargs):
54 # Accept a single tensor argument
55 x = None
56 if len(args) >= 1:
57 x = args[0]
58 else:
59 # Try common keyword names
60 x = kwargs.get(
61 "input", kwargs.get("self", kwargs.get("a", kwargs.get("args", None)))
62 )
63 if x is None:
64 raise ValueError("reciprocal expects a tensor as the first argument")
65 return _reciprocal_impl(x)
68# ('reciprocal.out', <Autograd.disable: False>)
69def reciprocal_out(*args, **kwargs):
70 # Accept (x, out) or keyword args self/input and out
71 x = None
72 out = None
73 if len(args) >= 2:
74 x, out = args[0], args[1]
75 elif len(args) == 1:
76 x = args[0]
77 out = kwargs.get("out", None)
78 else:
79 x = kwargs.get("input", kwargs.get("self", kwargs.get("a", None)))
80 out = kwargs.get("out", None)
82 if x is None or out is None:
83 raise ValueError("reciprocal_out expects arguments (input, out)")
85 _reciprocal_impl(x, out=out)
86 return out