Coverage for src/flag_gems/testing/__init__.py: 90%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1import torch
3from flag_gems import runtime
4from flag_gems.runtime import torch_device_fn
6if runtime.device.vendor_name == "kunlunxin":
7 RESOLUTION = {
8 torch.bool: 0,
9 torch.uint8: 0,
10 torch.int8: 0,
11 torch.int16: 0,
12 torch.int32: 0,
13 torch.int64: 0,
14 torch.float16: 1e-3,
15 torch.float32: 1.3e-6,
16 torch.bfloat16: 0.016,
17 torch.float64: 1e-7,
18 torch.complex32: 1e-3,
19 torch.complex64: 1.3e-6,
20 }
21else:
22 RESOLUTION = {
23 torch.bool: 0,
24 torch.uint8: 0,
25 torch.int8: 0,
26 torch.int16: 0,
27 torch.int32: 0,
28 torch.int64: 0,
29 torch.float8_e4m3fn: 1e-3,
30 torch.float8_e5m2: 1e-3,
31 torch.float8_e4m3fnuz: 1e-3,
32 torch.float8_e5m2fnuz: 1e-3,
33 torch.float16: 1e-3,
34 torch.float32: 1.3e-6,
35 torch.bfloat16: 0.016,
36 torch.float64: 1e-7,
37 torch.complex32: 1e-3,
38 torch.complex64: 1.3e-6,
39 }
42def _maybe_move_to_cpu(res, ref):
43 if res.device.type == "cpu" or ref.device.type == "cpu":
44 return res, ref
46 required = res.numel() * res.element_size()
48 free_mem = None
49 try:
50 free_mem, _ = torch_device_fn.mem_get_info(res.device)
51 except RuntimeError:
52 pass
54 # torch.isclose allocates an auxiliary tensor roughly the size of the inputs,
55 # so ensure we have enough headroom; otherwise compare on CPU.
56 HUGE_TENSOR_BYTES = 1 << 30 # 1 GiB
57 if (free_mem is not None and required >= free_mem) or (
58 required >= HUGE_TENSOR_BYTES
59 ):
60 return res.cpu(), ref.cpu()
61 return res, ref
64def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1, atol=1e-4):
65 if dtype is None:
66 dtype = torch.float32
67 assert res.dtype == dtype
68 ref = ref.to(dtype)
69 res, ref = _maybe_move_to_cpu(res, ref)
70 rtol = RESOLUTION[dtype]
71 torch.testing.assert_close(
72 res, ref, atol=atol * reduce_dim, rtol=rtol, equal_nan=equal_nan
73 )
76def assert_equal(res, ref, equal_nan=False):
77 torch.testing.assert_close(res, ref, atol=0, rtol=0, equal_nan=equal_nan)