Coverage for src/flag_gems/ops/full.py: 90%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64) 

14ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64) 

15 

16 

17def check_dtype(fill_value, dtype, device): 

18 if isinstance(fill_value, bool): 

19 if dtype != torch.bool: 

20 fill_value = int(fill_value) 

21 

22 elif ( 

23 dtype in ALL_INT_DTYPES 

24 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max) 

25 ) or ( 

26 dtype in ALL_FLOAT_DTYPES 

27 and not (math.isinf(fill_value) or math.isnan(fill_value)) 

28 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max) 

29 ): 

30 raise RuntimeError( 

31 f"value cannot be converted to type {dtype} without overflow" 

32 ) 

33 

34 if dtype == torch.float64: 

35 fill_value = torch.tensor(fill_value, dtype=dtype, device=device) 

36 

37 return fill_value 

38 

39 

40@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

41@triton.jit 

42def full_func(out, fill_value): 

43 return fill_value 

44 

45 

46@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) 

47@triton.jit 

48def full_func_scalar(out, fill_value): 

49 return tl.full(out.shape, fill_value, out.dtype) 

50 

51 

52def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None): 

53 logger.debug("GEMS FULL") 

54 if device is None: 

55 device = torch.device("cpu") 

56 if dtype is None: 

57 if isinstance(fill_value, bool): 

58 dtype = torch.bool 

59 elif isinstance(fill_value, int): 

60 dtype = torch.int64 

61 else: 

62 dtype = torch.get_default_dtype() 

63 else: 

64 fill_value = check_dtype(fill_value, dtype, device) 

65 

66 out = torch.empty(size, device=device, dtype=dtype) 

67 

68 if isinstance(fill_value, torch.Tensor): 

69 return full_func(out, fill_value, out0=out) 

70 else: 

71 return full_func_scalar(out, fill_value, out0=out)