Coverage for src/flag_gems/experimental_ops/reciprocal.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def reciprocal_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 one = tl.full([BLOCK_SIZE], 1, x.dtype) 

15 y = one / x 

16 tl.store(out_ptr + offsets, y, mask=mask) 

17 

18 

19def _reciprocal_impl(x: torch.Tensor, out: torch.Tensor = None): 

20 # Fallback for unsupported dtypes/devices 

21 if not x.is_cuda or x.is_complex(): 

22 if out is None: 

23 return torch.ops.aten.reciprocal(x) 

24 else: 

25 return torch.ops.aten.reciprocal.out(x, out=out) 

26 

27 if out is None: 

28 out = torch.empty_like(x) 

29 

30 # Ensure same device and dtype 

31 assert out.device == x.device, "Input and output must be on the same device" 

32 assert out.dtype == x.dtype, "Output dtype must match input dtype" 

33 assert ( 

34 out.numel() == x.numel() 

35 ), "Output must have the same number of elements as input" 

36 

37 x_contig = x.contiguous() 

38 out_contig = out.contiguous() 

39 

40 n_elements = x_contig.numel() 

41 if n_elements == 0: 

42 return out # nothing to do 

43 

44 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

45 reciprocal_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=1024) 

46 

47 if out is not out_contig: 

48 out.copy_(out_contig) 

49 return out 

50 

51 

52# ('reciprocal', <Autograd.disable: False>) 

53def reciprocal(*args, **kwargs): 

54 # Accept a single tensor argument 

55 x = None 

56 if len(args) >= 1: 

57 x = args[0] 

58 else: 

59 # Try common keyword names 

60 x = kwargs.get( 

61 "input", kwargs.get("self", kwargs.get("a", kwargs.get("args", None))) 

62 ) 

63 if x is None: 

64 raise ValueError("reciprocal expects a tensor as the first argument") 

65 return _reciprocal_impl(x) 

66 

67 

68# ('reciprocal.out', <Autograd.disable: False>) 

69def reciprocal_out(*args, **kwargs): 

70 # Accept (x, out) or keyword args self/input and out 

71 x = None 

72 out = None 

73 if len(args) >= 2: 

74 x, out = args[0], args[1] 

75 elif len(args) == 1: 

76 x = args[0] 

77 out = kwargs.get("out", None) 

78 else: 

79 x = kwargs.get("input", kwargs.get("self", kwargs.get("a", None))) 

80 out = kwargs.get("out", None) 

81 

82 if x is None or out is None: 

83 raise ValueError("reciprocal_out expects arguments (input, out)") 

84 

85 _reciprocal_impl(x, out=out) 

86 return out