• caffe: test code for Deep Learning approach


      1 #include <stdio.h>  // for snprintf
      2 #include <string>
      3 #include <vector>
      4 
      5 #include "boost/algorithm/string.hpp"
      6 #include "google/protobuf/text_format.h"
      7 
      8 #include "caffe/blob.hpp"
      9 #include "caffe/common.hpp"
     10 #include "caffe/net.hpp"
     11 #include "caffe/proto/caffe.pb.h"
     12 #include "caffe/util/db.hpp"
     13 #include "caffe/util/io.hpp"
     14 #include "caffe/vision_layers.hpp"
     15 
     16 using caffe::Blob;
     17 using caffe::Caffe;
     18 using caffe::Datum;
     19 using caffe::Net;
     20 using boost::shared_ptr;
     21 using std::string;
     22 namespace db = caffe::db;
     23 
     24 template<typename Dtype>
     25 int feature_extraction_pipeline(int argc, char** argv);
     26 
     27 int main(int argc, char** argv) {
     28   return feature_extraction_pipeline<float>(argc, argv);
     29 //  return feature_extraction_pipeline<double>(argc, argv);
     30 }
     31 
     32 template<typename Dtype>
     33 int feature_extraction_pipeline(int argc, char** argv) {
     34   ::google::InitGoogleLogging(argv[0]);
     35   const int num_required_args = 7;
     36   if (argc < num_required_args) {
     37     LOG(ERROR)<<
     38     "This program takes in a trained network and an input data layer, and then"
     39     " extract features of the input data produced by the net.
    "
     40     "Usage: extract_features  pretrained_net_param"
     41     "  feature_extraction_proto_file  extract_feature_blob_name1[,name2,...]"
     42     "  save_feature_dataset_name1[,name2,...]  num_mini_batches  db_type"
     43     "  [CPU/GPU] [DEVICE_ID=0]
    "
     44     "Note: you can extract multiple features in one pass by specifying"
     45     " multiple feature blob names and dataset names separated by ','."
     46     " The names cannot contain white space characters and the number of blobs"
     47     " and datasets must be equal.";
     48     return 1;
     49   }
     50   int arg_pos = num_required_args;
     51 
     52   arg_pos = num_required_args;
     53   if (argc > arg_pos && strcmp(argv[arg_pos], "GPU") == 0) {
     54     LOG(ERROR)<< "Using GPU";
     55     uint device_id = 0;
     56     if (argc > arg_pos + 1) {
     57       device_id = atoi(argv[arg_pos + 1]);
     58       CHECK_GE(device_id, 0);
     59     }
     60     LOG(ERROR) << "Using Device_id=" << device_id;
     61     Caffe::SetDevice(device_id);
     62     Caffe::set_mode(Caffe::GPU);
     63   } else {
     64     LOG(ERROR) << "Using CPU";
     65     Caffe::set_mode(Caffe::CPU);
     66   }
     67 
     68   arg_pos = 0;  // the name of the executable
     69   std::string pretrained_binary_proto(argv[++arg_pos]);
     70 
     71   // Expected prototxt contains at least one data layer such as
     72   //  the layer data_layer_name and one feature blob such as the
     73   //  fc7 top blob to extract features.
     74   /*
     75    layers {
     76      name: "data_layer_name"
     77      type: DATA
     78      data_param {
     79        source: "/path/to/your/images/to/extract/feature/images_leveldb"
     80        mean_file: "/path/to/your/image_mean.binaryproto"
     81        batch_size: 128
     82        crop_size: 227
     83        mirror: false
     84      }
     85      top: "data_blob_name"
     86      top: "label_blob_name"
     87    }
     88    layers {
     89      name: "drop7"
     90      type: DROPOUT
     91      dropout_param {
     92        dropout_ratio: 0.5
     93      }
     94      bottom: "fc7"
     95      top: "fc7"
     96    }
     97    */
     98   std::string feature_extraction_proto(argv[++arg_pos]);
     99   shared_ptr<Net<Dtype> > feature_extraction_net(
    100       new Net<Dtype>(feature_extraction_proto, caffe::TEST));
    101   feature_extraction_net->CopyTrainedLayersFrom(pretrained_binary_proto);
    102 
    103   std::string extract_feature_blob_names(argv[++arg_pos]);
    104   std::vector<std::string> blob_names;
    105   boost::split(blob_names, extract_feature_blob_names, boost::is_any_of(","));
    106 
    107   std::string save_feature_dataset_names(argv[++arg_pos]);
    108   std::vector<std::string> dataset_names;
    109   boost::split(dataset_names, save_feature_dataset_names,
    110                boost::is_any_of(","));
    111   CHECK_EQ(blob_names.size(), dataset_names.size()) <<
    112       " the number of blob names and dataset names must be equal";
    113   size_t num_features = blob_names.size();
    114 
    115   for (size_t i = 0; i < num_features; i++) {
    116     CHECK(feature_extraction_net->has_blob(blob_names[i]))
    117         << "Unknown feature blob name " << blob_names[i]
    118         << " in the network " << feature_extraction_proto;
    119   }
    120 
    121   int num_mini_batches = atoi(argv[++arg_pos]);
    122 
    123   std::vector<shared_ptr<db::DB> > feature_dbs;
    124   std::vector<shared_ptr<db::Transaction> > txns;
    125   const char* db_type = argv[++arg_pos];
    126   for (size_t i = 0; i < num_features; ++i) {
    127     LOG(INFO)<< "Opening dataset " << dataset_names[i];
    128     shared_ptr<db::DB> db(db::GetDB(db_type));
    129     db->Open(dataset_names.at(i), db::NEW);
    130     feature_dbs.push_back(db);
    131     shared_ptr<db::Transaction> txn(db->NewTransaction());
    132     txns.push_back(txn);
    133   }
    134 
    135   LOG(ERROR)<< "Extacting Features";
    136 
    137   Datum datum;
    138   const int kMaxKeyStrLength = 100;
    139   char key_str[kMaxKeyStrLength];
    140   std::vector<Blob<float>*> input_vec;
    141   std::vector<int> image_indices(num_features, 0);
    142   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
    143     feature_extraction_net->Forward(input_vec);
    144     for (int i = 0; i < num_features; ++i) {
    145       const shared_ptr<Blob<Dtype> > feature_blob = feature_extraction_net
    146           ->blob_by_name(blob_names[i]);
    147       int batch_size = feature_blob->num();
    148       int dim_features = feature_blob->count() / batch_size;
    149       const Dtype* feature_blob_data;
    150       for (int n = 0; n < batch_size; ++n) {
    151         datum.set_height(feature_blob->height());
    152         datum.set_width(feature_blob->width());
    153         datum.set_channels(feature_blob->channels());
    154         datum.clear_data();
    155         datum.clear_float_data();
    156         feature_blob_data = feature_blob->cpu_data() +
    157             feature_blob->offset(n);
    158         for (int d = 0; d < dim_features; ++d) {
    159           datum.add_float_data(feature_blob_data[d]);
    160         }
    161         int length = snprintf(key_str, kMaxKeyStrLength, "%010d",
    162             image_indices[i]);
    163         string out;
    164         CHECK(datum.SerializeToString(&out));
    165         txns.at(i)->Put(std::string(key_str, length), out);
    166         ++image_indices[i];
    167         if (image_indices[i] % 1000 == 0) {
    168           txns.at(i)->Commit();
    169           txns.at(i).reset(feature_dbs.at(i)->NewTransaction());
    170           LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
    171               " query images for feature blob " << blob_names[i];
    172         }
    173       }  // for (int n = 0; n < batch_size; ++n)
    174     }  // for (int i = 0; i < num_features; ++i)
    175   }  // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
    176   // write the last batch
    177   for (int i = 0; i < num_features; ++i) {
    178     if (image_indices[i] % 1000 != 0) {
    179       txns.at(i)->Commit();
    180     }
    181     LOG(ERROR)<< "Extracted features of " << image_indices[i] <<
    182         " query images for feature blob " << blob_names[i];
    183     feature_dbs.at(i)->Close();
    184   }
    185 
    186   LOG(ERROR)<< "Successfully extracted the features!";
    187   return 0;
    188 }
    View Code
  • 相关阅读:
    弄清变量名字空间
    Perl中文编码的处理
    了解魔符的含义
    Log::Minimal 小型可定制的log模块
    Perl – 文件测试操作符
    在源代码中使用Unicode字符
    editplus乱码charset的奇怪问题
    ASP.NET程序中常用代码汇总(一)
    ASP.NET程序中常用代码汇总(三)
    ASP.NET程序中常用代码汇总(二)
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/5200078.html
Copyright © 2020-2023  润新知