Coverage for src/flag_gems/testing/__init__.py: 90%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch 

2 

3from flag_gems import runtime 

4from flag_gems.runtime import torch_device_fn 

5 

6if runtime.device.vendor_name == "kunlunxin": 

7 RESOLUTION = { 

8 torch.bool: 0, 

9 torch.uint8: 0, 

10 torch.int8: 0, 

11 torch.int16: 0, 

12 torch.int32: 0, 

13 torch.int64: 0, 

14 torch.float16: 1e-3, 

15 torch.float32: 1.3e-6, 

16 torch.bfloat16: 0.016, 

17 torch.float64: 1e-7, 

18 torch.complex32: 1e-3, 

19 torch.complex64: 1.3e-6, 

20 } 

21else: 

22 RESOLUTION = { 

23 torch.bool: 0, 

24 torch.uint8: 0, 

25 torch.int8: 0, 

26 torch.int16: 0, 

27 torch.int32: 0, 

28 torch.int64: 0, 

29 torch.float8_e4m3fn: 1e-3, 

30 torch.float8_e5m2: 1e-3, 

31 torch.float8_e4m3fnuz: 1e-3, 

32 torch.float8_e5m2fnuz: 1e-3, 

33 torch.float16: 1e-3, 

34 torch.float32: 1.3e-6, 

35 torch.bfloat16: 0.016, 

36 torch.float64: 1e-7, 

37 torch.complex32: 1e-3, 

38 torch.complex64: 1.3e-6, 

39 } 

40 

41 

42def _maybe_move_to_cpu(res, ref): 

43 if res.device.type == "cpu" or ref.device.type == "cpu": 

44 return res, ref 

45 

46 required = res.numel() * res.element_size() 

47 

48 free_mem = None 

49 try: 

50 free_mem, _ = torch_device_fn.mem_get_info(res.device) 

51 except RuntimeError: 

52 pass 

53 

54 # torch.isclose allocates an auxiliary tensor roughly the size of the inputs, 

55 # so ensure we have enough headroom; otherwise compare on CPU. 

56 HUGE_TENSOR_BYTES = 1 << 30 # 1 GiB 

57 if (free_mem is not None and required >= free_mem) or ( 

58 required >= HUGE_TENSOR_BYTES 

59 ): 

60 return res.cpu(), ref.cpu() 

61 return res, ref 

62 

63 

64def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1, atol=1e-4): 

65 if dtype is None: 

66 dtype = torch.float32 

67 assert res.dtype == dtype 

68 ref = ref.to(dtype) 

69 res, ref = _maybe_move_to_cpu(res, ref) 

70 rtol = RESOLUTION[dtype] 

71 torch.testing.assert_close( 

72 res, ref, atol=atol * reduce_dim, rtol=rtol, equal_nan=equal_nan 

73 ) 

74 

75 

76def assert_equal(res, ref, equal_nan=False): 

77 torch.testing.assert_close(res, ref, atol=0, rtol=0, equal_nan=equal_nan)