• 线性模型第一课


    import random
    import torch
    import matplotlib as pl
    #生成样本,特征和标签
    def synthetic_data(w,b,num_examples):
        x=torch.normal(0,1,(num_examples,len(w)))  #均值为0方差为1,w列
        y=torch.matmul(x,w)+b  #x*w+b
        y+=torch.normal(0,0.01,y.shape) #均值为0,方差为0.01的随机噪音
        return x,y.reshape((-1,1))
    true_w=torch.tensor([2,-3.4])
    true_b=4.2
    features,lables=synthetic_data(true_w,true_b,1000)
    
    #定义函数,接受批量大小,特征矩阵和标签向量作为输入,生成大小为batch_size的小批量
    def data_iter(batch_size,features,lables):
        num_examples=len(features)#多个样本
        indices=list(range(num_examples))#每个样本的index
        random.shuffle(indices)#打乱顺序
        for i in range(0,num_examples,batch_size):#从0开始,每次跳batch_size个大小
            batch_indics=torch.tensor(
                indices[i:min(i+batch_size,num_examples)]
            )#每次拿b个index。如果不够拿一个最小值
            yield features[batch_indics],lables[batch_indics]
    batch_size=10
    for x,y in data_iter(batch_size,features,lables):
        print(x,'\n',y)
        break
    w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
    b=torch.zeros(1,requires_grad=True)
    #定义线性模型
    def linreg(x,w,b):
        return torch.matmul(x,w)+b
    #定义损失函数
    #1/2(y-y_hat)(y-y_hat)
    def squared_loss(y_hat,y):
        return (y_hat-y.reshape(y_hat.shape))**2/2
    #定义最优化算法
    #给定所有参数,学习率和批量
    def sgd(params,lr,batch_size):
        with torch.no_grad():
            for param in params:
                param-=lr*param.grad/batch_size
                param.grad.zero_()
    lr=0.03
    num_epochs=3#整个模型扫3遍
    net=linreg#线性模型
    loss=squared_loss
    for epoch in range(num_epochs):
        for x,y in data_iter(batch_size,features,lables):
            l=loss(net(x,w,b),y)#求y_hat,再求y_hat和y的损失
            l.sum().backward()
            sgd([w,b],lr,batch_size)
        with torch.no_grad():
            train_1=loss(net(features,w,b),lables)
            print(f'epoch{epoch+1} ,loss{float(train_1.mean()):f}')
    View Code
  • 相关阅读:
    crontab 启动supervisor爬虫
    frida初体验
    Protobuf 的数据反解析
    adb
    突破SSL Pinning抓app的数据包
    Charles下载与配置
    替换小技巧
    docker 使用
    pandas读取excel
    docker 安装
  • 原文地址:https://www.cnblogs.com/wangzhaojun1670/p/16172955.html
Copyright © 2020-2023  润新知