• faiss 没有提供余弦距离怎么办


    参考:https://zhuanlan.zhihu.com/p/40236865 ,但最后观点不同

    faiss是Facebook开源的用于快速计算海量向量距离的库,但是没有提供余弦距离,而余弦距离的使用率还是很高的,那怎么解决呢

    import faiss
    from faiss import normalize_L2
    import numpy as np
    from sklearn.metrics.pairwise import cosine_similarity
    
    def faiss_cos_similar_search(x, k=None):
        # 这个不是真的用faiss计算cos,而是找邻居的结果跟用cos得到的邻居结果是很接近,但是距离还是不同的哦
        assert len(x.shape) == 2, "仅支持2维向量的距离计算"
        nb, d = x.shape
        x = x.astype('float32')
        k_search = k if k else nb
        normalize_L2(x)
        index=faiss.IndexFlatIP(d)
        index.train(x)
    
        # index=faiss.IndexFlatL2(d)
        
        index.add(x)
        D, I =index.search(x, k=k_search)
        return I
    
    def sklearn_cos_search(x, k=None):
        assert len(x.shape) == 2, "仅支持2维向量的距离计算"
        nb, d = x.shape
        ag=cosine_similarity(x)
        np.argsort(-ag, axis=1)
        k_search = k if k else nb
    
        return np.argsort(-ag, axis=1)[:, :k_search]
    
    def test_IndexFlatIP_only(nb = 1000, d = 100, kr = 0.005, n_times=10):
        k = int(nb * kr)
        print("recall count is %d" % (k))
        for i in range(n_times):
            
            
            x = np.random.random((nb, d)).astype('float32')
            # x = np.random.randint(0,2, (nb,d))
            # faiss_I = faiss_cos_similar_search(x, k)
            index=faiss.IndexFlatIP(d)
            index.train(x)
            index.add(x)
            D, faiss_I =index.search(x, k=k)
    
            sklearn_I = sklearn_cos_search(x, k)
    
            cmp_result = faiss_I == sklearn_I
            
            print("is all correct: %s, correct batch rate: %d/%d, correct sample rate: %d/%d" % 
                (np.all(cmp_result), 
                np.all(cmp_result, axis=1).sum(),cmp_result.shape[0], 
                cmp_result.sum(),cmp_result.shape[0]*cmp_result.shape[1] ) )
    
    def test_embedding(nb = 1000, d = 100, kr = 0.005, n_times=10):
        k = int(nb * kr)
        print("recall count is %d" % (k))
        for i in range(n_times):
            
            
            x = np.random.random((nb, d)).astype('float32')
            # x = np.random.randint(0,2, (nb,d))
            faiss_I = faiss_cos_similar_search(x, k)
            sklearn_I = sklearn_cos_search(x, k)
    
            cmp_result = faiss_I == sklearn_I
            
            print("is all correct: %s, correct batch rate: %d/%d, correct sample rate: %d/%d" % 
                (np.all(cmp_result), 
                np.all(cmp_result, axis=1).sum(),cmp_result.shape[0], 
                cmp_result.sum(),cmp_result.shape[0]*cmp_result.shape[1] ) )
    
    def test_one_hot(nb = 1000, d = 100, kr = 0.005, n_times=10):
        k = int(nb * kr)
        print("recall count is %d" % (k))
        for i in range(n_times):
            
            
            # x = np.random.random((nb, d)).astype('float32')
            x = np.random.randint(0,2, (nb,d))
            faiss_I = faiss_cos_similar_search(x, k)
            sklearn_I = sklearn_cos_search(x, k)
    
            cmp_result = faiss_I == sklearn_I
            
            print("is all correct: %s, correct batch rate: %d/%d, correct sample rate: %d/%d" % 
                (np.all(cmp_result), 
                np.all(cmp_result, axis=1).sum(),cmp_result.shape[0], 
                cmp_result.sum(),cmp_result.shape[0]*cmp_result.shape[1] ) )
    if __name__ == "__main__":
        
        print("test use IndexFlatIP only")
        test_IndexFlatIP_only()
        print("-"*100 + "
    
    ")
        print("test when one hot")
        test_one_hot()
        print("-"*100 + "
    
    ")
        print("test use normalize_L2 + IndexFlatIP")
        test_embedding()
        print("-"*100 + "
    
    ")
    
    

    下面是实验结果

    分析:第一份结果(横线隔开),是仅用IndexFlatIP的时候,跟余弦距离的结果相差非常大

    第二份结果,是当数据是 one hot 的时候,用 normalize_L2 + IndexFlatIP,结果跟余弦距离结果基本上对的上了,但是也错了不少

    第二份结果,是当数据是 embedding 的向量的时候,用 normalize_L2 + IndexFlatIP,结果跟余弦距离结果基本上对的上了,错的也非常少

    需要注意,这里改方法对数据进行预处理,然后用欧氏距离去模拟余弦距离,并不是等价的,因为从结果来看,尽管差不多,但还是有不一样的地方,特别是召回调大的时候,更是相差变大

  • 相关阅读:
    access 导数据到sql server 2008
    axis2 调用.net基于https的WebService接口
    android layout 属性大全
    sqlite-manager
    android Permission 访问权限许可
    ImageSwitcher 右向左滑动的实现方式
    java 全角半角转换函数
    Delphi中使用Office中VBA的优缺点
    Delphi中控制VBA 宏
    Delphi 与 Word_VBA
  • 原文地址:https://www.cnblogs.com/paiandlu/p/12123859.html
Copyright © 2020-2023  润新知