Coverage for src/flag_gems/experimental_ops/hypot_.py: 0%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def hypot_( 

8 x_ptr, # Pointer to first input (will be output if in-place). 

9 y_ptr, # Pointer to second input (broadcasted/contiguous). 

10 out_ptr, # Pointer to output buffer. 

11 n_elements, # Number of elements to process. 

12 BLOCK_SIZE: tl.constexpr, 

13): 

14 pid = tl.program_id(axis=0) 

15 block_start = pid * BLOCK_SIZE 

16 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

17 mask = offsets < n_elements 

18 

19 x = tl.load(x_ptr + offsets, mask=mask) 

20 y = tl.load(y_ptr + offsets, mask=mask) 

21 

22 x32 = x.to(tl.float32) 

23 y32 = y.to(tl.float32) 

24 out32 = tl.sqrt(x32 * x32 + y32 * y32) 

25 

26 out_cast = out32.to(x.dtype) 

27 tl.store(out_ptr + offsets, out_cast, mask=mask) 

28 

29 

30_hypot_kernel = hypot_ 

31 

32 

33def hypot_(*args, **kwargs): 

34 # Extract arguments similar to torch.ops.aten.hypot_(self, other) 

35 x = None 

36 other = None 

37 if len(args) >= 1: 

38 x = args[0] 

39 if len(args) >= 2: 

40 other = args[1] 

41 if x is None: 

42 x = kwargs.get("input", kwargs.get("self", None)) 

43 if other is None: 

44 other = kwargs.get("other", None) 

45 

46 if x is None or other is None: 

47 raise TypeError("hypot_ expects two arguments: self and other") 

48 

49 if not isinstance(x, torch.Tensor): 

50 raise TypeError("self must be a torch.Tensor") 

51 if not x.is_cuda: 

52 raise ValueError("hypot_ Triton kernel only supports CUDA tensors") 

53 

54 device = x.device 

55 

56 # Prepare 'other' on the same device and dtype as x (in-place ops keep dtype) 

57 if isinstance(other, torch.Tensor): 

58 other_t = other.to(device) 

59 else: 

60 other_t = torch.tensor(other, device=device) 

61 

62 # In-place must keep dtype of x; cast other to x.dtype 

63 if other_t.dtype != x.dtype: 

64 other_t = other_t.to(x.dtype) 

65 

66 # Broadcast other to x's shape 

67 try: 

68 other_b = torch.broadcast_to(other_t, x.shape) 

69 except Exception: 

70 other_b = torch.broadcast_tensors(other_t, x)[0] 

71 

72 # Ensure contiguous buffers for kernel 

73 x_c = x if x.is_contiguous() else x.contiguous() 

74 other_c = other_b if other_b.is_contiguous() else other_b.contiguous() 

75 

76 n_elements = x.numel() 

77 if n_elements == 0: 

78 return x 

79 

80 # If x is contiguous, write directly in-place into x; otherwise write to temp and copy back. 

81 out_buf = x_c if x.is_contiguous() else torch.empty_like(x_c) 

82 

83 BLOCK_SIZE = 1024 

84 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

85 

86 _hypot_kernel[grid](x_c, other_c, out_buf, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

87 

88 if not x.is_contiguous(): 

89 x.copy_(out_buf) 

90 

91 return x