• tensorflow C++手写数字识别


    model.pb下载

    #include<fstream>
    #include<utility>
    #include<Eigen/Core>
    #include<Eigen/Dense>
    #include<string>
    
    #include<tensorflow/cc/ops/const_op.h>
    #include<tensorflow/cc/ops/image_ops.h>
    #include<tensorflow/cc/ops/standard_ops.h>
    #include<tensorflow/core/framework/graph.pb.h>
    #include<tensorflow/core/graph/default_device.h>
    #include<tensorflow/core/framework/tensor.h>
    #include<tensorflow/core/graph/graph_def_builder.h>
    #include<tensorflow/core/lib/core/errors.h>
    #include<tensorflow/core/lib/core/stringpiece.h>
    #include<tensorflow/core/lib/core/threadpool.h>
    #include<tensorflow/core/lib/io/path.h>
    #include<tensorflow/core/lib/strings/stringprintf.h>
    #include<tensorflow/core/platform/env.h>
    #include<tensorflow/core/platform/init_main.h>
    #include<tensorflow/core/platform/logging.h>
    #include<tensorflow/core/platform/types.h>
    #include<tensorflow/core/public/session.h>
    #include<tensorflow/core/util/command_line_flags.h>
    
    using namespace std;
    using namespace tensorflow;
    using namespace tensorflow::ops;
    using tensorflow::Flag;
    using tensorflow::Tensor;
    using tensorflow::Status;
    using tensorflow::string;
    using tensorflow::int32;
    
    //判断是否读取完整图像
    static Status ReadEntireFile(tensorflow::Env* env,const string& filename,Tensor* output) {
    	tensorflow::uint64 file_size = 0;
    	TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
    	string contents;
    	contents.resize(file_size);
    	std::unique_ptr<tensorflow::RandomAccessFile>file;
    	TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
    
    	tensorflow::StringPiece data;
    	TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
    	if (data.size() != file_size) {
    		return tensorflow::errors::DataLoss("Truncated read of '", filename,
    			"' expected ", file_size, " got ", data.size());
    	}
    	output->scalar<string>()() = data.ToString();
    	return Status::OK();
    }
    //从本地读取图片,存放到vector<Tensor>
    Status ReadTensorFromImageFile(const string& file_name,const int input_height,
    	const int input_width,const int input_mean,const int input_std,
    	std::vector<Tensor>* out_tensors) {
    	auto root = tensorflow::Scope::NewRootScope();
    	string input_name = "file_reader";
    	string out_name = "normalized";
    
    	Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
    	TF_RETURN_IF_ERROR(ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
    	auto file_reader = Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
    
    	std::vector<std::pair<string, tensorflow::Tensor>> inputs = { {"input",input} };
    
    	const int wanted_channels = 1;
    	tensorflow::Output image_reader;
    	if (tensorflow::StringPiece(file_name).ends_with(".png")) {
    		image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
    			DecodePng::Channels(wanted_channels));
    	}
    	else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
    		// gif decoder returns 4-D tensor, remove the first dim
    		image_reader =
    			Squeeze(root.WithOpName("squeeze_first_dim"),
    				DecodeGif(root.WithOpName("gif_reader"), file_reader));
    	}
    	else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) {
    		image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
    	}
    	else {
    		// Assume if it's neither a PNG nor a GIF then it must be a JPEG.
    		image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
    			DecodeJpeg::Channels(wanted_channels));
    	}
    	// Now cast the image data to float so we can do normal math on it.
    	auto float_caster =
    		Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
    	// The convention for image ops in TensorFlow is that all images are expected
    	// to be in batches, so that they're four-dimensional arrays with indices of
    	// [batch, height, width, channel]. Because we only have a single image, we
    	// have to add a batch dimension of 1 to the start with ExpandDims().
    
    
    	auto dims_expander = ExpandDims(root.WithOpName("expand"), float_caster, 0);
    	// Bilinearly resize the image to fit the required dimensions.
       // auto resized = ResizeBilinear(
    		//root, dims_expander,
    		//Const(root.WithOpName("size"), {input_height, input_width}));
    	// Subtract the mean and divide by the scale.
    	//Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
    		//{input_std});
    	float input_max = 255;
    	Div(root.WithOpName("div"), dims_expander, input_max);
    	// This runs the GraphDef network definition that we've just constructed, and
    	// returns the results in the output tensor.
    	tensorflow::GraphDef graph;
    	TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
    
    	std::unique_ptr<tensorflow::Session> session(
    		tensorflow::NewSession(tensorflow::SessionOptions()));
    	TF_RETURN_IF_ERROR(session->Create(graph));
    	TF_RETURN_IF_ERROR(session->Run({ inputs }, { "div" }, {}, out_tensors));
    	return Status::OK();
    }
    
    int main() {
    	//创建新回话
    	Session* session;
    	Status status = NewSession(SessionOptions(), &session);
    
    	string model_path = "model.pb";
    	GraphDef graphdef; //定义一个图
    	Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef);
    	if (!status_load.ok()) {
    		std::cout << "ERROR:Loading model failed..." << model_path << endl;
    		std::cout << status_load.ToString() << "
    ";
    		return -1;
    	}
    	//将模型导入session
    	Status status_create = session->Create(graphdef);
    	if (!status_create.ok()) {
    		std::cout << "ERROR:create graph in session failed..." << status_create.ToString() << '
    ';
    		return -1;
    	}
    	std::cout << "Session successfully created." << '
    ';
    	string image_path = "digit.jpg";
    	int input_height = 28, input_width = 28;
    	int input_mean = 0, input_std = 1;
    	std::vector<Tensor> resized_tensors;
    	Status read_tensor_status = ReadTensorFromImageFile(image_path, input_height, input_width,
    		input_mean, input_std,&resized_tensors);
    	if (!read_tensor_status.ok()) {
    		LOG(ERROR) << read_tensor_status;
    		cout << "resing error" << '
    ';
    		return -1;
    	}
    	const Tensor& resized_tensor = resized_tensors[0];
    	std::cout << resized_tensor.DebugString() << endl;
    
    	vector<tensorflow::Tensor> outputs;
    	string output_node = "softmax";
    
    	/*
    	virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
                         const std::vector<string>& output_tensor_names,
                         const std::vector<string>& target_node_names,
                         std::vector<Tensor>* outputs)
    	*/
    	Status status_run = session->Run({ {"inputs",resized_tensor} }, 
    		{ output_node }, {}, &outputs);
    
    	if (!status_run.ok()) {
    		std::cout << "ERROR: RUN failed..." << std::endl;
    		std::cout << status_run.ToString() << "
    ";
    		return -1;
    	}
    
    	std::cout << "Output tensor size:" << outputs.size() << std::endl;
    	for (std::size_t i = 0; i < outputs.size(); i++) {
    		std::cout << outputs[i].DebugString() << endl;
    	}
    
    	Tensor t = outputs[0];
    	int ndim = t.shape().dims();
    	auto tmap = t.tensor<float, 2>();
    	int output_dim = t.shape().dim_size(1);
    	std::vector<double> tout;
    	int output_class_id = -1;
    	double output_prob = 0.0;
    	for (int j = 0; j < output_dim; j++) {
    		std::cout << "Class " << j << "prob:" << tmap(0, j) << "," << endl;
    		if (tmap(0, j) >= output_prob) {
    			output_class_id = j;
    			output_prob = tmap(0, j);
    		}
    	}
    	std::cout << "Final class id:" << output_class_id << endl;
    	std::cout << "Final prob:" << output_prob << endl;
    
    	return 0;
    }
    
    天上我才必有用,千金散尽还复来!
  • 相关阅读:
    第一道题:无头苍蝇装头术(望不吝赐教)
    jdk8 list是否包含某值的一些应用
    Failed to close server connection after message failures; nested exception is javax.mail.MessagingException: Can't send command to SMTP host
    itext pdf加密
    TiDB-禁用遥测功能
    TiDB-配置调整
    DM-表空间
    DM-INI参数配置
    DM-DSC集群配置
    PG-并行查询
  • 原文地址:https://www.cnblogs.com/lutaishi/p/13436230.html
Copyright © 2020-2023  润新知