• pytorh .to(device) 和.cuda()


    一、.to(device) 可以指定CPU 或者GPU

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 单GPU或者CPU
    model.to(device)
    #如果是多GPU
    if torch.cuda.device_count() > 1:
      model = nn.DataParallel(model,device_ids=[0,1,2])
    model.to(device)

    mytensor = my_tensor.to(device)

    这行代码的意思是将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。

    这句话需要写的次数等于需要保存GPU上的tensor变量的个数;一般情况下这些tensor变量都是最开始读数据时的tensor变量,后面衍生的变量自然也都在GPU上

     

    二、.cuda() 只能指定GPU

    #指定某个GPU
    os.environ['CUDA_VISIBLE_DEVICE']='1'
    model.cuda()
    #如果是多GPU
    os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    device_ids = [0,1,2,3]
    net  = torch.nn.Dataparallel(net, device_ids =device_ids)
    net  = torch.nn.Dataparallel(net) # 默认使用所有的device_ids 
    net = net.cuda()
    class DataParallel(Module):
        def __init__(self, module, device_ids=None, output_device=None, dim=0):
            super(DataParallel, self).__init__()
    
            if not torch.cuda.is_available():
                self.module = module
                self.device_ids = []
                return
    
            if device_ids is None:
                device_ids = list(range(torch.cuda.device_count()))
            if output_device is None:
                output_device = device_ids[0]
  • 相关阅读:
    shell编程之变量
    linux更换yum源
    windows系统安装jdk并设置环境变量
    linux安装jdk
    mysql中null与“空值”的坑
    mysql服务器3306端口不能远程连接的解决
    Memcached
    redis memcached MongoDB
    postman进行http接口测试
    C# 开发Chrome内核浏览器(WebKit.net)
  • 原文地址:https://www.cnblogs.com/h694879357/p/15984367.html
Copyright © 2020-2023  润新知