Coverage for src/flag_gems/fused/outer.py: 77%

26 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4 

5from flag_gems.ops import mul, mv 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10class Outer(torch.autograd.Function): 

11 @staticmethod 

12 def forward(ctx, inp, weight): 

13 logger.debug("GEMS OUTER") 

14 assert inp.ndim == 1 and weight.ndim == 1, "Invalid input" 

15 inp1 = inp[:, None] 

16 weight1 = weight[None, :] 

17 inp1 = inp1.contiguous() 

18 weight1 = weight1.contiguous() 

19 out = mul(inp1, weight1) 

20 ctx.save_for_backward(inp, weight) 

21 return out 

22 

23 @staticmethod 

24 def backward(ctx, out_grad): 

25 logger.debug("GEMS OUTER VJP") 

26 assert out_grad.ndim == 2, "invalide out_grad shape" 

27 

28 inp, weight = ctx.saved_tensors 

29 

30 inp_grad = mv(out_grad, weight) 

31 weight_grad = mv(out_grad.t(), inp) 

32 

33 return inp_grad, weight_grad 

34 

35 

36def outer(inp, weight): 

37 return Outer.apply(inp, weight)