Coverage for src/flag_gems/runtime/backend/_ascend/ops/resolve_neg.py: 0%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7import triton.runtime.driver as driver 

8 

9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

10 

11 

12def get_npu_properties(): 

13 device = torch.npu.current_device() 

14 return driver.active.utils.get_device_properties(device) 

15 

16 

17@triton.jit 

18def resolve_neg_kernel( 

19 inp, 

20 out, 

21 data_len: tl.constexpr, 

22 BLOCK_SIZE: tl.constexpr, 

23 MAX_DATA_SIZE: tl.constexpr, 

24): 

25 pid = tl.program_id(0) 

26 iter_num = tl.cdiv(BLOCK_SIZE, MAX_DATA_SIZE) 

27 

28 for idx in tl.range(0, iter_num): 

29 offsets = pid * BLOCK_SIZE + idx * MAX_DATA_SIZE + tl.arange(0, MAX_DATA_SIZE) 

30 mask = offsets < data_len 

31 inp_val = tl.load(inp + offsets, mask=mask) 

32 out_val = -inp_val 

33 tl.store(out + offsets, out_val, mask=mask) 

34 

35 

36def resolve_neg(A: torch.Tensor): 

37 logger.debug("GEMS_ASCEND RESOLVE_NEG") 

38 

39 if A.is_neg(): 

40 data_len = A.numel() 

41 out = torch.empty(A.numel(), dtype=A.dtype, device=A.device) 

42 

43 CORE_NUM = get_npu_properties()["num_vectorcore"] 

44 BLOCK_SIZE = math.ceil(data_len / CORE_NUM) 

45 MAX_DATA_SIZE = 20 * 1024 

46 

47 grid = lambda meta: (triton.cdiv(data_len, meta["BLOCK_SIZE"]),) 

48 resolve_neg_kernel[grid]( 

49 A, 

50 out, 

51 data_len, 

52 BLOCK_SIZE, 

53 MAX_DATA_SIZE, 

54 ) 

55 out = out.view(A.shape) 

56 return out 

57 else: 

58 return A