Coverage for /tmp/tmpfgl4ek8j/_remote_module_non_scriptable.py: 31%

39 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1from typing import * 

2 

3import torch 

4import torch.distributed.rpc as rpc 

5from torch import Tensor 

6from torch._jit_internal import Future 

7from torch.distributed.rpc import RRef 

8from typing import Tuple # pyre-ignore: unused import 

9 

10 

11module_interface_cls = None 

12 

13 

14def forward_async(self, *args, **kwargs): 

15 args = (self.module_rref, self.device, self.is_device_map_set, *args) 

16 kwargs = {**kwargs} 

17 return rpc.rpc_async( 

18 self.module_rref.owner(), 

19 _remote_forward, 

20 args, 

21 kwargs, 

22 ) 

23 

24 

25def forward(self, *args, **kwargs): 

26 args = (self.module_rref, self.device, self.is_device_map_set, *args) 

27 kwargs = {**kwargs} 

28 ret_fut = rpc.rpc_async( 

29 self.module_rref.owner(), 

30 _remote_forward, 

31 args, 

32 kwargs, 

33 ) 

34 return ret_fut.wait() 

35 

36 

37_generated_methods = [ 

38 forward_async, 

39 forward, 

40] 

41 

42 

43 

44 

45def _remote_forward( 

46 module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs): 

47 module = module_rref.local_value() 

48 device = torch.device(device) 

49 

50 if device.type != "cuda": 

51 return module.forward(*args, **kwargs) 

52 

53 # If the module is on a cuda device, 

54 # move any CPU tensor in args or kwargs to the same cuda device. 

55 # Since torch script does not support generator expression, 

56 # have to use concatenation instead of 

57 # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``. 

58 args = (*args,) 

59 out_args: Tuple[()] = () 

60 for arg in args: 

61 arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) 

62 out_args = out_args + arg 

63 

64 kwargs = {**kwargs} 

65 for k, v in kwargs.items(): 

66 if isinstance(v, Tensor): 

67 kwargs[k] = kwargs[k].to(device) 

68 

69 if is_device_map_set: 

70 return module.forward(*out_args, **kwargs) 

71 

72 # If the device map is empty, then only CPU tensors are allowed to send over wire, 

73 # so have to move any GPU tensor to CPU in the output. 

74 # Since torch script does not support generator expression, 

75 # have to use concatenation instead of 

76 # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, **kwargs))``. 

77 ret: Tuple[()] = () 

78 for i in module.forward(*out_args, **kwargs): 

79 i = (i.cpu(),) if isinstance(i, Tensor) else (i,) 

80 ret = ret + i 

81 return ret