• 机器学习(2)


    代码

     https://github.com/s055523/MNISTTensorFlowSharp

    数据的获得

    数据可以由http://yann.lecun.com/exdb/mnist/下载。之后,储存在trainDir中,下次就不需要下载了。

    /// <summary>
            /// 如果文件不存在就去下载
            /// </summary>
            /// <param name="urlBase">下载地址</param>
            /// <param name="trainDir">文件目录地址</param>
            /// <param name="file">文件名</param>
            /// <returns></returns>
            public static Stream MaybeDownload(string urlBase, string trainDir, string file)
            {
                if (!Directory.Exists(trainDir))
                {
                    Directory.CreateDirectory(trainDir);
                }
    
                var target = Path.Combine(trainDir, file);
                if (!File.Exists(target))
                {
                    var wc = new WebClient();
                    wc.DownloadFile(urlBase + file, target);
                }
                return File.OpenRead(target);
            }
    View Code

    数据格式处理

    下载下来的文件共有四个,都是扩展名为gz的压缩包。

    train-images-idx3-ubyte.gz  55000张训练图片和5000张验证图片

    train-labels-idx1-ubyte.gz     训练图片对应的数字标签(即答案)

    t10k-images-idx3-ubyte.gz   10000张测试图片

    t10k-labels-idx1-ubyte.gz     测试图片对应的数字标签(即答案)

    处理图片数据压缩包

    每个压缩包的格式为:

    偏移量

    类型

    意义

    0

    Int32

    2051或2049

    一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049)

    4

    Int32

    60000或10000

    压缩包的图片数

    8

    Int32

    28

    每个图片的行数

    12

    Int32

    28

    每个图片的列数

    16

    Unsigned byte

    0 - 255

    第一张图片的第一个像素

    17

    Unsigned byte

    0 - 255

    第一张图片的第二个像素

    因此,我们可以使用一个统一的方式将数据处理。我们只需要那些图片像素。

    /// <summary>
            /// 从数据流中读取下一个int32
            /// </summary>
            /// <param name="s"></param>
            /// <returns></returns>
            int Read32(Stream s)
            {
                var x = new byte[4];
                s.Read(x, 0, 4);
                return DataConverter.BigEndian.GetInt32(x, 0);
            }
    
            /// <summary>
            /// 处理图片数据
            /// </summary>
            /// <param name="input"></param>
            /// <param name="file"></param>
            /// <returns></returns>
            MnistImage[] ExtractImages(Stream input, string file)
            {
                //文件是gz格式的
                using (var gz = new GZipStream(input, CompressionMode.Decompress))
                {
                    //不是2051说明下载的文件不对
                    if (Read32(gz) != 2051)
                    {
                        throw new Exception("不是2051说明下载的文件不对: " + file);
                    }
                    //图片数
                    var count = Read32(gz);
                    //行数
                    var rows = Read32(gz);
                    //列数
                    var cols = Read32(gz);
    
                    Console.WriteLine($"准备读取{count}张图片。");
    
                    var result = new MnistImage[count];
                    for (int i = 0; i < count; i++)
                    {
                        //图片的大小(每个像素占一个bit)
                        var size = rows * cols;
                        var data = new byte[size];
    
                        //从数据流中读取这么大的一块内容
                        gz.Read(data, 0, size);
    
                        //将读取到的内容转换为MnistImage类型
                        result[i] = new MnistImage(cols, rows, data);
                    }
                    return result;
                }
            }
    View Code

    准备一个MnistImage类型:

    /// <summary>
        /// 图片类型
        /// </summary>
        public struct MnistImage
        {
            public int Cols, Rows;
            public byte[] Data;
            public float[] DataFloat;
    
            public MnistImage(int cols, int rows, byte[] data)
            {
                Cols = cols;
                Rows = rows;
                Data = data;
                DataFloat = new float[data.Length];
                for (int i = 0; i < data.Length; i++)
                {
                    //数据归一化(这里将0-255除255变成了0-1之间的小数)
                    //也可以归一为-0.5到0.5之间
                    DataFloat[i] = Data[i] / 255f;
                }
            }
        }
    View Code

    这样一来,图片数据就处理完成了。

    处理数字标签数据压缩包

    数字标签数据压缩包和图片数据压缩包的格式类似。

    偏移量

    类型

    意义

    0

    Int32

    2051或2049

    一个定死的魔术数。用来验证该压缩包是训练集(2051)或测试集(2049)

    4

    Int32

    60000或10000

    压缩包的数字标签数

    5

    Unsigned byte

    0 - 9

    第一张图片对应的数字

    6

    Unsigned byte

    0 - 9

    第二张图片对应的数字

    它的处理更加简单。

    /// <summary>
            /// 处理标签数据
            /// </summary>
            /// <param name="input"></param>
            /// <param name="file"></param>
            /// <returns></returns>
            byte[] ExtractLabels(Stream input, string file)
            {
                using (var gz = new GZipStream(input, CompressionMode.Decompress))
                {
                    //不是2049说明下载的文件不对
                    if (Read32(gz) != 2049)
                    {
                        throw new Exception("不是2049说明下载的文件不对:" + file);
                    }
                    var count = Read32(gz);
                    var labels = new byte[count];
    
                    gz.Read(labels, 0, count);
    
                    return labels;
                }
            }
    View Code

    将数字标签转化为二维数组:one-hot编码

    由于我们的数字为0-9,所以,可以视为有十个class。此时,为了后续的处理方便,我们将数字标签转化为数组。因此,一组标签就转换为了一个二维数组。

    例如,标签0变成[1,0,0,0,0,0,0,0,0,0]

    标签1变成[0,1,0,0,0,0,0,0,0,0]

    以此类推。

    /// <summary>
            /// 将数字标签一维数组转为一个二维数组
            /// </summary>
            /// <param name="labels"></param>
            /// <param name="numClasses">多少个类别,这里是10(0到9)</param>
            /// <returns></returns>
            byte[,] OneHot(byte[] labels, int numClasses)
            {
                var oneHot = new byte[labels.Length, numClasses];
                for (int i = 0; i < labels.Length; i++)
                {
                    oneHot[i, labels[i]] = 1;
                }
                return oneHot;
            }
    View Code

    到此为止,数据格式处理就全部结束了。下面的代码展示了数据处理的全过程。

            /// <summary>
            /// 处理数据集
            /// </summary>
            /// <param name="trainDir">数据集所在文件夹</param>
            /// <param name="numClasses"></param>
            /// <param name="validationSize">拿出多少做验证?</param>
            public void ReadDataSets(string trainDir, int numClasses = 10, int validationSize = 5000)
            {
                const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
                const string TrainImagesName = "train-images-idx3-ubyte.gz";
                const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
                const string TestImagesName = "t10k-images-idx3-ubyte.gz";
                const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";
    
                //获得训练数据,然后处理训练数据和测试数据
                TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
                TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
                TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
                TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);
    
                //拿出前面的一部分做验证
                ValidationImages = Pick(TrainImages, 0, validationSize);
                ValidationLabels = Pick(TrainLabels, 0, validationSize);
    
                //拿出剩下的做训练(输入0意味着拿剩下所有的)
                TrainImages = Pick(TrainImages, validationSize, 0);
                TrainLabels = Pick(TrainLabels, validationSize, 0);
    
                //将数字标签转换为二维数组
                //例如,标签3 =》 [0,0,0,1,0,0,0,0,0,0]
                //标签0 =》 [1,0,0,0,0,0,0,0,0,0]
                if (numClasses != -1)
                {
                    OneHotTrainLabels = OneHot(TrainLabels, numClasses);
                    OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
                    OneHotTestLabels = OneHot(TestLabels, numClasses);
                }
            }
    
            /// <summary>
            /// 获得source集合中的一部分,从first开始,到last结束
            /// </summary>
            /// <typeparam name="T"></typeparam>
            /// <param name="source"></param>
            /// <param name="first"></param>
            /// <param name="last"></param>
            /// <returns></returns>
            T[] Pick<T>(T[] source, int first, int last)
            {
                if (last == 0)
                {
                    last = source.Length;
                }
    
                var count = last - first;
                var ret = source.Skip(first).Take(count).ToArray();
                return ret;
            }
    
            public static Mnist Load()
            {
                var x = new Mnist();
                x.ReadDataSets(@"D:人工智能C#代码MNISTTensorFlowSharpMNISTTensorFlowSharpdata");
                return x;
            }
    View Code

    在这里,数据共有下面几部分:

    1. 训练图片数据55000 TrainImages及对应标签TrainLabels
    2. 验证图片数据5000 ValidationImages及对应标签ValidationLabels
    3. 测试图片数据10000 TestImages及对应标签TestLabels

    KNN算法的实现

    现在,我们已经有了所有的数据在手。需要实现的是:

    1. 拿出数据中的一部分(例如,5000张图片)作为KNN的训练数据,然后,再从数据中的另一部分拿一张图片A
    2. 对这张图片A,求它和5000张训练图片的距离,并找出一张训练图片B,它是所有训练图片中,和A距离最小的那张(这意味着K=1)
    3. 此时,就认为A所代表的数字等同于B所代表的数字b
    4. 重复1-3,N次

    首先进行数据的收集:

    //三个Reader分别从总的数据库中获得数据
            public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels);
            public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels);
            public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels);
    
            /// <summary>
            /// 数据的一部分,包括了所有的有用信息
            /// </summary>
            public class BatchReader
            {
                int start = 0;
                //图片库
                MnistImage[] source;
                //数字标签
                byte[] labels;
                //oneHot之后的数字标签
                byte[,] oneHotLabels;
    
                internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels)
                {
                    this.source = source;
                    this.labels = labels;
                    this.oneHotLabels = oneHotLabels;
                }
    
                /// <summary>
                /// 返回两个浮点二维数组(C# 7的新语法)
                /// </summary>
                /// <param name="batchSize"></param>
                /// <returns></returns>
                public (float[,], float[,]) NextBatch(int batchSize)
                {
                    //一张图
                    var imageData = new float[batchSize, 784];
                    //标签
                    var labelData = new float[batchSize, 10];
    
                    int p = 0;
                    for (int item = 0; item < batchSize; item++)
                    {
                        Buffer.BlockCopy(source[start + item].DataFloat, 0, imageData, p, 784 * sizeof(float));
                        p += 784 * sizeof(float);
                        for (var j = 0; j < 10; j++)
                            labelData[item, j] = oneHotLabels[item + start, j];
                    }
    
                    start += batchSize;
                    return (imageData, labelData);
                }
            }
    View Code

    然后,在算法中,获取数据:

            static void KNN()
            {
                //取得数据
                var mnist = Mnist.Load();
    
                //拿5000个训练数据,200个测试数据
                const int trainCount = 5000;
                const int testCount = 200;
    
                //获得的数据有两个
                //一个是图片,它们都是28*28的
                //一个是one-hot的标签,它们都是1*10的
                (var trainingImages, var trainingLabels) = mnist.GetTrainReader().NextBatch(trainCount);
                (var testImages, var testLabels) = mnist.GetTestReader().NextBatch(testCount);
    
                Console.WriteLine($"MNIST 1NN");
    View Code

    下面进行计算。这里使用了K=1的L1距离。这是最简单的情况。

                //建立一个图表示计算任务
                using (var graph = new TFGraph())
                {
                    var session = new TFSession(graph);
    
                    //用来feed数据的占位符。trainingInput表示N张用来进行训练的图片,N是一个变量,所以这里使用-1
                    TFOutput trainingInput = graph.Placeholder(TFDataType.Float, new TFShape(-1, 784));
    
                    //xte表示一张用来测试的图片
                    TFOutput xte = graph.Placeholder(TFDataType.Float, new TFShape(784));
    
                    //计算这两张图片的L1距离。这很简单,实际上就是把784个数字逐对相减,然后取绝对值,最后加起来变成一个总和
                    var distance = graph.ReduceSum(graph.Abs(graph.Sub(trainingInput, xte)), axis: graph.Const(1));
    
                    //这里只是用了最近的那个数据
                    //也就是说,最近的那个数据是什么,那pred(预测值)就是什么
                    TFOutput pred = graph.ArgMin(distance, graph.Const(0));
    View Code

    最后是开启Session计算的过程:

                    var accuracy = 0f;
    
                    //开始循环进行计算,循环trainCount次
                    for (int i = 0; i < testCount; i++)
                    {
                        var runner = session.GetRunner();
    
                        //每次,对一张新的测试图,计算它和trainCount张训练图的距离,并获得最近的那张
                        var result = runner.Fetch(pred).Fetch(distance)
                            //trainCount张训练图(数据是trainingImages)
                            .AddInput(trainingInput, trainingImages)
                            //testCount张测试图(数据是从testImages中拿出来的)
                            .AddInput(xte, Extract(testImages, i))
                            .Run();
                        
                        //最近的点的序号
                        var nn_index = (int)(long)result[0].GetValue();
    
                        //从trainingLabels中找到答案(这是预测值)
                        var prediction = ArgMax(trainingLabels, nn_index);
    
                        //正确答案位于testLabels[i]中
                        var real = ArgMax(testLabels, i);
    
                        //PrintImage(testImages, i);
    
                        Console.WriteLine($"测试 {i}: " +
                            $"预测: {prediction} " +
                            $"正确答案: {real} (最近的点的序号={nn_index})");
                        //Console.WriteLine(testImages);
    
                        if (prediction == real)
                        {
                            accuracy += 1f / testCount;
                        }
                    }
                    Console.WriteLine("准确率: " + accuracy);
    View Code

    对KNN的改进

    本文只是对KNN识别MNIST数据集进行了一个非常简单的介绍。在实现了最简单的K=1的L1距离计算之后,正确率约为91%。大家可以试着将算法进行改进,例如取K=2或者其他数,或者计算L2距离等。L2距离的结果比L1好一些,可以达到93-94%的正确率。

  • 相关阅读:
    C++中的static关键字的总结
    2017上海C++面试
    Vim 跳到上次光标位置
    Windows XP Professional产品序列号
    Centos7 安装sz,rz命令
    Xshell里连接VirtualBox里的Centos7
    什么是位、字节、字、KB、MB
    Centos7 tmux1.6 安装
    Centos7 在 Xshell里 vim的配置
    对JDBC的轻量级封装,Hibernate框架
  • 原文地址:https://www.cnblogs.com/haoyifei/p/9028235.html
Copyright © 2020-2023  润新知