手记

caffe函数入口caffe.cpp详解

概览

这篇博客解析caffe函数入口caffe.cpp,主要内容为caffe启动框架,基本不涉及深度学习的具体内容,内容十分基础,适合新手阅读。下面所有的代码解析都以训练lenet手写数字体识别为例,其运行参数为:

caffe train --solver=examples/mnist/lenet_solver.prototxt $@

main函数

先把main函数贴上来

int main(int argc, char** argv) {
  // Print output to stderr (while still logging).
  FLAGS_alsologtostderr = 1;
  // Set version
  gflags::SetVersionString(AS_STRING(CAFFE_VERSION));
  // Usage message.
  gflags::SetUsageMessage("command line brew\n"
      "usage: caffe <command> <args>\n\n"
      "commands:\n"
      "  train           train or finetune a model\n"
      "  test            score a model\n"
      "  device_query    show GPU diagnostic information\n"
      "  time            benchmark model execution time");
  // Run tool or show usage.
  caffe::GlobalInit(&argc, &argv);
  if (argc == 2) {
#ifdef WITH_PYTHON_LAYER
    try {
#endif
      return GetBrewFunction(caffe::string(argv[1]))();
#ifdef WITH_PYTHON_LAYER
    } catch (bp::error_already_set) {
      PyErr_Print();
      return 1;
    }
#endif
  } else {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
  }
}

main函数上来就是一个变量FLAGS_alsologtostderr,但vscode找不到该变量的定义。其实这个变量包括其他带有FLAGS前缀的变量是由gflags定义的,gflags 是 google 开源的用于处理命令行参数的项目。alsologtostderr指将日志输出到标准错误流中去。后面SetVersionString 的作用是当你使用caffe --version时能打印出caffe的版本信息,CAFFE_VERSION由Makefile指定.紧接着SetUsageMessage实际上设置了caffe的帮助信息,当运行caffe参数不正确或者使用--help参数时打印出usage信息。caffe::GlobalInit函数会根据命令行参数做一些初始化的工作,其定义在common.cpp中,具体如下:

void GlobalInit(int* pargc, char*** pargv) {
  // Google flags.
  ::gflags::ParseCommandLineFlags(pargc, pargv, true);
  // Google logging.
  ::google::InitGoogleLogging(*(pargv)[0]);
  // Provide a backtrace on segfault.
  ::google::InstallFailureSignalHandler();
}

对于训练手写数字体识别:

只有一个参数solver =examples/mnist/lenet_solver.prototxt 解析后可以以FLAGS_solver来访问。包括solver model等用户自定义的命令行参数(非gflags默认的参数)定义在caffe.cpp里:

DEFINE_string(gpu, "",
    "Optional; run in GPU mode on given device IDs separated by ','."
    "Use '-gpu all' to run on all available GPUs. The effective training "
    "batch size is multiplied by the number of devices.");
DEFINE_string(solver, "",
    "The solver definition protocol buffer text file.");
DEFINE_string(model, "",
    "The model definition protocol buffer text file.");

对于gflags更详细的信息可以参考google gflags 库完全使用

后面的InitGoogleLogging和InstallFailureSignalHandler用来处理日志和运行错误。

那么main函数怎么根据train test等参数进入到相应的train函数或test函数中去呢?

看这一行代码:

return GetBrewFunction(caffe::string(argv[1]))();

这个函数可以根据第一个参数argv[1](argv[0]是caffe本身的路径)来返回相应的函数,接下来我们来看GetBrewFunction是怎么实现这个功能的。

typedef int (*BrewFunction)(); //定义了一个函数指针类型,该类型指针指向一个参数为空返回值为int的函数
typedef std::map<caffe::string, BrewFunction> BrewMap;//定义了一个map类型,该类型的变量维护一个字典,函数名称(string)作为key,函数指针(BrewFunction)作为value
BrewMap g_brew_map;

#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \   //##表示合并字符串
 public: /* NOLINT */ \
  __Registerer_##func() { \
    g_brew_map[#func] = &func; \ #为字符串
  } \
}; \
__Registerer_##func g_registerer_##func; \
}

static BrewFunction GetBrewFunction(const caffe::string& name) {
  if (g_brew_map.count(name)) {
    return g_brew_map[name];//根据name中的具体内容返回相应的函数指针
  } else {
    LOG(ERROR) << "Available caffe actions:";
    for (BrewMap::iterator it = g_brew_map.begin();
         it != g_brew_map.end(); ++it) {
      LOG(ERROR) << "\t" << it->first;
    }
    LOG(FATAL) << "Unknown action: " << name;
    return NULL;  // not reachable, just to suppress old compiler warnings.
  }
}
//下面是一个例子,详细说明train函数怎么填充到g_brew_map中
int train(){
}
RegisterBrewFunction(train)//这一句会根据宏定义被替换成下面的内容

namespace{
class __Registerer_train{
    public:
        __Registerer_train(){
            g_brew_map["train"] = &train;
        }
};
__Registerer_train g_registerer_train; //实例化的过程中将train函数填充到字典g_brew_map中去了
}


根据上面一些注释,我们可以看出一个大概的框架:

    1 定义一个字典,存储函数名到函数指针的映射。


    2 通过RegisterBrewFunction(func)的宏定义来填充这个字典。

    3 调用GetBrewFunction根据函数名返回相应的函数指针。


train函数

下面具体看train函数

