• Facebook 发布深度学习工具包 PyTorch Hub,让论文复现变得更容易


    近日,PyTorch 社区发布了一个深度学习工具包 PyTorchHub, 帮助机器学习工作者更快实现重要论文的复现工作。PyTorchHub 由一个预训练模型仓库组成,专门用于提高研究工作的复现性以及新的研究。同时它还内置了对Google Colab的支持,并与Papers With Code集成。目前 PyTorchHub 包括了一系列与图像分类、分割、生成以及转换相关的模型。

    可复现性是许多研究领域的基本要求,这其中当然包括基于机器学习技术的研究领域。然而, 许多机器学习相关论文要么无法复现,要么难以重现。随着论文数量的持续增长,包括目前在 arXiv 上预印刷的数万份论文以及提交给会议的论文,研究工作的可复现性变得越来越重要。虽然其中许多论文都附有代码以及训练好的模型,但这种帮助显然非常有限,复现过程中仍有大量需要读者自己摸索的步骤。下面让我们来看一下如何通过 PyTorch Hub 这一利器完成快速的模型发布与工作复现。

    image

    如何快速发布模型

    这部分主要介绍了对于模型发布者来说如何快速高效的将自己的模型加入 PyTorch Hub 库。PyTorch Hub 支持通过添加简单的 hubconf.py 文件将预先训练的模型(模型定义和预先训练重)发布到 GitHub 存储库。这提供了模型列表以及其依赖库列表。一些示例可以在torchvisionhuggingface-bertgan-model-zoo存储库中找到。

    Pytoch 社区给出了 torchvision 的 hubconf.py 文件的示例:

    �复制代码
     
     
    # Optional list of dependencies required by the package
     
    dependencies = ['torch']
       
     
    from torchvision.models.alexnet import alexnet
     
    from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
     
    from torchvision.models.inception import inception_v3
     
    from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d
     
    from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
     
    from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
     
    from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
     
    from torchvision.models.googlenet import googlenet
     
    from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
     
    from torchvision.models.mobilenet import mobilenet_v2

    在 torchvision 中,模型有以下特性:

    • 每个模型文件可以被独立执行或实现某个功能
    • 不需要除了 PyTorch 之外的任何软件包(在 hubconf.py 中编码为 dependencies[‘torch’])
    • 他们不需要单独的入口点,因为模型在创建时可以无缝地开箱即用。

    PyTroch 社区认为最小化包依赖性可减少用户加载模型时遇到的困难。这里他们给出了一个更为复杂的例子——HuggingFace’s BERT 模型,它的 hubconf.py 如下:

    �复制代码
     
     
    dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
       
     
    from hubconfs.bert_hubconf import (
     
    bertTokenizer,
     
    bertModel,
     
    bertForNextSentencePrediction,
     
    bertForPreTraining,
     
    bertForMaskedLM,
     
    bertForSequenceClassification,
     
    bertForMultipleChoice,
     
    bertForQuestionAnswering,
     
    bertForTokenClassification
     
    )

    此外,对于每个模型,PyTorch 官方提到都需要为其创建一个入口点。下面是一个用于指定 bertForMaskedLM 模型的入口点的代码片段,这部分代码完成的功能是返回加载了预训练参数的模型。

    �复制代码
     
     
    def bertForMaskedLM(*args, **kwargs):
     
    """
     
    BertForMaskedLM includes the BertModel Transformer followed by the
     
    pre-trained masked language modeling head.
     
    Example:
     
    ...
     
    """
     
    model = BertForMaskedLM.from_pretrained(*args, **kwargs)
     
    return model

    这些入口点可以看成是复杂的模型结构的一种封装形式。它们可以在提供简洁高效的帮助文档的同时完成下载预训练权重的功能(例如,通过 pretrained = True),也可以集成其他特定功能,例如可视化。

    通过 hubconf.py,模型发布者可以在 Github 上基于template提交他们的合并请求。PyTorch 社区希望通过 PyTorch Hub 创建一系列高质量、易复现且效果好的模型以提高研究工作的复现性。因此,PyTorch 会通过与模型发布者合作的方式以完善请求,并有可能会在某些情况下拒绝发布一些低质量的模型。一旦 PyTorch 社区接受了模型发布者的请求,这些新的模型将会很快出现在 PyTorch Hub 的网页上以供用户浏览。

    用户工作流

    对于想使用 PyTorch Hub 对别人的工作进行复现的用户,PyTorch Hub 提供了以下几个步骤:1)浏览可用的模型;2)加载模型;3)探索已加载的模型。下面让我们来浏览几个例子。

    浏览可用的入口点

    用户可以使用 torch.hub.list() API 列出仓库中的所有可用入口点。

    �复制代码
     
     
    >>> torch.hub.list('pytorch/vision')
     
    >>>
     
    ['alexnet',
     
    'deeplabv3_resnet101',
     
    'densenet121',
     
    ...
     
    'vgg16',
     
    'vgg16_bn',
     
    'vgg19',
     
    'vgg19_bn']

    注意,PyTorch Hub 还允许辅助入口点(除了预训练模型),例如,用于 BERT 模型预处理的 bertTokenizer,它可以使用户工作流程更加顺畅。

    加载模型

    对于 PyTroch Hub 中可用的模型,用户可以使用 torch.hub.load() API 加载模型入口点。此外,torch.hub.help() API 可以提供有关如何实例化模型的有用信息。

    �复制代码
     
     
    print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
     
    model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)

    由于仓库的持有者会不断添加错误修复以及性能改进,PyTorch Hub 允许用户通过调用以下内容简单地获取最新更新:

    �复制代码
     
     
    model = torch.hub.load(..., force_reload=True)

    这一举措可以有效地减轻仓库持有者重复发布模型的负担,从而使他们能够更专注于自己的研究工作。同时,也确保了用户可以获得最新版本的模型。

    此外,对于用户来说,稳定性也是一个重要问题。因此,某些模型所有者会从特征的分支或标签为他们提供服务,以确保代码的稳定性。例如,pytorch_GAN_zoo 会从 hub 分支为他们提供服务:

    �复制代码
     
     
    model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)

    这里,传递给 hub.load() 的 * args,** kwargs 用于实例化模型。在上面的示例中,pretrained = True 和 useGPU = False 被传递给模型的入口点。

    探索已加载的模型

    从 PyTorch Hub 加载模型后,用户可以使用以下工作流查看已加载模型的可用方法,并更好地了解运行它所需的参数。

    其中,dir(model) 可以查看模型中可用的方法。下面是 bertForMaskedLM 的一些方法:

    �复制代码
     
     
    >>> dir(model)
     
    >>>
     
    ['forward'
     
    ...
     
    'to'
     
    'state_dict',
     
    ]

    help(model.forward)则会提供使已加载的模型运行时所需参数的视图:

    �复制代码
     
     
    >>> help(model.forward)
     
    >>>
     
    Help on method forward in module pytorch_pretrained_bert.modeling:
     
    forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
     
    ...

    更多细节可以查看BERTDeepLabV3页面:

    其他探索方式与相关资源

    PyTorch Hub 中提供的模型也支持 Colab,并且会直接链接在 Papers With Code 上,用户只需单击链接即可开始使用:

    image

    PyTorch 提供了一些相关资源帮助用户快速上手 PyTorch Hub:

    FAQ

    问:如果我们想贡献一个 Hub 中已经有了的模型,但也许我的模型具有更高的准确性,我还应该贡献吗?
    答:是的,请提交您的模型,Hub 的下一步是开发投票系统以展示最佳模型。

    问:谁负责保管 PyTorch Hub 的模型权重?
    答:作为贡献者,您负责保管模型权重。您可以在您喜欢的云存储中托管您的模型,或者如果它符合限制,则可以在 GitHub 上托管您的模型。 如果您无法保管权重,请通过 Hub 仓库中提交问题的方式与我们联系。

    问:如果我的模型使用了私有化数据进行训练怎么办?我还应该贡献这个模型吗?
    答:请不要提交您的模型!PyTorch Hub 以开源研究为中心,并扩展到使用公开数据集来训练这些模型。如果提交了私有模型的合并请求,我们将恳请您重新提交使用公开数据进行训练后的模型。

    问:我下载的模型保存在哪里?
    答:我们遵循 XDG 基本目录规范,并遵循缓存文件和目录的通用标准。这些位置按以下顺序使用:

    • 调用 hub.set_dir(<PATH_TO_HUB_DIR>)
    • 如果环境变量了 TORCH_HOME,则为 $TORCH_HOME/hub。
    • 如果设置了环境变量 XDG_CACHE_HOME,则为 $ XDG_CACHE_HOME / torch / hub。
    • ~/.cache/torch/hub

    相关推荐:

  • 相关阅读:
    cogs 1272. [AHOI2009] 行星序列
    1027. 打印沙漏(20)
    1026. 程序运行时间(15)
    1023. 组个最小数 (20)
    《C语言程序设计(第四版)》阅读心得(四 文件操作)
    1022. D进制的A+B (20)
    1021. 个位数统计 (15)
    1020. 月饼 (25)
    1015. 德才论 (25)
    1009. 说反话 (20)
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11262511.html
Copyright © 2020-2023  润新知