• TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.


    代码

    # -*- coding: utf-8 -*-
    """
    Created on Sat Feb 19 13:19:30 2022
    
    @author: koneko
    """
    
    from matplotlib  import pyplot as plt
    import torch 
    import math
    
    dtype = torch.float
    device = torch.device("cuda:0")
    
    # Create random input and output data
    x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
    y = torch.sin(x)
    
    # Randomly initialize weights
    a = torch.randn((), device=device, dtype=dtype)
    b = torch.randn((), device=device, dtype=dtype)
    c = torch.randn((), device=device, dtype=dtype)
    d = torch.randn((), device=device, dtype=dtype)
    
    lr = 1e-6
    
    for t in range(2000):
        # Forward pass: compute predicted y
        y_pred = a + b * x + c * x ** 2 + d * x ** 3
    
        # Compute and print loss
        loss = (y_pred - y).pow(2).sum().item()
        if t % 100 == 99:
            print(t, loss)
    
        # Backprop to compute gradients of a, b, c, d with respect to loss
        grad_y_pred = 2.0 * (y_pred - y)
        grad_a = grad_y_pred.sum()
        grad_b = (grad_y_pred * x).sum()
        grad_c = (grad_y_pred * x ** 2).sum()
        grad_d = (grad_y_pred * x ** 3).sum()
    
        # Update weights using gradient descent
        a -= lr * grad_a
        b -= lr * grad_b
        c -= lr * grad_c
        d -= lr * grad_d
    
    
    print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
    
    x = x.numpy()
    y_pred = y_pred.numpy()
    
    plt.plot(x,y_pred)
    
    

    报错信息

    TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
    

    原因

    看信息应该是说数据在显存里plt不能直接调用?所以要先复制到宿主内存里面

    解决方法

    倒数第二三行修改为:

    x = x.cpu().numpy()
    y_pred = y_pred.cpu().numpy()
    
  • 相关阅读:
    java文件下载
    java程序运行原理
    java io流(核心:读进来,写出去)
    oracle操作表和字段的sql复习
    深入理解C/S和B/S模式
    Windows PyCharm永久激活
    MacBook PyCharm永久激活
    百度云同同步盘 mac版
    SJW-遍历我的账户左侧导航页面(句柄切换)
    python-selenium无法调用浏览器的问题==
  • 原文地址:https://www.cnblogs.com/urahyou/p/15916662.html
Copyright © 2020-2023  润新知