Coverage for src/flag_gems/ops/unfold_backward.py: 64%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def _unfold_backward_kernel(
12 grad_in_ptr,
13 grad_out_ptr,
14 numel_in,
15 prod_after,
16 L,
17 size,
18 step,
19 D,
20 inner_total,
21 BLOCK: tl.constexpr,
22):
23 pid = tl.program_id(0)
24 offs = pid * BLOCK + tl.arange(0, BLOCK)
25 mask = offs < numel_in
27 vals = tl.load(grad_in_ptr + offs, mask=mask, other=0)
28 vals_f32 = tl.cast(vals, tl.float32)
30 k = offs % size
31 tmp1 = offs // size
32 after_lin = tmp1 % prod_after
33 tmp2 = offs // (prod_after * size)
34 s = tmp2 % L
35 before_lin = offs // inner_total
37 pos = s * step + k
39 out_id = ((before_lin * D) + pos) * prod_after + after_lin
41 tl.atomic_add(grad_out_ptr + out_id, vals_f32, mask=mask)
44def unfold_backward(
45 grad_in: torch.Tensor, input_sizes, dim: int, size: int, step: int
46) -> torch.Tensor:
47 logger.debug("GEMS UNFOLD BACKWARD")
48 if step <= 0:
49 raise ValueError("step must be > 0")
51 if not isinstance(input_sizes, (list, tuple)):
52 input_sizes = list(input_sizes)
53 input_sizes = [int(s) for s in input_sizes]
54 ndim = len(input_sizes)
55 d = dim % ndim
57 D = int(input_sizes[d])
58 L = (D - int(size)) // int(step) + 1
60 prod_after = 1
61 for s_ in input_sizes[d + 1 :]:
62 prod_after *= int(s_)
63 inner_total = int(L) * int(prod_after) * int(size)
65 device = grad_in.device
66 grad_out_f32 = torch.zeros(input_sizes, dtype=torch.float32, device=device)
68 numel_in = grad_in.numel()
70 BLOCK = 128
71 grid = lambda meta: (triton.cdiv(numel_in, meta["BLOCK"]),)
73 _unfold_backward_kernel[grid](
74 grad_in,
75 grad_out_f32,
76 numel_in,
77 prod_after,
78 L,
79 size,
80 step,
81 D,
82 inner_total,
83 BLOCK=BLOCK,
84 )
86 if grad_in.dtype != torch.float32:
87 return grad_out_f32.to(grad_in.dtype)
88 return grad_out_f32