Coverage for src/flag_gems/fused/outer.py: 77%
26 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
3import torch
5from flag_gems.ops import mul, mv
7logger = logging.getLogger(__name__)
10class Outer(torch.autograd.Function):
11 @staticmethod
12 def forward(ctx, inp, weight):
13 logger.debug("GEMS OUTER")
14 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input"
15 inp1 = inp[:, None]
16 weight1 = weight[None, :]
17 inp1 = inp1.contiguous()
18 weight1 = weight1.contiguous()
19 out = mul(inp1, weight1)
20 ctx.save_for_backward(inp, weight)
21 return out
23 @staticmethod
24 def backward(ctx, out_grad):
25 logger.debug("GEMS OUTER VJP")
26 assert out_grad.ndim == 2, "invalide out_grad shape"
28 inp, weight = ctx.saved_tensors
30 inp_grad = mv(out_grad, weight)
31 weight_grad = mv(out_grad.t(), inp)
33 return inp_grad, weight_grad
36def outer(inp, weight):
37 return Outer.apply(inp, weight)