• 基于 ONNX 在 ML.NET 中使用 Pytorch 训练的垃圾分类模型


    ML.NET 在经典机器学习范畴内,对分类、回归、异常检测等问题开发模型已经有非常棒的表现了,我之前的文章都有过介绍。当然我们希望在更高层次的领域加以使用,例如计算机视觉、自然语言处理和信号处理等等领域。

    图像识别是计算机视觉的一类分支,AI研发者们较为熟悉的是使用TensorFlow、Pytorch、Keras、MXNET等框架来训练深度神经网络模型,其中会涉及到CNN(卷积神经网络)、DNN(深度神经网络)的相关算法。

    ML.NET 在较早期的版本是无法支持这类研究的,可喜的是最新的版本不但能很好地集成 TensorFlow 的模型做迁移学习,还可以直接导入 DNN 常见预编译模型:AlexNet、ResNet18、ResNet50、ResNet101 实现对图像的分类、识别等。

    我特别想推荐的是,ML.NET 最新版本对 ONNX 的支持也是非常强劲,通过 ONNX 可以把众多其他优秀深度学习框架的模型引入到 .NET Core 运行时中,极大地扩充了 .NET 应用在智能认知服务的丰富程度。在 Microsoft Docs 中已经提供了一个基于 ONNX 使用 Tiny YOLOv2 做对象检测的例子。为了展现 ML.NET 在其他框架上的通用性,本文将介绍使用 Pytorch 训练的垃圾分类的模型,基于 ONNX 导入到 ML.NET 中完成预测。

    在2019年9月华为云举办了一次人工智能大赛·垃圾分类挑战杯,首次将AI与环保主题结合,展现人工智能技术在生活中的运用。有幸我看到了本次大赛亚军方案的分享,并且在 github 上找到了开源代码,按照 README 说明,我用 Pytorch 训练出了一个模型,并保存为garbage.pt 文件。

    生成 ONNX 模型

    首先,我使用以下 Pytorch 代码来生成一个garbage.pt 对应的文件,命名为 garbage.onnx

    torch_model = torch.load("garbage.pt") # pytorch模型加载
        batch_size = 1  #批处理大小
        input_shape = (3,224,224)   #输入数据
    
        # # set the model to inference mode
        torch_model.eval()
    
        x = torch.randn(batch_size, *input_shape, device='cuda')        # 生成张量
        export_onnx_file = "garbage.onnx"                    # 目的ONNX文件名
     
        
        torch.onnx.export(torch_model.module,
                            x,
                            export_onnx_file,
                            input_names=["input"],        # 输入名
                            output_names=["output"]    # 输出名

    准备 ML.NET 项目

    创建一个 .NET Core 控制台应用,按如下结构创建好合适的目录。assets 目录下的 images 子目录将放置待预测的图片,而 Model 子目录放入前一个步骤生成的 garbage.onnx 模型文件。

    ImageNetData 和 ImageNetPrediction 类定义了输入和输出的数据结构。

    using System.Collections.Generic;
    using System.IO;
    using System.Linq;
    using Microsoft.ML.Data;
    
    namespace GarbageDemo.DataStructures
    {
        public class ImageNetData
        {
            [LoadColumn(0)]
            public string ImagePath;
    
            [LoadColumn(1)]
            public string Label;
    
            public static IEnumerable<ImageNetData> ReadFromFile(string imageFolder)
            {
                return Directory
                   .GetFiles(imageFolder)
                   .Where(filePath => Path.GetExtension(filePath) == ".jpg")
                   .Select(filePath => new ImageNetData { ImagePath = filePath, Label = Path.GetFileName(filePath) });
    
            }
        }
    
        public class ImageNetPrediction : ImageNetData
        {
            public float[] Score;
    
            public string PredictedLabelValue;
        }
    }

     OnnxModelScorer 类定义了 ONNX 模型加载、打分预测的过程。注意 ImageNetModelSettings 的属性和第一步中指定的输入输出字段名要一致。

    using System;
    using System.Collections.Generic;
    using System.Linq;
    using Microsoft.ML;
    using Microsoft.ML.Data;
    using Microsoft.ML.Transforms.Onnx;
    using Microsoft.ML.Transforms.Image;
    using GarbageDemo.DataStructures;
    
    namespace GarbageDemo
    {
        class OnnxModelScorer
        {
            private readonly string imagesFolder;
            private readonly string modelLocation;
            private readonly MLContext mlContext;
    
    
            public OnnxModelScorer(string imagesFolder, string modelLocation, MLContext mlContext)
            {
                this.imagesFolder = imagesFolder;
                this.modelLocation = modelLocation;
                this.mlContext = mlContext;
            }
    
            public struct ImageNetSettings
            {
                public const int imageHeight = 224;
                public const int imageWidth = 224;    
                public const float Mean = 127;
                public const float Scale = 1;
                public const bool ChannelsLast = false;
            } 
            
            public struct ImageNetModelSettings
            {
                // input tensor name
                public const string ModelInput = "input";
    
                // output tensor name
                public const string ModelOutput = "output";
            }
    
            private ITransformer LoadModel(string modelLocation)
            {
                Console.WriteLine("Read model");
                Console.WriteLine($"Model location: {modelLocation}");
                Console.WriteLine($"Default parameters: image size=({ImageNetSettings.imageWidth},{ImageNetSettings.imageHeight})");
    
                // Create IDataView from empty list to obtain input data schema
                var data = mlContext.Data.LoadFromEnumerable(new List<ImageNetData>());
    
                // Define scoring pipeline
                var pipeline = mlContext.Transforms.LoadImages(outputColumnName: ImageNetModelSettings.ModelInput, imageFolder: "", inputColumnName: nameof(ImageNetData.ImagePath))                           
                                .Append(mlContext.Transforms.ResizeImages(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                            imageWidth: ImageNetSettings.imageWidth, 
                                                                            imageHeight: ImageNetSettings.imageHeight, 
                                                                            inputColumnName: ImageNetModelSettings.ModelInput,
                                                                            resizing: ImageResizingEstimator.ResizingKind.IsoCrop,
                                                                            cropAnchor: ImageResizingEstimator.Anchor.Center
                                                                            ))
                                .Append(mlContext.Transforms.ExtractPixels(outputColumnName: ImageNetModelSettings.ModelInput, interleavePixelColors: ImageNetSettings.ChannelsLast))
                                .Append(mlContext.Transforms.NormalizeGlobalContrast(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                     inputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                     ensureZeroMean : true, 
                                                                                     ensureUnitStandardDeviation: true, 
                                                                                     scale: ImageNetSettings.Scale))
                                .Append(mlContext.Transforms.ApplyOnnxModel(modelFile: modelLocation, outputColumnNames: new[] { ImageNetModelSettings.ModelOutput }, inputColumnNames: new[] { ImageNetModelSettings.ModelInput }));
    
                // Fit scoring pipeline
                var model = pipeline.Fit(data);
    
                return model;
            }
    
            private IEnumerable<float[]> PredictDataUsingModel(IDataView testData, ITransformer model)
            {
                Console.WriteLine($"Images location: {imagesFolder}");
                Console.WriteLine("");
                Console.WriteLine("=====Identify the objects in the images=====");
                Console.WriteLine("");
    
                IDataView scoredData = model.Transform(testData);
    
                IEnumerable<float[]> probabilities = scoredData.GetColumn<float[]>(ImageNetModelSettings.ModelOutput);
    
                return probabilities;
            }
    
            public IEnumerable<float[]> Score(IDataView data)
            {
                var model = LoadModel(modelLocation);
    
                return PredictDataUsingModel(data, model);
            }
        }
    }

    Program 类中定义了调用过程,完成预测结果呈现。

    using GarbageDemo.DataStructures;
    using Microsoft.ML;
    using System;
    using System.Collections.Generic;
    using System.IO;
    using System.Linq;
    
    namespace GarbageDemo
    {
        class Program
        {
            static void Main(string[] args)
            {
                var assetsRelativePath = @"../../../assets";
                string assetsPath = GetAbsolutePath(assetsRelativePath);
                var modelFilePath = Path.Combine(assetsPath, "Model", "garbage.onnx");
                var imagesFolder = Path.Combine(assetsPath, "images");// Initialize MLContext
                MLContext mlContext = new MLContext();
    
                try
                {
                    // Load Data
                    IEnumerable<ImageNetData> images = ImageNetData.ReadFromFile(imagesFolder);
                    IDataView imageDataView = mlContext.Data.LoadFromEnumerable(images);
    
                    // Create instance of model scorer
                    var modelScorer = new OnnxModelScorer(imagesFolder, modelFilePath, mlContext);
    
                    // Use model to score data
                    IEnumerable<float[]> probabilities = modelScorer.Score(imageDataView);
    
                    int index = 0;
                    foreach (var probable in probabilities)
                    {
                        var scores = Softmax(probable);
    
                        var (topResultIndex, topResultScore) = scores.Select((predictedClass, index) => (Index: index, Value: predictedClass))
                            .OrderByDescending(result => result.Value)
                            .First();
                        Console.WriteLine("图片:{3} 
     分类{2} {0}:{1}", labels[topResultIndex], topResultScore, topResultIndex, images.ElementAt(index).ImagePath);
                        Console.WriteLine("=============================");
                        index++;
                    }
    
                }
                catch (Exception ex)
                {
                    Console.WriteLine(ex.ToString());
                }
    
                Console.WriteLine("========= End of Process..Hit any Key ========");
                Console.ReadLine();
            }
    
            public static string GetAbsolutePath(string relativePath)
            {
                FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
                string assemblyFolderPath = _dataRoot.Directory.FullName;
    
                string fullPath = Path.Combine(assemblyFolderPath, relativePath);
    
                return fullPath;
            }
    
            private static float[] Softmax(float[] values)
            {
                var maxVal = values.Max();
                var exp = values.Select(v => Math.Exp(v - maxVal));
                var sumExp = exp.Sum();
    
                return exp.Select(v => (float)(v / sumExp)).ToArray();
            }
    
            private static string[] labels = new string[]
            {
                "其他垃圾/一次性快餐盒",
                "其他垃圾/污损塑料",
                "其他垃圾/烟蒂",
                "其他垃圾/牙签",
                "其他垃圾/破碎花盆及碟碗",
                "其他垃圾/竹筷",
                "厨余垃圾/剩饭剩菜",
                "厨余垃圾/大骨头",
                "厨余垃圾/水果果皮",
                "厨余垃圾/水果果肉",
                "厨余垃圾/茶叶渣",
                "厨余垃圾/菜叶菜根",
                "厨余垃圾/蛋壳",
                "厨余垃圾/鱼骨",
                "可回收物/充电宝",
                "可回收物/包",
                "可回收物/化妆品瓶",
                "可回收物/塑料玩具",
                "可回收物/塑料碗盆",
                "可回收物/塑料衣架",
                "可回收物/快递纸袋",
                "可回收物/插头电线",
                "可回收物/旧衣服",
                "可回收物/易拉罐",
                "可回收物/枕头",
                "可回收物/毛绒玩具",
                "可回收物/洗发水瓶",
                "可回收物/玻璃杯",
                "可回收物/皮鞋",
                "可回收物/砧板",
                "可回收物/纸板箱",
                "可回收物/调料瓶",
                "可回收物/酒瓶",
                "可回收物/金属食品罐",
                "可回收物/锅",
                "可回收物/食用油桶",
                "可回收物/饮料瓶",
                "有害垃圾/干电池",
                "有害垃圾/软膏",
                "有害垃圾/过期药物",
                "可回收物/毛巾",
                "可回收物/饮料盒",
                "可回收物/纸袋"
            };

    选择一张图片放到 images 目录中,运行结果如下:

    有 0.88 的得分说明照片中的物品属于污损塑料,让我们看一下图片真相。

    果然是相当准确 ,并且把周边的附属物都过滤掉了。

    对于 ML.NET 训练深度神经网络模型支持更复杂的场景是不是更有信心了!

  • 相关阅读:
    杭电2059
    杭电2058
    php错误大集合
    显示IP地址
    超简单好用的屏幕录像工具
    jquery“不再提醒"功能
    KindEditor编辑器中的class自动过滤了
    实用案例:切换面板同时切换内容
    仿51返利用户图解教程
    JavaScript调用dataTable并获取其值(ASP.Net,VS2005)
  • 原文地址:https://www.cnblogs.com/BeanHsiang/p/13176454.html
Copyright © 2020-2023  润新知