Coverage for src/flag_gems/experimental_ops/xlogy_.py: 0%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def xlogy_inplace_tensor_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 y = tl.load(y_ptr + offsets, mask=mask) 

15 

16 x_f32 = x.to(tl.float32) 

17 y_f32 = y.to(tl.float32) 

18 

19 logy = tl.log(y_f32) 

20 res = x_f32 * logy 

21 res = tl.where(x_f32 == 0.0, 0.0, res) 

22 

23 tl.store(x_ptr + offsets, res.to(x.dtype), mask=mask) 

24 

25 

26@triton.jit 

27def xlogy_inplace_scalar_kernel(x_ptr, y_scalar, n_elements, BLOCK_SIZE: tl.constexpr): 

28 pid = tl.program_id(axis=0) 

29 block_start = pid * BLOCK_SIZE 

30 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

31 mask = offsets < n_elements 

32 

33 x = tl.load(x_ptr + offsets, mask=mask) 

34 x_f32 = x.to(tl.float32) 

35 

36 y_vec = tl.full((BLOCK_SIZE,), y_scalar, tl.float32) 

37 logy = tl.log(y_vec) 

38 

39 res = x_f32 * logy 

40 res = tl.where(x_f32 == 0.0, 0.0, res) 

41 

42 tl.store(x_ptr + offsets, res.to(x.dtype), mask=mask) 

43 

44 

45def _ensure_supported_dtype(t: torch.Tensor): 

46 if t.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

47 raise TypeError( 

48 f"Unsupported dtype {t.dtype}. Supported: float16, bfloat16, float32." 

49 ) 

50 

51 

52def _ensure_cuda_contiguous(t: torch.Tensor, name: str): 

53 if not t.is_cuda: 

54 raise RuntimeError(f"{name} must be a CUDA tensor.") 

55 if not t.is_contiguous(): 

56 raise RuntimeError(f"{name} must be contiguous.") 

57 

58 

59def xlogy__Tensor(*args, **kwargs): 

60 # Expecting signature: (self, other) 

61 if len(args) >= 2: 

62 x, other = args[0], args[1] 

63 else: 

64 x = kwargs.get("self", kwargs.get("input", None)) 

65 other = kwargs.get("other", None) 

66 if x is None or other is None: 

67 raise ValueError("xlogy__Tensor expects (self, other) where both are tensors.") 

68 

69 if not isinstance(other, torch.Tensor): 

70 raise TypeError( 

71 "xlogy__Tensor expects 'other' to be a Tensor. Use xlogy__Scalar_Other for scalar 'other'." 

72 ) 

73 

74 _ensure_cuda_contiguous(x, "self") 

75 _ensure_supported_dtype(x) 

76 _ensure_cuda_contiguous(other, "other") 

77 _ensure_supported_dtype(other) 

78 

79 n_elements = x.numel() 

80 if other.numel() == 1: 

81 # Treat as scalar 

82 y_scalar = other.to(torch.float32).item() 

83 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

84 xlogy_inplace_scalar_kernel[grid](x, y_scalar, n_elements, BLOCK_SIZE=1024) 

85 else: 

86 if x.numel() != other.numel() or x.shape != other.shape: 

87 raise RuntimeError( 

88 "For xlogy__Tensor, 'other' must have the same shape as 'self' or be a scalar tensor." 

89 ) 

90 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

91 xlogy_inplace_tensor_kernel[grid](x, other, n_elements, BLOCK_SIZE=1024) 

92 

93 return x 

94 

95 

96def xlogy__Scalar_Other(*args, **kwargs): 

97 # Expecting signature: (self, other_scalar) 

98 if len(args) >= 2: 

99 x, other = args[0], args[1] 

100 else: 

101 x = kwargs.get("self", kwargs.get("input", None)) 

102 other = kwargs.get("other", None) 

103 if x is None: 

104 raise ValueError("xlogy__Scalar_Other expects 'self' tensor.") 

105 if other is None or isinstance(other, torch.Tensor): 

106 raise TypeError( 

107 "xlogy__Scalar_Other expects 'other' to be a Python scalar (not a Tensor)." 

108 ) 

109 

110 _ensure_cuda_contiguous(x, "self") 

111 _ensure_supported_dtype(x) 

112 

113 # Convert scalar to float for kernel 

114 y_scalar = float(other) 

115 

116 n_elements = x.numel() 

117 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

118 xlogy_inplace_scalar_kernel[grid](x, y_scalar, n_elements, BLOCK_SIZE=1024) 

119 return x