• 详解pytorch中的max方法


    实际上pytorch官方文档中的应该是torch.max(input)方法,而本文要讲的可能严格意义上不是torch中的,而是针对torch中的张量方法,即input.max(axis)[index]
    其中input表示要求取最大值的张量,axis可以为0(表示求取每列的最大值),也可以为1(每行的最大值)。index为0表示只返回最大值本身,为1表示只返回最大值对应的索引。如下,其中axis可以省去:

    a = torch.Tensor([[0,3,2],[4,0,0]])
    print(a.max(axis=0)[0]) # tensor([4., 3., 2.]),即第一列为[0 4]最大值为4,第二列为[3 0],依此类推
    print(a.max(axis=0)[1]) # tensor([1, 0, 0]),索引也是列的索引
    print(a.max(axis=1)[0]) # tensor([3., 4.]),取各行的最大值
    print(a.max(axis=1)[1]) # tensor([1, 0]),对应的索引
    

    应用

    在求解强化学习中需要qmaxq_{max}qmax对应的action时,通常是输入一个张量即神经网络算出的q值,然后输出q值对应的索引,输出的是int型,如下:

    import torch
    q = torch.Tensor([[0,3,2,1]])
    action=q.max(1)[1].item() # .item()将只有一个元素的张量变为对应的元素
    action=q.max(1)[1].view(1,1).item() # 如果不放心可在前面加view方法shape成只有一个元素的张量
    
  • 相关阅读:
    mvc3在各个IIS版本中的部署
    linq学习
    常用的正则表达式
    Jenkins+Git+Maven+Tomcat的初步学习
    12个用得着的JQuery代码片段
    JQuery原理介绍及学习方法
    【前端学习】javascript面向对象编程(继承和复用)
    c# throw和throw ex
    .net 信息采集ajax数据
    C# FileSystemWatcher 并发
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13281640.html
Copyright © 2020-2023  润新知