Coverage for src/flag_gems/experimental_ops/digamma_.py: 0%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def digamma_( 

8 x_ptr, # Pointer to input/output tensor (in-place) 

9 n_elements, # Number of elements 

10 BLOCK_SIZE: tl.constexpr, 

11): 

12 pid = tl.program_id(axis=0) 

13 block_start = pid * BLOCK_SIZE 

14 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

15 mask = offsets < n_elements 

16 

17 x = tl.load(x_ptr + offsets, mask=mask) 

18 x_f32 = x.to(tl.float32) 

19 

20 pi = 3.1415926535897932384626433832795028841971 

21 

22 # Reflection for x < 0.5: psi(x) = psi(1 - x) - pi * cot(pi * x) 

23 reflect_mask = x_f32 < 0.5 

24 xr = tl.where(reflect_mask, 1.0 - x_f32, x_f32) 

25 

26 # Use recurrence to shift xr to >= 8 for better asymptotic precision 

27 s = tl.zeros_like(x_f32) 

28 y = xr 

29 for _ in range(8): 

30 m = y < 8.0 

31 s = s - tl.where(m, 1.0 / y, 0.0) 

32 y = tl.where(m, y + 1.0, y) 

33 

34 # Asymptotic expansion for digamma at large y 

35 r = 1.0 / y 

36 r2 = r * r 

37 t2 = r2 

38 t4 = t2 * t2 

39 t6 = t4 * t2 

40 t8 = t4 * t4 

41 series = ( 

42 (-0.5 * r) 

43 + (-1.0 / 12.0) * t2 

44 + (1.0 / 120.0) * t4 

45 + (-1.0 / 252.0) * t6 

46 + (1.0 / 240.0) * t8 

47 ) 

48 psi_y = tl.log(y) + s + series 

49 

50 # Apply reflection if needed 

51 cot_term = tl.cos(pi * x_f32) / tl.sin(pi * x_f32) 

52 result = tl.where(reflect_mask, psi_y - pi * cot_term, psi_y) 

53 

54 result = result.to(x.dtype) 

55 tl.store(x_ptr + offsets, result, mask=mask) 

56 

57 

58_KERNEL_DIGAMMA_INPLACE = digamma_ 

59 

60 

61def digamma_(*args, **kwargs): 

62 x = args[0] 

63 if not isinstance(x, torch.Tensor): 

64 raise TypeError("digamma_ expects a torch.Tensor as the first argument") 

65 if not x.is_cuda: 

66 return torch.ops.aten.digamma_(x) 

67 

68 # Handle non-contiguous tensors by operating on a contiguous copy and copying back 

69 if not x.is_contiguous(): 

70 y = x.contiguous() 

71 n_elements = y.numel() 

72 if n_elements == 0: 

73 return x 

74 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

75 _KERNEL_DIGAMMA_INPLACE[grid](y, n_elements, BLOCK_SIZE=1024) 

76 x.copy_(y) 

77 return x 

78 

79 n_elements = x.numel() 

80 if n_elements == 0: 

81 return x 

82 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

83 _KERNEL_DIGAMMA_INPLACE[grid](x, n_elements, BLOCK_SIZE=1024) 

84 return x