• faiss计算余弦距离


    faiss是Facebook开源的相似性搜索库,为稠密向量提供高效相似度搜索和聚类,支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库

    faiss不直接提供余弦距离计算,而是提供了欧式距离和点积,利用余弦距离公式,经过L2正则后的向量点积结果即为余弦距离,所以利用faiss计算余弦距离需要先对输入进行L2正则化

    • 安装

      参照官方开源安装https://github.com/facebookresearch/faiss/blob/main/INSTALL.md

      # CPU-only version
      $ conda install -c pytorch faiss-cpu
      $ pip install faiss-cpu
      
      # GPU(+CPU) version
      $ conda install -c pytorch faiss-gpu
      $ pip install faiss-cpu
       
      
    • 常规计算余弦距离方式

      常规一般使用sklearn包的cosine_similarity计算余弦距离,因为该包自动对向量进行L2正则,所以不要求输入必须为正则结果,代码如下:

      ## 计算余弦距离
      from sklearn.metrics.pairwise import cosine_similarity
      from sklearn import preprocessing
      def get_cos_result(embeding_library, persons, embeding_search):
          simi = cosine_similarity(embeding_search, embeding_library)
          max_argmin = np.argmax(simi,axis=1)
          search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)]
          return search_speaker
      ## 对输入进行正则化,可以不用正则
      def l2_normal(embeding):
          return preprocessing.normalize(embeding)
      
    • faiss的精确搜索

      faiss并不提供计算与余弦距离,只提供了点积计算和欧式距离,所以在计算余弦距离时,需要对输入进行L2正则,代码如下:

      import faiss
      from faiss import normalize_L2
      def faiss_precise_search(embeding_library, persons, embeding_search,topk=1):
          ## 这里也可以使用上文的sklearn的包进行正则
          normalize_L2(embeding_search)
          normalize_L2(embeding_library)
          # faiss.IndexFlatIP是内积 ;faiss.indexFlatL2是欧式距离
          quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
          index = quantizer
          ## 要保证输入为np.float32格式
          index.add(embeding_library.astype(np.float32))
          library = {'persons': persons, 'index': index}
          st = time.time()
          distance,idx = library['index'].search(embeding_search,topk)
          print('precise search:',time.time()-st)
          combined_results = []
          for p in range(len(distance)):
              results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
              combined_results.append(results)
          return combined_results
      
    • faiss快速搜索

      faiss提供了多种快速搜索的方式,这里介绍常用的一种加速搜索的方式:倒排索引,这种方式与ES快速搜索的方式类似,需要先使用k-means建立聚类中心,通过查询最近的聚类中心,然后比较聚类中所有向量得到相似向量,这里需要两个超参数,一个是聚类中心num_cells,一个是查找聚类中心的个数num_cells_in_search,具体代码如下

      def faiss_fast_search(embeding_library, persons, embeding_search,topk=1):
          normalize_L2(embeding_search)
          normalize_L2(embeding_library)
          d = embeding_library.shape[1]
          num_cells = 50
          num_cells_in_search = 5
          # 声明量化器
          quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
           # faiss.METRIC_INNER_PRODUCT计算内积 faiss.METRIC_L2j计算欧式距离
          index = faiss.IndexIVFFlat(quantizer, d,min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT)
          assert not index.is_trained
          index.train(embeding_library.astype(np.float32))
          index.add(embeding_library.astype(np.float32))
          index.nprobe = min(num_cells_in_search,len(persons))
          library = {'persons': persons, 'index': index}
          st = time.time()
          distance, idx = library['index'].search(embeding_search, topk)
          print('fast search:',time.time()-st)
          combined_results = []
          for p in range(len(distance)):
              results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
              combined_results.append(results)
          return combined_results
      
    • 整体代码

      # -*- coding: utf-8 -*-
      import faiss
      from faiss import normalize_L2
      from sklearn.metrics.pairwise import cosine_similarity
      from sklearn import preprocessing
      import numpy as np
      import time
      
      def l2_normal(embeding):
          return preprocessing.normalize(embeding)
      
      def get_cos_result(embeding_search, persons, embeding_library):
          simi = cosine_similarity(embeding_search, embeding_library)
          max_argmin = np.argmax(simi,axis=1)
          search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)]
          return search_speaker
      
      def faiss_precise_search(embeding_library, persons, embeding_search):
          normalize_L2(embeding_search)
          normalize_L2(embeding_library)
          # faiss.IndexFlatIP是内积 ;faiss.indexFlatL2是欧式距离
          quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
          index = quantizer
          index.add(embeding_library.astype(np.float32))
          library = {'persons': persons, 'index': index}
          st = time.time()
          distance,idx = library['index'].search(embeding_search,1)
          print('precise search:',time.time()-st)
          combined_results = []
          for p in range(len(distance)):
              results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
              combined_results.append(results)
          return combined_results
      
      def faiss_fast_search(embeding_library, persons, embeding_search,topk=1):
          normalize_L2(embeding_search)
          normalize_L2(embeding_library)
          num_cells = 500
          num_cells_in_search = 10
          quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
          index = faiss.IndexIVFFlat(quantizer, embeding_library.shape[1],min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT) #faiss.METRIC_INNER_PRODUCT计算内积 faiss.METRIC_L2j计算欧式距离
          assert not index.is_trained
          index.train(embeding_library.astype(np.float32))
          index.add(embeding_library.astype(np.float32))
          index.nprobe = min(num_cells_in_search,len(persons))
          library = {'persons': persons, 'index': index}
          st = time.time()
          distance, idx = library['index'].search(embeding_search, topk)
          print('fast search:',time.time()-st)
          combined_results = []
          for p in range(len(distance)):
              results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
              combined_results.append(results)
          return combined_results
      
      if __name__ == '__main__':
          d = 512
          n_library = 100000
          n_search = 1
          embeding_library = np.random.random((n_library, d)).astype(np.float32)
          persons = ['Speak' + "%0d" % (i + 1) for i in range(n_library)]
          embeding_search = np.random.random((n_search, d)).astype(np.float32)
          print(faiss_fast_search(embeding_library, persons, embeding_search))
          print(faiss_precise_search(embeding_library, persons, embeding_search))
          st = time.time()
          print(get_cos_result(embeding_search, persons, embeding_library))
          en1 = time.time()
          print(en1-st)
      
  • 相关阅读:
    msql 触发器
    微信模板消息推送
    微信朋友朋友圈自定义分享内容
    微信退款
    异步调起微信支付
    微信支付
    第一次作业
    【Linus安装MongoDB及Navicat】
    【前端】ES6总结
    【开发工具】Pycharm使用
  • 原文地址:https://www.cnblogs.com/peng-yuan/p/15476945.html
Copyright © 2020-2023  润新知