• 计算两张图的余弦相似度


    # 结果余弦相似度对比
    import numpy as np
    import pdb
     
    def count_difference(groundtruth, inputs):
        statistical_method = {
            'cosine_similarity':
            lambda X1, X2: np.sum(X1 * X2) /
            (np.sqrt(np.sum(X1**2)) * np.sqrt(np.sum(X2**2))),
            'maximum_absolute_error':
            lambda X1, X2: np.max(np.abs(X1 - X2)).tolist(),
            'accumulated_relative_error':
            lambda X1, X2: np.sum(np.abs(np.nan_to_num((X1 - X2) / X1))),
            'relative_euclidean_distance':
            lambda X1, X2: np.sqrt(
                np.sum((X1 / np.sqrt(np.sum(X1**2)) - X2 / np.sqrt(np.sum(X2**2)))
                       **2)).tolist(),
            'kullback_leibler_divergence':
            lambda X1, X2: np.sum(X1 * np.nan_to_num(np.log(X1 / X2))),
            'standard_deviation':
            lambda *X: [(np.mean(x).tolist(), np.std(x).tolist()) for x in X]
        }
     
        reports = {}
        for input_key in groundtruth.keys():
            if input_key in inputs.keys():
                reports[input_key] = {}
                gt_input = groundtruth[input_key].reshape(-1, 1)
                compare_input = inputs[input_key].reshape(-1, 1)
                assert gt_input.size == compare_input.size# 要求对比的两张图尺寸一致
                for key, value in statistical_method.items():
                    reports[input_key][key] = value(gt_input, compare_input)
        return reports
     
     
    def main():
        
        a = np.fromfile("/home/wangmaolin/for_test/tofile/conv_82_memory", dtype=np.float32)
        print(a.shape)
        print(a.dtype)
        inputs = {"data": a}
     
        c = np.fromfile("/home/wangmaolin/for_test/onnx_output/onnx_output_conv_82", dtype=np.float32)
        print(c.shape)
        print(c.dtype)
        gt_inputs = {"data": c}
        
     
        report = count_difference(gt_inputs, inputs)
        print(report)
     
     
    if __name__ == '__main__':
        main()
    
    转载请注明出处
  • 相关阅读:
    scss-数据类型
    scss-@import
    scss-&父选择器标识符
    scss-嵌套属性
    Python之NumPy(axis=0 与axis=1)区分
    Java map 详解
    java之JDBC多条语句执行
    p-value值的认识
    numpy.random之常用函数
    Python random模块sample、randint、shuffle、choice随机函数
  • 原文地址:https://www.cnblogs.com/lnlin/p/15490644.html
Copyright © 2020-2023  润新知