• 基于milvus搭建“以图搜图”服务(附代码)


    “以图搜图”服务需要的关键功能和准备工作:

    1 图像向量化功能,可选的模型有很多,本例选用resnet网络提取图像特征;

    2 milvus建表,用milvus存放图像特征,通过唯一ID(此处称:milvus_id)与图像一一对应,sql建表将milvus_id作为唯一索引,存放图像的其他信息(如url,来源等);

    3 异步添加图像,同步搜索图像,添加图像的量通常会很大,因此采用异步批量的方式将图像特征加载到milvus,图像添加服务会将每次的请求信息存到sql,写个脚本专门用来定时批量加载图像特征到milvus,由于是异步操作,可能会出现重复加载的情况,此处使用redis进行去重。图像搜索的请求通常会比图像添加少很多,因此图像搜索使采用同步方式返回结果;

    (总结:需建立三个表:milvus表1,存放图像特征;sql表2,存放图像信息,数据与milvus表1一一对应;sql表3,存放图像添加请求信息,用于图像特征异步批量加载到milvus)

    “以图搜图”服务关键功能及代码(代码仅做参考)

    1 图像向量化

    """
    功能:图像向量化
    """
    from keras.applications.resnet50 import ResNet50
    from keras.preprocessing import image
    from keras.applications.resnet50 import preprocess_input, decode_predictions
    import numpy as np
    from numpy import linalg as LA
    import time
    
    model = ResNet50(weights='imagenet')
    # model.summary()
    
    
    def img2feature(img_path, input_dim=224):  # 图像路径???图像数据
        img = image.load_img(img_path, target_size=(input_dim, input_dim))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        x = model.predict(x)
        x = x / LA.norm(x)
        return x
    
    
    def main():
        img_path = '1.jpg'
        t0 = time.time()
        res = img2feature(img_path)
        print(time.time() - t0, res.shape)
        # print(res, type(res), res.shape)
    
    
    if __name__ == "__main__":
        main()
    

     2 milvus表的操作

    # coding:utf-8
    from functools import reduce
    import numpy as np
    import time
    from img2feature import img2feature
    from pymilvus import (
        connections, list_collections,
        FieldSchema, CollectionSchema, DataType,
        Collection, utility
    )
    
    
    field_name = 'image_feature'
    host = '***.***.***.***'
    port = '19530'
    dim = 1000
    default_fields = [
        FieldSchema(name="milvus_id", dtype=DataType.INT64, is_primary=True),
        FieldSchema(name="feature", dtype=DataType.FLOAT_VECTOR, dim=dim),
        FieldSchema(name="create_time", dtype=DataType.INT64)
    ]
    
    
    # create_table
    def create_table():
        connections.connect(host=host, port=port)
        # create collection
    
        default_schema = CollectionSchema(fields=default_fields, description="test collection")
    
        print(f"\nCreate collection...")
        collection = Collection(name=field_name, schema=default_schema)
        print(f"\nCreate index...")
        default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
        collection.create_index(field_name="feature", index_params=default_index)
        print(print(f"\nCreate index...is OKOKOKOKOK"))
        collection.load()
    
    
    # insert data
    def insert_data():
        connections.connect(host=host, port=port)
        default_schema = CollectionSchema(fields=default_fields, description="test collection")
        collection = Collection(name=field_name, schema=default_schema)
        vectors = img2feature('1.jpg').tolist()[0]
        print(type(vectors), len(vectors))
        data1 = [
            [123],
            [vectors],
            [int(time.time())]
        ]
        collection.insert(data1)
        print('insert compete')
    
    
    # search data
    def search_data():
        print('search')
        connections.connect(host=host, port=port)
        collection = Collection(name=field_name)
        print('连接成功')
    
        # 首次查询建立索引和load()
        # default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
        # print(f"\nCreate index...")
        # collection.create_index(field_name="feature", index_params=default_index)
        # print(print(f"\nCreate index...is OKOKOKOKOK"))
        # collection.load()
        # exit()
    
        vectors = img2feature('1.jpg').tolist()[0]
    
        topK = 10
        search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
    
        res = collection.search(
            [vectors],
            "feature",
            search_params,
            topK,
            "create_time > {}".format(0),
            output_fields=["milvus_id"]
        )
        print('>>>', res)
        for hits in res:
            print(len(hits))
            for hit in hits:
                print(hit)
        print('查询结束')
    
    
    def show_nums():
        connections.connect(host=host, port=port)
        collection = Collection(name=field_name)
        print('ok')
        print(collection.num_entities)
    
    
    # delete data
    def delete_table():
        connections.connect(host=host, port=port)
        default_schema = CollectionSchema(fields=default_fields, description="test collection")
        collection = Collection(name=field_name, schema=default_schema)
        print('>>>', utility.has_collection(field_name))
        collection.drop()
        print('>>>', utility.has_collection(field_name))
    
    
    if __name__ == "__main__":
        t1 = time.time()
        # create_table()
        # insert_data()
        # search_data()
        show_nums()
        # delete_table()
        print('time cost: {}'.format(time.time() - t1))
    

     3 创建sql表2、表3

     4 图像添加、搜索服务

    from rest_framework.views import APIView as View
    from kpdjango.response import SucessAPIResponse, ErrorAPIResponse
    from kpmysql.base import Kpmysql
    from core import search_image
    import kplog
    import logging
    log = logging.getLogger("console")
    
    
    class add_image(View):
        def post(self, requests):
            try:
                db = Kpmysql.connect("db168")
                cur = db.cursor()
                image_info = requests.POST.get('image_info')
                image_path = requests.POST.get('image_path')
                sql = "INSERT INTO t_image_search_image_add_log(image_path, info) VALUES(%s, %s)"
                cur.execute(sql, (image_path, image_info))
                db.commit()
                log.info('添加图像成功:{}-{}'.format(image_path, image_info))
                return SucessAPIResponse(msg="Success")
            except Exception as e:
                log.info('添加图像失败:{}'.format(e))
                return ErrorAPIResponse(msg="Fail")
    
    
    class search_image(View):
        def post(self, requests):
            try:
                image_path = requests.POST.get('image_path')
                res = search_image(image_path)
                log.info('查询图像成功:{}-{}'.format(image_path, res))
                return SucessAPIResponse(msg="Success", data={"data": res})
            except Exception as e:
                log.info('查询图像成功:{}'.format(e))
                return ErrorAPIResponse(msg="Fail")
    

     5 图像异步批量加载

    import time, datetime
    from kpmysql.base import Kpmysql
    from core import insert_data_many
    from concurrent.futures import ThreadPoolExecutor
    import redis
    from conf.setting import REDIS
    from core import str2time
    import kplog
    import logging
    
    log = logging.getLogger("console")
    log_addimgs = logging.getLogger("console_addimgs")
    
    
    def worker(datas):
        try:
            redis_cli = redis.Redis(host=REDIS.get('host'), port=REDIS.get('port'), password=REDIS.get('password'),
                                    db=REDIS.get('db'))
            dics = []
            ids = []
            for data in datas:
                if redis_cli.zscore('image_search', str(data[0])):  # 基于redis去重
                    continue
                dics.append({'image_path': data[1], 'create_time': data[2]})
                ids.append((data[0]))
                redis_cli.zadd('image_search', {str(data[0]): str2time(data[2])})
            # 数据插入milvus
            insert_data_many(dics)
            # 更新 set t_image_search_image_add_log is_load=1
            sql_update = """UPDATE t_image_search_image_add_log SET is_load=1 WHERE id=%s"""
            db168 = Kpmysql.connect("db168")
            cur168 = db168.cursor()
            cur168.executemany(sql_update, ids)
            db168.commit()
        except Exception as e:
            print(e)
    
    
    def main():
        max_workers = 20  # 最大线程数
        pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='Thread')
        task_list = []
        init_time = datetime.datetime.now() - datetime.timedelta(hours=13)
        create_time_init = '2020-2-22 00:00:00'
        while True:
            now = datetime.datetime.now()
            diff = now - init_time
            if diff.seconds > 3600:
                # 加载 t_image_search_image_add_log where is_load=0 数据
                db168 = Kpmysql.connect("db168")
                cur168 = db168.cursor()
                sql = """SELECT id, image_path, create_time FROM t_image_search_image_add_log WHERE is_load=0 and create_time >= %s ORDER BY create_time"""
                cur168.execute(sql, create_time_init)
                datas = cur168.fetchall()
                create_time_init = datas[-1][2]
    
    
                while True:
                    for _i, _n in enumerate(task_list):
                        if _n.done():
                            task_list.pop(_i)
                    if len(task_list) < int(max_workers * 0.9):
                        break
                task_list.append(pool.submit(worker, datas))
                init_time = now
            time.sleep(600)
    
    
    if __name__ == "__main__":
        main()
    

     优化(重点)

    经过实际测试和使用的建议:

    1. keras在调用GPU时并开启多线程时不如pytorch方便,pytorch占用显存更少;

    2. 定时从数据库拿数据,改成kafka生产消费模型,代码更简洁,逻辑更简单;

  • 相关阅读:
    crmfuxi
    段子
    wsfenxiang
    生成器、列表推导式
    闭包、迭代器、递归
    函数的参数及返回值
    嵌套、作用域、命名空间
    定义、函数的调用
    测试样式
    进制转换
  • 原文地址:https://www.cnblogs.com/niulang/p/15921786.html
Copyright © 2020-2023  润新知