• 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类


     一、概述

    通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下:

    之前介绍过,图片的特征是不能采用像素的灰度值的,这部分原理的台阶有点高,还好可以直接使用通过TensorFlow训练过的特征提取模型(美其名曰迁移学习)。

    模型文件为:tensorflow_inception_graph.pb

    二、样本介绍

     我随便在网上找了一些图片,分成6类:男孩、女孩、猫、狗、男人、女人。tags文件标记了每个文件所代表的类型标签(Label)。

    通过对这六类图片的学习,期望输入新的图片时,可以判断出是何种类型。

    三、代码

     全部代码:

    namespace TensorFlow_ImageClassification
    {    
    
        class Program
        {
            //Assets files download from:https://gitee.com/seabluescn/ML_Assets
            static readonly string AssetsFolder = @"D:StepByStepBlogsML_Assets";
            static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "ImageClassification", "train");
            static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "ImageClassification", "train_tags.tsv");
            static readonly string TestDataFolder = Path.Combine(AssetsFolder, "ImageClassification","test");
            static readonly string inceptionPb = Path.Combine(AssetsFolder, "TensorFlow", "tensorflow_inception_graph.pb");
            static readonly string imageClassifierZip = Path.Combine(Environment.CurrentDirectory, "MLModel", "imageClassifier.zip");
    
            //配置用常量
            private struct ImageNetSettings
            {
                public const int imageHeight = 224;
                public const int imageWidth = 224;
                public const float mean = 117;
                public const float scale = 1;
                public const bool channelsLast = true;
            }
    
            static void Main(string[] args)
            {
                TrainAndSaveModel();
                LoadAndPrediction();
    
                Console.WriteLine("Hit any key to finish the app");
                Console.ReadKey();
            }
    
            public static void TrainAndSaveModel()
            {
                MLContext mlContext = new MLContext(seed: 1);
    
                // STEP 1: 准备数据
                var fulldata = mlContext.Data.LoadFromTextFile<ImageNetData>(path: TrainTagsPath, separatorChar: '	', hasHeader: false);
    
                var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1);
                var trainData = trainTestData.TrainSet;
                var testData = trainTestData.TestSet;
    
                // STEP 2:创建学习管道
                var pipeline = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelTokey", inputColumnName: "Label")
                    .Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: TrainDataFolder, inputColumnName: nameof(ImageNetData.ImagePath)))
                    .Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "input"))
                    .Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageNetSettings.channelsLast, offsetImage: ImageNetSettings.mean))
                    .Append(mlContext.Model.LoadTensorFlowModel(inceptionPb).
                         ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
                    .Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelTokey", featureColumnName: "softmax2_pre_activation"))
                    .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))
                    .AppendCacheCheckpoint(mlContext);
    
                // STEP 3:通过训练数据调整模型    
                ITransformer model = pipeline.Fit(trainData);
    
                // STEP 4:评估模型
                Console.WriteLine("===== Evaluate model =======");
                var evaData = model.Transform(testData);
                var metrics = mlContext.MulticlassClassification.Evaluate(evaData, labelColumnName: "LabelTokey", predictedLabelColumnName: "PredictedLabel");
                PrintMultiClassClassificationMetrics(metrics);
    
                //STEP 5:保存模型
                Console.WriteLine("====== Save model to local file =========");
                mlContext.Model.Save(model, trainData.Schema, imageClassifierZip);
            }
    
            static void LoadAndPrediction()
            {
                MLContext mlContext = new MLContext(seed: 1);
    
                // Load the model
                ITransformer loadedModel = mlContext.Model.Load(imageClassifierZip, out var modelInputSchema);
    
                // Make prediction function (input = ImageNetData, output = ImageNetPrediction)
                var predictor = mlContext.Model.CreatePredictionEngine<ImageNetData, ImageNetPrediction>(loadedModel);
                
                DirectoryInfo testdir = new DirectoryInfo(TestDataFolder);
                foreach (var jpgfile in testdir.GetFiles("*.jpg"))
                {
                    ImageNetData image = new ImageNetData();
                    image.ImagePath = jpgfile.FullName;
                    var pred = predictor.Predict(image);
    
                    Console.WriteLine($"Filename:{jpgfile.Name}:	Predict Result:{pred.PredictedLabelValue}");
                }
            }       
        }
    
        public class ImageNetData
        {
            [LoadColumn(0)]
            public string ImagePath;
    
            [LoadColumn(1)]
            public string Label;
        }
    
        public class ImageNetPrediction
        {
            //public float[] Score;
            public string PredictedLabelValue;
        }   
    }
    View Code

      

    四、分析

     1、数据处理通道

    可以看出,其代码流程与结构与上两篇文章介绍的完全一致,这里就介绍一下核心的数据处理模型部分的代码:

    var pipeline = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelTokey", inputColumnName: "Label")
      .Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: TrainDataFolder, inputColumnName: nameof(ImageNetData.ImagePath)))
      .Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "input"))
      .Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageNetSettings.channelsLast, offsetImage: ImageNetSettings.mean))
      .Append(mlContext.Model.LoadTensorFlowModel(inceptionPb).
              ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
      .Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelTokey", featureColumnName: "softmax2_pre_activation"))
      .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))

    MapValueToKey与MapKeyToValue之前已经介绍过了;
    LoadImages是读取文件,输入为文件名、输出为Image;
    ResizeImages是改变图片尺寸,这一步是必须的,即使所有训练图片都是标准划一的图片也需要这个操作,后面需要根据这个尺寸确定容纳图片像素信息的数组大小;
    ExtractPixels是将图片转换为包含像素数据的矩阵;
    LoadTensorFlowModel是加载第三方模型,ScoreTensorFlowModel是调用模型处理数据,其输入为:“input”,输出为:“softmax2_pre_activation”,由于模型中输入、输出的名称是规定的,所以,这里的名称不可以随便修改。
    分类算法采用的是L-BFGS最大熵分类算法,其特征数据为TensorFlow网络输出的值,标签值为"LabelTokey"。

    2、验证过程
                MLContext mlContext = new MLContext(seed: 1);
                ITransformer loadedModel = mlContext.Model.Load(imageClassifierZip, out var modelInputSchema);           
                var predictor = mlContext.Model.CreatePredictionEngine<ImageNetData, ImageNetPrediction>(loadedModel);
                            
                ImageNetData image = new ImageNetData();
                image.ImagePath = jpgfile.FullName;
                var pred = predictor.Predict(image);
                Console.WriteLine($"Filename:{jpgfile.Name}:	Predict Result:{pred.PredictedLabelValue}");

     两个实体类代码:

        public class ImageNetData
        {
            [LoadColumn(0)]
            public string ImagePath;
    
            [LoadColumn(1)]
            public string Label;
        }
    
        public class ImageNetPrediction
        {       
            public string PredictedLabelValue;
        } 
    3、验证结果
    我在网络上又随便找了20张图片进行验证,发现验证成功率是非常高的,基本都是准确的,只有两个出错了。

    上图片被识别为girl(我认为是Woman),这个情有可原,本来girl和worman在外貌上也没有一个明确的分界点。

    上图被识别为woman,这个也情有可原,不解释。

    需要了解的是:不管你输入什么图片,最终的结果只能是以上六个类型之一,算法会寻找到和六个分类中特征最接近的一个分类作为结果。


    4、调试
    注意看实体类的话,我们只提供了三个基本属性,如果想看一下在学习过程中数据是如何处理的,可以给ImageNetPrediction类增加一些字段进行调试。
    首先我们需要看一下IDateView有哪些列(Column)
                var predictions = trainedModel.Transform(testData);          
    
                var OutputColumnNames = predictions.Schema.Where(col => !col.IsHidden).Select(col => col.Name);
                foreach (string column in OutputColumnNames)
                {
                    Console.WriteLine($"OutputColumnName:{ column }");
                }

     将我们要调试的列加入到实体对象中去,特别要注意数据类型。

        public class ImageNetPrediction
        {
            public float[] Score;
            public string PredictedLabelValue; 
            public string Label;
           
            public void PrintToConsole()
            {
                //打印字段信息
            }
        }  

     查看数据集详细信息:

               var predictions = trainedModel.Transform(testData); 
                var DataShowList = new List<ImageNetPrediction>(mlContext.Data.CreateEnumerable<ImageNetPrediction>(predictions, false, true));
               foreach (var dataline in DataShowList)
                {                
                        dataline.PrintToConsole();                               
                }
    
    

    五、资源获取 

    源码下载地址:https://github.com/seabluescn/Study_ML.NET

    工程名称:TensorFlow_ImageClassification

    资源获取:https://gitee.com/seabluescn/ML_Assets

    点击查看机器学习框架ML.NET学习笔记系列文章目录

  • 相关阅读:
    常用校验码(奇偶校验,海明校验,CRC)学习总结
    .net获取项目根目录方法集合
    C#读写config配置文件
    C# 将ComboBox设置为禁止编辑的方法
    C#中查询数据库时返回的影响行数等于-1?
    UserControl 的一个值得注意的问题 [属性" * "的代码生成失败.错误是:"程序集"*.Version=1.0.0.0,Culture=neutral,..........无标记为序列化"
    C#实现对象序列化为XML
    螺旋矩阵的几种打印形式
    单例模式
    css-text-decoration-skip
  • 原文地址:https://www.cnblogs.com/seabluescn/p/10944579.html
Copyright © 2020-2023  润新知