Coverage for src/flag_gems/experimental_ops/digamma_.py: 0%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def digamma_(
8 x_ptr, # Pointer to input/output tensor (in-place)
9 n_elements, # Number of elements
10 BLOCK_SIZE: tl.constexpr,
11):
12 pid = tl.program_id(axis=0)
13 block_start = pid * BLOCK_SIZE
14 offsets = block_start + tl.arange(0, BLOCK_SIZE)
15 mask = offsets < n_elements
17 x = tl.load(x_ptr + offsets, mask=mask)
18 x_f32 = x.to(tl.float32)
20 pi = 3.1415926535897932384626433832795028841971
22 # Reflection for x < 0.5: psi(x) = psi(1 - x) - pi * cot(pi * x)
23 reflect_mask = x_f32 < 0.5
24 xr = tl.where(reflect_mask, 1.0 - x_f32, x_f32)
26 # Use recurrence to shift xr to >= 8 for better asymptotic precision
27 s = tl.zeros_like(x_f32)
28 y = xr
29 for _ in range(8):
30 m = y < 8.0
31 s = s - tl.where(m, 1.0 / y, 0.0)
32 y = tl.where(m, y + 1.0, y)
34 # Asymptotic expansion for digamma at large y
35 r = 1.0 / y
36 r2 = r * r
37 t2 = r2
38 t4 = t2 * t2
39 t6 = t4 * t2
40 t8 = t4 * t4
41 series = (
42 (-0.5 * r)
43 + (-1.0 / 12.0) * t2
44 + (1.0 / 120.0) * t4
45 + (-1.0 / 252.0) * t6
46 + (1.0 / 240.0) * t8
47 )
48 psi_y = tl.log(y) + s + series
50 # Apply reflection if needed
51 cot_term = tl.cos(pi * x_f32) / tl.sin(pi * x_f32)
52 result = tl.where(reflect_mask, psi_y - pi * cot_term, psi_y)
54 result = result.to(x.dtype)
55 tl.store(x_ptr + offsets, result, mask=mask)
58_KERNEL_DIGAMMA_INPLACE = digamma_
61def digamma_(*args, **kwargs):
62 x = args[0]
63 if not isinstance(x, torch.Tensor):
64 raise TypeError("digamma_ expects a torch.Tensor as the first argument")
65 if not x.is_cuda:
66 return torch.ops.aten.digamma_(x)
68 # Handle non-contiguous tensors by operating on a contiguous copy and copying back
69 if not x.is_contiguous():
70 y = x.contiguous()
71 n_elements = y.numel()
72 if n_elements == 0:
73 return x
74 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
75 _KERNEL_DIGAMMA_INPLACE[grid](y, n_elements, BLOCK_SIZE=1024)
76 x.copy_(y)
77 return x
79 n_elements = x.numel()
80 if n_elements == 0:
81 return x
82 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
83 _KERNEL_DIGAMMA_INPLACE[grid](x, n_elements, BLOCK_SIZE=1024)
84 return x