• t-SNE可视化(MNIST例子)


    如下所示:

    import pickle as pkl
    import numpy as np
    from matplotlib import pyplot as plt
    from tsne import bh_sne
    import sys 
    
    with open("data", 'rb') as f:
                if sys.version_info > (3, 0):
                    data = pkl.load(f, encoding='latin1')
                else:
                    data = pkl.load(f)
    
    data =data.astype('float64')
    
    
    with open("label", 'rb') as f:
                if sys.version_info > (3, 0):
                    y_data = pkl.load(f, encoding='latin1')
                else:
                    y_data = pkl.load(f)
    classNum = 6
    y_data = np.where(y_data==1)[1]*(9.0/classNum)
    
    vis_data = bh_sne(data)
    
    # plot the result
    vis_x = vis_data[:, 0]
    vis_y = vis_data[:, 1]
    
    fig = plt.figure()
    plt.scatter(vis_x, vis_y, c=y_data, s=1, cmap=plt.cm.get_cmap("jet", 10))
    plt.colorbar(ticks=range(10))
    plt.clim(-0.5, 9.5)
    plt.show()
    fig.savefig('test.png')

    结果:

    以MNIST为例,先做PCA降到50维,再做t-sne:

    from time import time
    from tsne import bh_sne
    import numpy as np
    import matplotlib.pyplot as plt
    from tensorflow.examples.tutorials.mnist import input_data
    from matplotlib import offsetbox
    from sklearn import (manifold, datasets, decomposition, ensemble,
                         discriminant_analysis, random_projection)
    from sklearn import decomposition
    
    
    mnist = input_data.read_data_sets('./input_data', one_hot=False)
    sub_sample = 5000
    y = mnist.train.labels[0:sub_sample]
    X = mnist.train.images[0:sub_sample]
    
    n_samples, n_features = X.shape
    n_neighbors = 30
    
    
    #----------------------------------------------------------------------
    # Scale and visualize the embedding vectors
    def plot_embedding(X_emb, title=None):
        x_min, x_max = np.min(X_emb, 0), np.max(X_emb, 0)
        X_emb = (X_emb - x_min) / (x_max - x_min)
    
        plt.figure()
        ax = plt.subplot(111)
        for i in range(X_emb.shape[0]):
            plt.text(X_emb[i, 0], X_emb[i, 1], str(y[i]),
                     color=plt.cm.Set1(y[i] / 10.),
                     fontdict={'weight': 'bold', 'size': 9})
    
        if hasattr(offsetbox, 'AnnotationBbox'):
            # only print thumbnails with matplotlib > 1.0
            shown_images = np.array([[1., 1.]])  # just something big
            for i in range(sub_sample):
                dist = np.sum((X_emb[i] - shown_images) ** 2, 1)
                if np.min(dist) < 8e-3:
                    # don't show points that are too close
                    continue
                shown_images = np.r_[shown_images, [X_emb[i]]]
                imagebox = offsetbox.AnnotationBbox(
                    offsetbox.OffsetImage(X[i].reshape(28,28)[::2,::2], cmap=plt.cm.gray_r),
                    X_emb[i])
                ax.add_artist(imagebox)
        plt.xticks([]), plt.yticks([])
        if title is not None:
            plt.title(title)
    
    
    #----------------------------------------------------------------------
    # Plot images of the digits
    n_img_per_row = 20
    img = np.zeros((30 * n_img_per_row, 30 * n_img_per_row))
    for i in range(n_img_per_row):
        ix = 30 * i + 1
        for j in range(n_img_per_row):
            iy = 30 * j + 1
            img[ix:ix + 28, iy:iy + 28] = X[i * n_img_per_row + j].reshape((28, 28))
    
    plt.imshow(img, cmap=plt.cm.binary)
    plt.xticks([])
    plt.yticks([])
    plt.title('A selection from the 64-dimensional digits dataset')
    
    # t-SNE embedding of the digits dataset
    print("Computing t-SNE embedding")
    t0 = time()
    X_pca = decomposition.TruncatedSVD(n_components=50).fit_transform(X)
    # data =X.astype('float64')
    X_tsne  = bh_sne(X_pca)
    
    plot_embedding(X_tsne,
                   "t-SNE embedding of the digits (time %.2fs)" %
                   (time() - t0))
    
    plt.show()

    结果如下:

     

    更多降维的可视化参考:http://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html#sphx-glr-auto-examples-manifold-plot-lle-digits-py

  • 相关阅读:
    java远程调用rmi入门实例
    POJ2752 Seek the Name, Seek the Fame 【KMP】
    Scala入门到精通——第十六节 泛型与注解
    js:简单的拖动效果
    Android拍照、摄像方向旋转的问题 代码具体解释
    对dispatch_async到主线程的逻辑封装成C/C++接口类型
    Oracle password expire notices
    CentOS bridge br0 kvm libvirt-xml
    国内常用ntp服务器ip地址
    C Deepin指针
  • 原文地址:https://www.cnblogs.com/huangshiyu13/p/6945239.html
Copyright © 2020-2023  润新知