Coverage for src/flag_gems/ops/digamma_.py: 30%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7 

8 

9@triton.jit 

10def digamma_kernel_( 

11 x_ptr, 

12 n_elements, 

13 BLOCK_SIZE: tl.constexpr, 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 x_f32 = x.to(tl.float32) 

22 

23 pi = 3.1415926535897932384626433832795028841971 

24 

25 # Reflection for x < 0.5: psi(x) = psi(1 - x) - pi * cot(pi * x) 

26 reflect_mask = x_f32 < 0.5 

27 xr = tl.where(reflect_mask, 1.0 - x_f32, x_f32) 

28 

29 # Use recurrence to shift xr to >= 8 for better asymptotic precision 

30 s = tl.zeros_like(x_f32) 

31 y = xr 

32 for _ in range(8): 

33 m = y < 8.0 

34 s = s - tl.where(m, 1.0 / y, 0.0) 

35 y = tl.where(m, y + 1.0, y) 

36 

37 # Asymptotic expansion for digamma at large y 

38 r = 1.0 / y 

39 r2 = r * r 

40 t2 = r2 

41 t4 = t2 * t2 

42 t6 = t4 * t2 

43 t8 = t4 * t4 

44 series = ( 

45 (-0.5 * r) 

46 + (-1.0 / 12.0) * t2 

47 + (1.0 / 120.0) * t4 

48 + (-1.0 / 252.0) * t6 

49 + (1.0 / 240.0) * t8 

50 ) 

51 psi_y = tl.log(y) + s + series 

52 

53 # Apply reflection if needed 

54 cot_term = tl.cos(pi * x_f32) / tl.sin(pi * x_f32) 

55 result = tl.where(reflect_mask, psi_y - pi * cot_term, psi_y) 

56 

57 result = result.to(x.dtype) 

58 tl.store(x_ptr + offsets, result, mask=mask) 

59 

60 

61def digamma_(*args, **kwargs): 

62 x = args[0] 

63 if not isinstance(x, torch.Tensor): 

64 raise TypeError("digamma_ expects a torch.Tensor as the first argument") 

65 

66 # Handle non-contiguous tensors by operating on a contiguous copy and copying back 

67 if not x.is_contiguous(): 

68 y = x.contiguous() 

69 n_elements = y.numel() 

70 if n_elements == 0: 

71 return x 

72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

73 with torch_device_fn.device(y.device): 

74 digamma_kernel_[grid](y, n_elements, BLOCK_SIZE=1024) 

75 x.copy_(y) 

76 return x 

77 

78 n_elements = x.numel() 

79 if n_elements == 0: 

80 return x 

81 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

82 with torch_device_fn.device(x.device): 

83 digamma_kernel_[grid](x, n_elements, BLOCK_SIZE=1024) 

84 return x