Coverage for src/flag_gems/ops/unfold_backward.py: 64%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def _unfold_backward_kernel( 

12 grad_in_ptr, 

13 grad_out_ptr, 

14 numel_in, 

15 prod_after, 

16 L, 

17 size, 

18 step, 

19 D, 

20 inner_total, 

21 BLOCK: tl.constexpr, 

22): 

23 pid = tl.program_id(0) 

24 offs = pid * BLOCK + tl.arange(0, BLOCK) 

25 mask = offs < numel_in 

26 

27 vals = tl.load(grad_in_ptr + offs, mask=mask, other=0) 

28 vals_f32 = tl.cast(vals, tl.float32) 

29 

30 k = offs % size 

31 tmp1 = offs // size 

32 after_lin = tmp1 % prod_after 

33 tmp2 = offs // (prod_after * size) 

34 s = tmp2 % L 

35 before_lin = offs // inner_total 

36 

37 pos = s * step + k 

38 

39 out_id = ((before_lin * D) + pos) * prod_after + after_lin 

40 

41 tl.atomic_add(grad_out_ptr + out_id, vals_f32, mask=mask) 

42 

43 

44def unfold_backward( 

45 grad_in: torch.Tensor, input_sizes, dim: int, size: int, step: int 

46) -> torch.Tensor: 

47 logger.debug("GEMS UNFOLD BACKWARD") 

48 if step <= 0: 

49 raise ValueError("step must be > 0") 

50 

51 if not isinstance(input_sizes, (list, tuple)): 

52 input_sizes = list(input_sizes) 

53 input_sizes = [int(s) for s in input_sizes] 

54 ndim = len(input_sizes) 

55 d = dim % ndim 

56 

57 D = int(input_sizes[d]) 

58 L = (D - int(size)) // int(step) + 1 

59 

60 prod_after = 1 

61 for s_ in input_sizes[d + 1 :]: 

62 prod_after *= int(s_) 

63 inner_total = int(L) * int(prod_after) * int(size) 

64 

65 device = grad_in.device 

66 grad_out_f32 = torch.zeros(input_sizes, dtype=torch.float32, device=device) 

67 

68 numel_in = grad_in.numel() 

69 

70 BLOCK = 128 

71 grid = lambda meta: (triton.cdiv(numel_in, meta["BLOCK"]),) 

72 

73 _unfold_backward_kernel[grid]( 

74 grad_in, 

75 grad_out_f32, 

76 numel_in, 

77 prod_after, 

78 L, 

79 size, 

80 step, 

81 D, 

82 inner_total, 

83 BLOCK=BLOCK, 

84 ) 

85 

86 if grad_in.dtype != torch.float32: 

87 return grad_out_f32.to(grad_in.dtype) 

88 return grad_out_f32