Coverage for src/flag_gems/runtime/backend/_hygon/ops/fill.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@pointwise_dynamic( 

15 is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")], num_outputs=1 

16) 

17@triton.jit 

18def fill_scalar_func(inp, value_scalar): 

19 return tl.full(inp.shape, value_scalar, dtype=inp.dtype) 

20 

21 

22@pointwise_dynamic( 

23 is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")], num_outputs=1 

24) 

25@triton.jit 

26def fill_tensor_func(inp, value): 

27 return value 

28 

29 

30def fill_scalar(input, value): 

31 logger.debug("GEMS FILL (Dynamic)") 

32 out = torch.empty_like(input) 

33 with torch_device_fn.device(input.device): 

34 return fill_scalar_func(input, value, out0=out) 

35 

36 

37def fill_tensor(input, value): 

38 if not value.is_cuda: 

39 return fill_scalar(input, value.item()) 

40 logger.debug("GEMS FILL (Dynamic)") 

41 if value.ndim != 0: 

42 raise RuntimeError( 

43 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

44 ) 

45 out = torch.empty_like(input) 

46 with torch_device_fn.device(input.device): 

47 return fill_tensor_func(input, value, out0=out) 

48 

49 

50def fill_tensor_(self, value): 

51 if not value.is_cuda: 

52 return fill_scalar_(self, value.item()) 

53 logger.debug("GEMS FILL_TENSOR_") 

54 if value.ndim != 0: 

55 raise RuntimeError( 

56 f"fill_ only supports 0-dimension value tensor but got tensor with {value.ndim} dimensions." 

57 ) 

58 with torch_device_fn.device(self.device): 

59 fill_tensor_func(self, value, out0=self) 

60 return self 

61 

62 

63def fill_scalar_(self, value): 

64 logger.debug("GEMS FILL_SCALAR_") 

65 with torch_device_fn.device(self.device): 

66 fill_scalar_func(self, value, out0=self) 

67 return self