• 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)


    一、概述

     上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断。思路很简单,就是写一个自定义的数据处理通道,输入为文件名,输出为float数字,里面保存的是像素信息。

     样本包括6万张训练图片和1万张测试图片,图片为灰度图片,分辨率为20*20 。train_tags.tsv文件对每个图片的数值进行了标记,如下:

      

    二、源码

     全部代码: 

    namespace MulticlassClassification_Mnist
    {
        class Program
        {
            //Assets files download from:https://gitee.com/seabluescn/ML_Assets
            static readonly string AssetsFolder = @"D:StepByStepBlogsML_AssetsMNIST";
            static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "train_tags.tsv");
            static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "train");
            static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip");
    
            static void Main(string[] args)
            {
                MLContext mlContext = new MLContext(seed: 1);
              
                TrainAndSaveModel(mlContext);
                TestSomePredictions(mlContext);
    
                Console.WriteLine("Hit any key to finish the app");
                Console.ReadKey();
            }
    
            public static void TrainAndSaveModel(MLContext mlContext)
            {
                // STEP 1: 准备数据
                var fulldata = mlContext.Data.LoadFromTextFile<InputData>(path: TrainTagsPath, separatorChar: '	', hasHeader: false);
                var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1);
                var trainData = trainTestData.TrainSet;
                var testData = trainTestData.TestSet;
    
                // STEP 2: 配置数据处理管道        
                var dataProcessPipeline = mlContext.Transforms.CustomMapping(new LoadImageConversion().GetMapping(), contractName: "LoadImageConversionAction")
                   .Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue))
                   .Append(mlContext.Transforms.NormalizeMeanVariance( outputColumnName: "FeaturesNormalizedByMeanVar", inputColumnName: "ImagePixels"));
    
    
                // STEP 3: 配置训练算法 (using a maximum entropy classification model trained with the L-BFGS method)
                var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "Label", featureColumnName: "FeaturesNormalizedByMeanVar");
                var trainingPipeline = dataProcessPipeline.Append(trainer)
                     .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictNumber", "Label"));
    
    
                // STEP 4: 训练模型使其与数据集拟合           
                ITransformer trainedModel = trainingPipeline.Fit(trainData);          
    
                // STEP 5:评估模型的准确性           
                var predictions = trainedModel.Transform(testData);
                var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Label", scoreColumnName: "Score");
                PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
              
                // STEP 6:保存模型            
                mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath);           
            }
    
            private static void TestSomePredictions(MLContext mlContext)
            {
                // Load Model           
                ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);
    
                // Create prediction engine 
                var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel);
              
                DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test"));           
                foreach(var image in TestFolder.GetFiles())
                {
                    count++;
    
                    InputData img = new InputData()
                    {
                        FileName = image.Name
                    };
                    var result = predEngine.Predict(img);
                   
                    Console.WriteLine($"Current Source={img.FileName},PredictResult={result.GetPredictResult()}");                
                }
            }       
        }
    
        class InputData
        {
            [LoadColumn(0)]
            public string FileName;
    
            [LoadColumn(1)]
            public string Number;
    
            [LoadColumn(1)]
            public float Serial;       
        }
    
        class OutPutData : InputData
        {
            public float[] Score;
            public int GetPredictResult()
            {
                float max = 0;
                int index = 0;
                for (int i = 0; i < Score.Length; i++)
                {
                    if (Score[i] > max)
                    {
                        max = Score[i];
                        index = i;
                    }
                }
                return index;
            }       
        }   
    }
    View Code

      

    三、分析

     整个处理流程和上一篇文章基本一致,这里解释两个不一样的地方。

    1、自定义的图片读取处理通道

    namespace MulticlassClassification_Mnist
    {
        public class LoadImageConversionInput
        {
            public string  FileName { get; set; }
        }
     
        public class LoadImageConversionOutput
        {
            [VectorType(400)]
            public float[] ImagePixels { get; set; }
            public string ImagePath;
        }
    
        [CustomMappingFactoryAttribute("LoadImageConversionAction")]
        public class LoadImageConversion : CustomMappingFactory<LoadImageConversionInput, LoadImageConversionOutput>
        {       
            static readonly string TrainDataFolder = @"D:StepByStepBlogsML_AssetsMNIST	rain";
    
            public void CustomAction(LoadImageConversionInput input, LoadImageConversionOutput output)
            {  
                string ImagePath = Path.Combine(TrainDataFolder, input.FileName);
                output.ImagePath = ImagePath;
    
                Bitmap bmp = Image.FromFile(ImagePath) as Bitmap;           
    
                output.ImagePixels = new float[400];
                for (int x = 0; x < 20; x++)
                    for (int y = 0; y < 20; y++)
                    {
                        var pixel = bmp.GetPixel(x, y);
                        var gray = (pixel.R + pixel.G + pixel.B) / 3 / 16;
                        output.ImagePixels[x + y * 20] = gray;
                    }           
                bmp.Dispose();                     
            }
    
            public override Action<LoadImageConversionInput, LoadImageConversionOutput> GetMapping()
                  => CustomAction;
        }
    }

     这里可以看出,我们自定义的数据处理通道,输入为文件名称,输出是一个float数组,这里数组必须要指定宽度,由于图片分辨率为20*20,所以数组宽度指定为400,输出ImagePath为文件详细地址,用来调试使用,没有实际用途。处理思路非常简单,遍历每个Pixel,计算其灰度值,为了减少工作量我们把灰度值进行缩小,除以了16 ,由于后面数据会做归一化,所以这里影响不是太明显。

    2、模型测试

                DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test"));
                int count = 0;
                int success = 0;
                foreach(var image in TestFolder.GetFiles())
                {
                    count++;
    
                    InputData img = new InputData()
                    {
                        FileName = image.Name
                    };
                    var result = predEngine.Predict(img);
    
                    if(int.Parse(image.Name.Substring(0,1))==result.GetPredictResult())
                    {
                        success++;
                    }                
                }

     我们把测试目录里的全面图片读出遍历了一遍,将其测试结果和实际结果做了一次验证,实际上是把评估(Evaluate)的事情又重复做了一次,两次测试的成功率基本接近。

    四、关于图片特征提取

    我们是采用图片所有像素的灰度值来作为特征值的,但必须要强调的是:像素值矩阵不是图片的典型特征。虽然有时候对于较规则的图片,通过像素提取方式进行计算,也可以取得很好的效果,但在处理稍微复杂一点的图片的时候,就不管用了,原因很明显,我们人类在分析图片内容时看到的特征更多是线条等信息,绝对不是像素值,看下图:

    我们人类很容易就判断出这两个图片表达的是同一件事情,但其像素值特征却相差甚远。

     传统的图片特征提取方式很多,比如:SIFT、HOG、LBP、Haar等。 现在采用TensorFlow的模型进行特征提取效果非常好。下一篇文章介绍图片分类时再进行详细介绍。 

    五、资源获取

    源码下载地址:https://github.com/seabluescn/Study_ML.NET

    工程名称:MulticlassClassification_Mnist_Useful

    MNIST资源获取:https://gitee.com/seabluescn/ML_Assets

    点击查看机器学习框架ML.NET学习笔记系列文章目录

  • 相关阅读:
    淘宝IP地址库采集
    Android MediaCodec硬编兼容性测试方案
    《Tensorflow实战》之6.3VGGnet学习
    tensorflow问题集锦
    <tensorflow实战>之5.3实现进阶的卷积网咯
    CNN_minist
    tensorflow之MLP学习
    tensorflow学习之等价代码
    tensorflow学习之softmax regression
    NPE进一步学习
  • 原文地址:https://www.cnblogs.com/seabluescn/p/10942116.html
Copyright © 2020-2023  润新知