• pytorh .to(device) 和.cuda()


    一、.to(device) 可以指定CPU 或者GPU

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 单GPU或者CPU
    model.to(device)
    #如果是多GPU
    if torch.cuda.device_count() > 1:
      model = nn.DataParallel(model,device_ids=[0,1,2])
    model.to(device)

    mytensor = my_tensor.to(device)

    这行代码的意思是将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。

    这句话需要写的次数等于需要保存GPU上的tensor变量的个数;一般情况下这些tensor变量都是最开始读数据时的tensor变量,后面衍生的变量自然也都在GPU上

     

    二、.cuda() 只能指定GPU

    #指定某个GPU
    os.environ['CUDA_VISIBLE_DEVICE']='1'
    model.cuda()
    #如果是多GPU
    os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    device_ids = [0,1,2,3]
    net  = torch.nn.Dataparallel(net, device_ids =device_ids)
    net  = torch.nn.Dataparallel(net) # 默认使用所有的device_ids 
    net = net.cuda()
    class DataParallel(Module):
        def __init__(self, module, device_ids=None, output_device=None, dim=0):
            super(DataParallel, self).__init__()
    
            if not torch.cuda.is_available():
                self.module = module
                self.device_ids = []
                return
    
            if device_ids is None:
                device_ids = list(range(torch.cuda.device_count()))
            if output_device is None:
                output_device = device_ids[0]
  • 相关阅读:
    远程、标签
    NUnit单元测试资料汇总
    jdk1.6下载页面
    javac: cannot execute binary file
    how to remove MouseListener / ActionListener on a JTextField
    Linux下chkconfig命令详解(转)
    如何让vnc控制由默认的twm界面改为gnome?(转)
    winzip15.0注冊码
    微服务的优缺点
    站点建设10个最好的响应的HTML5滑块插件
  • 原文地址:https://www.cnblogs.com/h694879357/p/15984367.html
Copyright © 2020-2023  润新知