Coverage for src/flag_gems/experimental_ops/xlogy_.py: 0%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def xlogy_inplace_tensor_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
14 y = tl.load(y_ptr + offsets, mask=mask)
16 x_f32 = x.to(tl.float32)
17 y_f32 = y.to(tl.float32)
19 logy = tl.log(y_f32)
20 res = x_f32 * logy
21 res = tl.where(x_f32 == 0.0, 0.0, res)
23 tl.store(x_ptr + offsets, res.to(x.dtype), mask=mask)
26@triton.jit
27def xlogy_inplace_scalar_kernel(x_ptr, y_scalar, n_elements, BLOCK_SIZE: tl.constexpr):
28 pid = tl.program_id(axis=0)
29 block_start = pid * BLOCK_SIZE
30 offsets = block_start + tl.arange(0, BLOCK_SIZE)
31 mask = offsets < n_elements
33 x = tl.load(x_ptr + offsets, mask=mask)
34 x_f32 = x.to(tl.float32)
36 y_vec = tl.full((BLOCK_SIZE,), y_scalar, tl.float32)
37 logy = tl.log(y_vec)
39 res = x_f32 * logy
40 res = tl.where(x_f32 == 0.0, 0.0, res)
42 tl.store(x_ptr + offsets, res.to(x.dtype), mask=mask)
45def _ensure_supported_dtype(t: torch.Tensor):
46 if t.dtype not in (torch.float16, torch.bfloat16, torch.float32):
47 raise TypeError(
48 f"Unsupported dtype {t.dtype}. Supported: float16, bfloat16, float32."
49 )
52def _ensure_cuda_contiguous(t: torch.Tensor, name: str):
53 if not t.is_cuda:
54 raise RuntimeError(f"{name} must be a CUDA tensor.")
55 if not t.is_contiguous():
56 raise RuntimeError(f"{name} must be contiguous.")
59def xlogy__Tensor(*args, **kwargs):
60 # Expecting signature: (self, other)
61 if len(args) >= 2:
62 x, other = args[0], args[1]
63 else:
64 x = kwargs.get("self", kwargs.get("input", None))
65 other = kwargs.get("other", None)
66 if x is None or other is None:
67 raise ValueError("xlogy__Tensor expects (self, other) where both are tensors.")
69 if not isinstance(other, torch.Tensor):
70 raise TypeError(
71 "xlogy__Tensor expects 'other' to be a Tensor. Use xlogy__Scalar_Other for scalar 'other'."
72 )
74 _ensure_cuda_contiguous(x, "self")
75 _ensure_supported_dtype(x)
76 _ensure_cuda_contiguous(other, "other")
77 _ensure_supported_dtype(other)
79 n_elements = x.numel()
80 if other.numel() == 1:
81 # Treat as scalar
82 y_scalar = other.to(torch.float32).item()
83 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
84 xlogy_inplace_scalar_kernel[grid](x, y_scalar, n_elements, BLOCK_SIZE=1024)
85 else:
86 if x.numel() != other.numel() or x.shape != other.shape:
87 raise RuntimeError(
88 "For xlogy__Tensor, 'other' must have the same shape as 'self' or be a scalar tensor."
89 )
90 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
91 xlogy_inplace_tensor_kernel[grid](x, other, n_elements, BLOCK_SIZE=1024)
93 return x
96def xlogy__Scalar_Other(*args, **kwargs):
97 # Expecting signature: (self, other_scalar)
98 if len(args) >= 2:
99 x, other = args[0], args[1]
100 else:
101 x = kwargs.get("self", kwargs.get("input", None))
102 other = kwargs.get("other", None)
103 if x is None:
104 raise ValueError("xlogy__Scalar_Other expects 'self' tensor.")
105 if other is None or isinstance(other, torch.Tensor):
106 raise TypeError(
107 "xlogy__Scalar_Other expects 'other' to be a Python scalar (not a Tensor)."
108 )
110 _ensure_cuda_contiguous(x, "self")
111 _ensure_supported_dtype(x)
113 # Convert scalar to float for kernel
114 y_scalar = float(other)
116 n_elements = x.numel()
117 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
118 xlogy_inplace_scalar_kernel[grid](x, y_scalar, n_elements, BLOCK_SIZE=1024)
119 return x