// Train / Finetune a model.
int train() {
  CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; //FLAGS_solver <= 0 会输出
  CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())// snapshot 和 weight参数都没有,不管
      << "Give a snapshot to resume training or weights to finetune "
      "but not both.";
  vector<string> stages = get_stages_from_flags(); //stages参数也没有,跳过

  caffe::SolverParameter solver_param;
  caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//该行从lenet_solver.prototxt读取参数到solver_param中

  solver_param.mutable_train_state()->set_level(FLAGS_level); //level参数也没有,跳过
  for (int i = 0; i < stages.size(); i++) {
    solver_param.mutable_train_state()->add_stage(stages[i]);
  }

  // If the gpus flag is not provided, allow the mode and device to be set
  // in the solver prototxt.
  if (FLAGS_gpu.size() == 0     //从solverparam中读取GPU的信息,是否使用GPU,GPU的id之类的,初期可以不用特别关注
      && solver_param.has_solver_mode()
      && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {
      if (solver_param.has_device_id()) {
          FLAGS_gpu = "" +
              boost::lexical_cast<string>(solver_param.device_id());
      } else {  // Set default GPU if unspecified
          FLAGS_gpu = "" + boost::lexical_cast<string>(0);
      }
  }

  vector<int> gpus;
  get_gpus(&gpus);
  if (gpus.size() == 0) {
    LOG(INFO) << "Use CPU.";
    Caffe::set_mode(Caffe::CPU);
  } else {
    ostringstream s;
    for (int i = 0; i < gpus.size(); ++i) {
      s << (i ? ", " : "") << gpus[i];
    }
    LOG(INFO) << "Using GPUs " << s.str();
#ifndef CPU_ONLY
    cudaDeviceProp device_prop;
    for (int i = 0; i < gpus.size(); ++i) {
      cudaGetDeviceProperties(&device_prop, gpus[i]);
      LOG(INFO) << "GPU " << gpus[i] << ": " << device_prop.name;
    }
#endif
    solver_param.set_device_id(gpus[0]);
    Caffe::SetDevice(gpus[0]);
    Caffe::set_mode(Caffe::GPU);
    Caffe::set_solver_count(gpus.size());
  }

  caffe::SignalHandler signal_handler(
        GetRequestedAction(FLAGS_sigint_effect),
        GetRequestedAction(FLAGS_sighup_effect));

  if (FLAGS_snapshot.size()) {
    solver_param.clear_weights();
  } else if (FLAGS_weights.size()) {
    solver_param.clear_weights();
    solver_param.add_weights(FLAGS_weights);
  }
//根据solver_param,生成solver
  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

  solver->SetActionFunction(signal_handler.GetActionFunction());

  if (FLAGS_snapshot.size()) {
    LOG(INFO) << "Resuming from " << FLAGS_snapshot;
    solver->Restore(FLAGS_snapshot.c_str());
  }

  LOG(INFO) << "Starting Optimization";
  if (gpus.size() > 1) {
#ifdef USE_NCCL
    caffe::NCCL<float> nccl(solver);
    nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL);
#else
    LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL";
#endif
  } else {
    //求解solver
    solver->Solve();
  }
  LOG(INFO) << "Optimization Done.";
  return 0;
}


solver的实例化

这里不涉及任何solver内部的细节,包括生成_net和test_net,具体的求解方法等内容,只剖析caffe怎样根据solverparam.type实例化不同的solver类。实际上这些内容和上面讲的根据命令行参数执行train还是test等函数的方法十分相似,但其过程更加复杂,还是简要的分析一下。


shared_ptr<caffe::Solver<float>>
    solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver<float>的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。

下面分析SolverRegistry具体是怎么做的:

typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
  typedef std::map<string, Creator> CreatorRegistry;

  static CreatorRegistry& Registry() {
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }


  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    return registry[type](param);
  }

从上述代码可以看到也是维护了一个map由solverparam.type返回具体的solver<Dtype>指针

SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。 CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*返回。 

Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,可以在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。


那包括SGDSolver等各种solver是怎么注册的呢?下面以注册SGDSolver为例说明

solver_factory.hpp文件中有两个宏定义如下:

#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \

#define REGISTER_SOLVER_CLASS(type)                                            \
  template <typename Dtype>                                                    \
  Solver<Dtype>* Creator_##type##Solver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new type##Solver<Dtype>(param);                                     \
  }                                                                            \
  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)

sgd_solver.cpp文件末尾有

REGISTER_SOLVER_CLASS(SGD);

根据宏定义替换的结果如下:

template <typename Dtype>
Solver<Dtype>* Creator_SGD_Solver(const SolverParameter& param)
{
  return new SGDSolver<Dtype>(param);
}
static SolverRegisterer<float> g_creator_f_SGD("SGD",Creator_SGD_Solver<float>);
static SolverRegisterer<double> g_creator_f_SGD("SGD",Creator_SGD_Solver<double>);

即根据宏定义,定义了一个Creator函数指针可指的函数Creator_SGD_Solver,然后通过下面的函数将key和value注册进去:

template <typename Dtype>
class SolverRegisterer {
 public:
  SolverRegisterer(const string& type,
      Solver<Dtype>* (*creator)(const SolverParameter&)) {
    // LOG(INFO) << "Registering solver type: " << type;
    SolverRegistry<Dtype>::AddCreator(type, creator);
  }
};

AddCreator函数的源码不在此展示,具体细节阅读solver_factory.hpp

至此,生成solver的工厂模式应该讲清楚了,caffe的启动框架也差不多清晰了,接下来就是solver怎么根据solver_params生成net,以及net的前向和反向计算了。

参考:

    Caffe中Solver解析

    google gflags 库完全使用

原文出处

0人推荐
随时随地看视频
慕课网APP