Coverage for src/flag_gems/ops/digamma_.py: 30%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import torch
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
9@triton.jit
10def digamma_kernel_(
11 x_ptr,
12 n_elements,
13 BLOCK_SIZE: tl.constexpr,
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask)
21 x_f32 = x.to(tl.float32)
23 pi = 3.1415926535897932384626433832795028841971
25 # Reflection for x < 0.5: psi(x) = psi(1 - x) - pi * cot(pi * x)
26 reflect_mask = x_f32 < 0.5
27 xr = tl.where(reflect_mask, 1.0 - x_f32, x_f32)
29 # Use recurrence to shift xr to >= 8 for better asymptotic precision
30 s = tl.zeros_like(x_f32)
31 y = xr
32 for _ in range(8):
33 m = y < 8.0
34 s = s - tl.where(m, 1.0 / y, 0.0)
35 y = tl.where(m, y + 1.0, y)
37 # Asymptotic expansion for digamma at large y
38 r = 1.0 / y
39 r2 = r * r
40 t2 = r2
41 t4 = t2 * t2
42 t6 = t4 * t2
43 t8 = t4 * t4
44 series = (
45 (-0.5 * r)
46 + (-1.0 / 12.0) * t2
47 + (1.0 / 120.0) * t4
48 + (-1.0 / 252.0) * t6
49 + (1.0 / 240.0) * t8
50 )
51 psi_y = tl.log(y) + s + series
53 # Apply reflection if needed
54 cot_term = tl.cos(pi * x_f32) / tl.sin(pi * x_f32)
55 result = tl.where(reflect_mask, psi_y - pi * cot_term, psi_y)
57 result = result.to(x.dtype)
58 tl.store(x_ptr + offsets, result, mask=mask)
61def digamma_(*args, **kwargs):
62 x = args[0]
63 if not isinstance(x, torch.Tensor):
64 raise TypeError("digamma_ expects a torch.Tensor as the first argument")
66 # Handle non-contiguous tensors by operating on a contiguous copy and copying back
67 if not x.is_contiguous():
68 y = x.contiguous()
69 n_elements = y.numel()
70 if n_elements == 0:
71 return x
72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73 with torch_device_fn.device(y.device):
74 digamma_kernel_[grid](y, n_elements, BLOCK_SIZE=1024)
75 x.copy_(y)
76 return x
78 n_elements = x.numel()
79 if n_elements == 0:
80 return x
81 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
82 with torch_device_fn.device(x.device):
83 digamma_kernel_[grid](x, n_elements, BLOCK_SIZE=1024)
84 return x