• tensorflow添加自定义的auc计算operator


    tensorflow可以很方便的添加用户自定义的operator(如果不添加也可以采用sklearnauc计算函数或者自己写一个 但是会在python执行,这里希望在graph中也就是c++端执行这个计算)

    这里根据工作需要添加一个计算aucoperator,只给出最简单实现,后续高级功能还是参考官方wiki

    https://www.tensorflow.org/versions/r0.7/how_tos/adding_an_op/index.html

    注意tensorflow现在和最初的官方wiki有变化,原wiki貌似是需要重新bazel编译整个tensorflow,然后使用比如tf.user_op.auc这样。

    目前wiki给出的方式>=0.6.0版本,采用plug-in的方式,更加灵活可以直接用g++编译一个so载入,解耦合,省去了编译tensorflow过程,即插即用。

       

    首先aucoperator计算的文件

       

    tensorflow/core/user_ops/auc.cc

       

    /* Copyright 2015 Google Inc. All Rights Reserved.

       

    Licensed under the Apache License, Version 2.0 (the "License");

    you may not use this file except in compliance with the License.

    You may obtain a copy of the License at

       

    http://www.apache.org/licenses/LICENSE-2.0

       

    Unless required by applicable law or agreed to in writing, software

    distributed under the License is distributed on an "AS IS" BASIS,

    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

    See the License for the specific language governing permissions and

    limitations under the License.

    ==============================================================================*/

       

    // An auc Op.

       

    #include "tensorflow/core/framework/op.h"

    #include "tensorflow/core/framework/op_kernel.h"

       

    using namespace tensorflow;

    using std::vector;

    //@TODO add weight as optional input

    REGISTER_OP("Auc")

    .Input("predicts: T1")

    .Input("labels: T2")

    .Output("z: float")

    .Attr("T1: {float, double}")

    .Attr("T2: {float, double}")

    //.Attr("T1: {float, double}")

    //.Attr("T2: {int32, int64}")

    .SetIsCommutative()

    .Doc(R"doc(

    Given preidicts and labels output it's auc

    )doc");

       

    class AucOp : public OpKernel {

    public:

    explicit AucOp(OpKernelConstruction* context) : OpKernel(context) {}

       

    template<typename ValueVec>

    void index_sort(const ValueVec& valueVec, vector<int>& indexVec)

    {

    indexVec.resize(valueVec.size());

    for (size_t i = 0; i < indexVec.size(); i++)

    {

    indexVec[i] = i;

    }

    std::sort(indexVec.begin(), indexVec.end(),

    [&valueVec](const int l, const int r) { return valueVec(l) > valueVec(r); });

    }

       

    void Compute(OpKernelContext* context) override {

    // Grab the input tensor

    const Tensor& predicts_tensor = context->input(0);

    const Tensor& labels_tensor = context->input(1);

    auto predicts = predicts_tensor.flat<float>(); //输入能接受float double那么这里如何都处理?

    auto labels = labels_tensor.flat<float>();

       

    vector<int> indexes;

    index_sort(predicts, indexes);

    typedef float Float;

       

    Float oldFalsePos = 0;

    Float oldTruePos = 0;

    Float falsePos = 0;

    Float truePos = 0;

    Float oldOut = std::numeric_limits<Float>::infinity();

    Float result = 0;

       

    for (size_t i = 0; i < indexes.size(); i++)

    {

    int index = indexes[i];

    Float label = labels(index);

    Float prediction = predicts(index);

    Float weight = 1.0;

    //Pval3(label, output, weight);

    if (prediction != oldOut) //存在相同值得情况是特殊处理的

    {

    result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);

    oldOut = prediction;

    oldFalsePos = falsePos;

    oldTruePos = truePos;

    }

    if (label > 0)

    truePos += weight;

    else

    falsePos += weight;

    }

    result += 0.5 * (oldTruePos + truePos) * (falsePos - oldFalsePos);

    Float AUC = result / (truePos * falsePos);

       

    // Create an output tensor

    Tensor* output_tensor = NULL;

    TensorShape output_shape;

       

    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));

    output_tensor->scalar<float>()() = AUC;

    }

    };

       

    REGISTER_KERNEL_BUILDER(Name("Auc").Device(DEVICE_CPU), AucOp);

       

       

    编译:

    $cat gen-so.sh

       

    TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')

    TF_LIB=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')

    i=$1

    o=${i/.cc/.so}

    g++ -std=c++11 -shared $i -o $o -I $TF_INC -l tensorflow_framework -L $TF_LIB -fPIC -Wl,-rpath $TF_LIB

       

    $sh gen-so.sh auc.cc

    会生成auc.so

       

    使用的时候

    auc_module = tf.load_op_library('auc.so')

    #auc = tf.user_ops.auc #0.6.0之前的tensorflow 自定义op方式

    auc = auc_module.auc

       

    evaluate_op = auc(py_x, Y) #py_x is predicts, Y is labels

       

       

       

       

       

       

  • 相关阅读:
    有人向我反馈了一个bug
    java.lang.ClassNotFoundException: org.springframework.core.SpringProperties
    Maven pom文件提示Missing artifact org.springframework:spring-context-support:jar:3.2.2.RELEASE:compile
    在业务逻辑中如何进行数据库的事务管理。
    about to fork child process, waiting until server is ready for connections. forked process: 2676 ERROR: child process failed, exited with error number 100
    tomcat底层原理实现
    springmvc 动态代理 JDK实现与模拟JDK纯手写实现。
    纯手写SpringMVC架构,用注解实现springmvc过程
    数据库连接池原理 与实现(动脑学院Jack老师课后自己的练习有感)
    定时器中实现数据库表数据移动的功能,Exception in thread "Timer-0" isExist java.lang.NullPointerException定时器中线程报错。
  • 原文地址:https://www.cnblogs.com/rocketfan/p/5201593.html
Copyright © 2020-2023  润新知