深度学习caffe–手写字体识别例程(五)—— convert_mnist_data.cpp文件详解

释放双眼,带上耳机,听听看~!

        我们在《深度学习caffe–手写字体识别例程(四)》中,用到了convert_mnist_data.bin文件进行数据集格式的转换,命令如下


1
2
3
1$BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte \
2  $DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}
3

它的作用是将mnist数据集转换为lmdb或leveldb格式的文件,以便用于深度学习的训练。这篇文章我们就来研究convert_mnist_data.bin这个文件是如何实现的。convert_mnist_data.bin文件的源文件在example/mnist/目录下,文件名为convert_mnist_data.cpp,由于这个文件中的代码比较长,我们下面把代码贴出来,并在每行或几行的代码下面进行解释。


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
1#include <gflags/gflags.h>
2#include <glog/logging.h>
3#include <google/protobuf/text_format.h>
4
5#if defined(USE_LEVELDB) && defined(USE_LMDB)
6#include <leveldb/db.h>
7#include <leveldb/write_batch.h>
8#include <lmdb.h>
9#endif
10
11#include <stdint.h>
12#include <sys/stat.h>
13
14#include <fstream>  // NOLINT(readability/streams)
15#include <string>
16
17#include "boost/scoped_ptr.hpp"
18#include "caffe/proto/caffe.pb.h"
19#include "caffe/util/db.hpp"
20#include "caffe/util/format.hpp"
21

        这些代码是文件包含的头文件,是文件中需要使用到的头文件。


1
2
1#if defined(USE_LEVELDB) && defined(USE_LMDB)
2

        这是一个判断的宏,如果满足判断条件,则编译下方的代码,否则编译#else下面的代码。我们总览这个文件,发现#else在文件的结尾处,只包含了几行代码。这个宏的根本作用在于,判断是否定义了USE_LEVELDB和USE_LMDB,如果定义了则进行文件格式转换的操作,否则,不操作。这两个宏是在编译caffe源码的时候定义的。


1
2
3
4
5
1using namespace caffe;  // NOLINT(build/namespaces)
2using boost::scoped_ptr;
3using std::string;
4
5

        这3行是这个文件需要用到的库。


1
2
1DEFINE_string(backend, "lmdb", "The backend for storing the result");
2

        这行代码在这个文件中没能找到DEFINE_string的定义。其实它是在gflags.h文件中定义的,这个文件在/usr/include/gflags/目录下,有兴趣可以打开文件研究一下,DEFINE_string是一个宏定义,这里我们只介绍一下它的作用。调用DEFINE_string之后,会生成基于backend生成一个变量FLAGS_backend,并且变量的取值为“lmdb”,"The backend for storing the result"是这个变量的说明。


1
2
3
4
5
1uint32_t swap_endian(uint32_t val) {
2    val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
3    return (val << 16) | (val >> 16);
4}
5

        这段代码是一个函数,它的作用是对32位的整形变量进行大小端转换,在《深度学习caffe–手写字体识别例程(三)》中,我们介绍了,在mnist数据集中,多字节的数据是按照大端模式存储的,也就是数据的高字节存在低地址,如果我们进行数据读取数据读出来之后,字节顺序是反的。比如一个32字节的数据0x12345678,它在mnist文件中存储时,相对地址0地址为12,1地址为34,2地址为56,3地址为78。当从文件中读取32位的数据时,读出来的是0x78563412,与原始数据正好是反的。所以需要用这个函数进行转换。

       还是以0x12345678为例,从mnist中读出的值为0x78563412,调用这个函数时,将0x78563412赋值给val。在函数中,((val << 8) & 0xFF00FF00)将val左移8位并与0xFF00FF00按位做与运算,得到的结果是0x56001200;((val >> 8) & 0xFF00FF)将val右移8位并与0xFF00FF按位做与运算,得到的结果是0x00780034。两个结果再做按位或运算并赋值给val,则val =0x56001200 | 0x00780034=0x56781234。

       函数的最后一行(val << 16)将val左移16位,注意此时val的值已经是0x56781234,它左移16位的结果为0x12340000,(val >> 16)将val右移16位得到的结果为0x00005678,两个结果做按位或运算,得到的结果为0x12345678,并将结果返回。经过这一系列的操作,实现了数据的转换。


1
2
3
1void convert_dataset(const char* image_filename, const char* label_filename,
2        const char* db_path, const string&amp; db_backend) {
3

        这段代码是数据转换的函数定义,这个函数是这个文件的核心函数,就是它实现了mnist的二进制文件到lmdb文件的转换。函数的形参分别为

image_filename图片文件名

label_filename标签文件名

db_path生成文件的存储路径

db_backend生成文件的尾缀,指定文件类型,即lmdb还是leveldb。


1
2
3
4
1  std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
2  std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
3
4

       这两行代码的作用是实例化两个std::ifstream对象,它们是流式文件的对象,这两个对象将图片文件和标签文件输入,以二进制的方式打开,通过这两个对象就可以访问图片文件和标签文件。


1
2
3
1  CHECK(image_file) &lt;&lt; &quot;Unable to open file &quot; &lt;&lt; image_filename;
2  CHECK(label_file) &lt;&lt; &quot;Unable to open file &quot; &lt;&lt; label_filename;
3

        这两行代码的作用是检查image_file和label_file这两个对象,是否为0,如果为0,则打印信息“Unable to open file +文件名”,并退出程序。其中CHECK是个宏定义,它在logging.h中定义,这个文件在/usr/include/glog/目录下,它的作用是判断括号内的条件是否为0,如果为0则打印后边的内容,并退出程序。有兴趣可以翻看logging.h文件研究一下。


1
2
3
4
5
6
7
1  uint32_t magic;
2  uint32_t num_items;
3  uint32_t num_labels;
4  uint32_t rows;
5  uint32_t cols;
6
7

这段代码定义了5个变量,它们分别用来保存魔数、条目数、标签数、图片的行数、图片的列数。这些变量的含义在《深度学习caffe–手写字体识别例程(三)》中有详细介绍,可以参考,它们都是从mnist数据集中读取出来的。


1
2
3
4
1  image_file.read(reinterpret_cast&lt;char*&gt;(&amp;magic), 4);
2  magic = swap_endian(magic);
3  CHECK_EQ(magic, 2051) &lt;&lt; &quot;Incorrect image file magic.&quot;;
4

        这3行代码的作用是从之前实例化的图片文件的流式文件对象中读取4个字节的数据,并且保存到magic变量中,关于mnist数据集中的图片文件的格式定义可以参考《深度学习caffe–手写字体识别例程(三)》。在图片文件中前4个字节的数据就是魔数。

        第2行将magic进行大小端变换,这是因为变量在文件中是按照大端存储的。

        第3行对magic进行检测,看它与2051是否相等,如果不相等,输出信息,并退出程序。CHECK_EQ是宏定义,与CHECK类似,它也是在logging.h中定义的。


1
2
3
4
5
1  label_file.read(reinterpret_cast&lt;char*&gt;(&amp;magic), 4);
2  magic = swap_endian(magic);
3  CHECK_EQ(magic, 2049) &lt;&lt; &quot;Incorrect label file magic.&quot;;
4
5

        这3行代码的作用是读取标签文件的魔数,并做大小端变换,最后检测它与2049是否相等,如果不相等则打印信息并退出。


1
2
3
4
1  image_file.read(reinterpret_cast&lt;char*&gt;(&amp;num_items), 4);
2  num_items = swap_endian(num_items);
3
4

        这2行代码的作用是读取图片文件的图片条目数,并做大小端变换。


1
2
3
4
1  label_file.read(reinterpret_cast&lt;char*&gt;(&amp;num_labels), 4);
2  num_labels = swap_endian(num_labels);
3
4

       这2行代码的作用是读取标签文件的标签条目数,并做大小端变换。


1
2
1  CHECK_EQ(num_items, num_labels);
2

        这行代码的作用是判断图片条目数和标签条目数是否相等,不相等则退出。图片与标签是一一对应的,如果不相等说明原始文件有问题。


1
2
3
4
5
6
1  image_file.read(reinterpret_cast&lt;char*&gt;(&amp;rows), 4);
2  rows = swap_endian(rows);
3  image_file.read(reinterpret_cast&lt;char*&gt;(&amp;cols), 4);
4  cols = swap_endian(cols);
5
6

        这4行的作用是读取图片的行数和列数,并做大小端变换,由《深度学习caffe–手写字体识别例程(三)》中的介绍,我们知道图片的行数和列数都是28。


1
2
3
4
5
1  scoped_ptr&lt;db::DB&gt; db(db::GetDB(db_backend));
2  db-&gt;Open(db_path, db::NEW);
3  scoped_ptr&lt;db::Transaction&gt; txn(db-&gt;NewTransaction());
4
5

        这3行的作用首先定义一个指向db_backend类型数据库的指针,然后新建数据库,并打开。最后定义一个指向数据库事务的指针txn,这个指针指向数据库指针db指向的数据库事务。这个事务下面主要被用作数据的转换存储。


1
2
1  char label;
2

       这一句定义了一个char型变量label,它下面被用作保存标签值。


1
2
1  char* pixels = new char[rows * cols];
2

       这行代码用来定义一个指向char型变量的指针,它指向一个大小为rows * cols的char型数组。它在下面用来保存一副图片的数据。


1
2
3
4
1  int count = 0;
2  string value;
3
4

        这两行分别定义了一个int型count变量和一个string类型的value变量。


1
2
3
4
5
1  Datum datum;
2  datum.set_channels(1);
3  datum.set_height(rows);
4  datum.set_width(cols);
5

        这几行首先定义了一个Datum类型的变量datum,Datum数据类型在caffe.proto文件中定义,这个文件位于caffe根目录的src/caffe/proto/路径下,有兴趣可以对照着caffe.proto文件对这个数据类型进行深入研究,Datum中包含的主要数据有:

channels:图片的通道数,代码中取值为1。

height:图片的高,在手写体识别例程中,取值为28

width:图片的宽,在手写体识别例程中,取值为28

data:图片的数据,在手写体识别例程中,data中包含28*28=784个数据

label:图片的label


1
2
3
4
1  LOG(INFO) &lt;&lt; &quot;A total of &quot; &lt;&lt; num_items &lt;&lt; &quot; items.&quot;;
2  LOG(INFO) &lt;&lt; &quot;Rows: &quot; &lt;&lt; rows &lt;&lt; &quot; Cols: &quot; &lt;&lt; cols;
3
4

        这两行代码的作用是在终端上输出总的条目数和行数和列数。


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
1  for (int item_id = 0; item_id &lt; num_items; ++item_id) {
2    image_file.read(pixels, rows * cols);
3    label_file.read(&amp;label, 1);
4    datum.set_data(pixels, rows*cols);
5    datum.set_label(label);
6    string key_str = caffe::format_int(item_id, 8);
7    datum.SerializeToString(&amp;value);
8
9    txn-&gt;Put(key_str, value);
10
11    if (++count % 1000 == 0) {
12      txn-&gt;Commit();
13    }
14  }
15

      这段代码为一个for循环,循环次数为图片文件的条目数,对每个条目进行遍历。经过上面的读取操作,图片文件对象image_file的指针已经指到了第一幅图片的位置,标签文件对象label_file的指针已经指到了第一个标签的位置。进入到for循环中,首先对图片文件进行读取,读取的大小为一副图片,并保存到pixels指向的存储区,然后读取标签文件,读取一个字节,即一个标签,并保存到label变量中。

      接下来datum.set_data(pixels, rows*cols);将图片数据保存到datum数据结构中。datum.set_label(label);将标签保存到datum数据结构中。

       然后string key_str = caffe::format_int(item_id, 8);定义了一个名字为key_str的字符串,它保存的是调用caffe::format_int()函数生成的字符串,它将item_id的值转换为8个字节的字符串的格式,比如item_id的取值为25时,转换完的字符串为“00000025”,这个字符串被用作下面数据库存储的键值。

       再下面datum.SerializeToString(&value);将datum数据结构中的数据转换为字符串,并保存到value中。

       最后将键值key_str和图像数据value写入数据库。

       for循环的最后,每次count加1,如果count是1000的整数倍时,数据库提交一次。


1
2
3
4
1  if (count % 1000 != 0) {
2      txn-&gt;Commit();
3  }
4

        这段代码在for循环外边,判断如果count不是1000的整数倍,说明for循环退出时,还有没提交的数据,则再提交一次。


1
2
3
4
5
1  LOG(INFO) &lt;&lt; &quot;Processed &quot; &lt;&lt; count &lt;&lt; &quot; files.&quot;;
2  delete[] pixels;
3  db-&gt;Close();
4}
5

        这几行是数据转换函数的末尾,打印转换完成的条目数,释放pixels指向的存储空间,并关闭数据库。


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
1int main(int argc, char** argv) {
2#ifndef GFLAGS_GFLAGS_H_
3  namespace gflags = google;
4#endif
5
6  FLAGS_alsologtostderr = 1;
7
8  gflags::SetUsageMessage(&quot;This script converts the MNIST dataset to\n&quot;
9        &quot;the lmdb/leveldb format used by Caffe to load data.\n&quot;
10        &quot;Usage:\n&quot;
11        &quot;    convert_mnist_data [FLAGS] input_image_file input_label_file &quot;
12        &quot;output_db_file\n&quot;
13        &quot;The MNIST dataset could be downloaded at\n&quot;
14        &quot;    http://yann.lecun.com/exdb/mnist/\n&quot;
15        &quot;You should gunzip them after downloading,&quot;
16        &quot;or directly use data/mnist/get_mnist.sh\n&quot;);
17  gflags::ParseCommandLineFlags(&amp;argc, &amp;argv, true);
18

        这段代码是主函数的开始部分,主要是gflags相关的一些操作,这些对数据转换过程基本不会有影响,只是显示一些信息,所以这里不对这些代码进行深究。


1
2
1  const string&amp; db_backend = FLAGS_backend;
2

        这一行代码定义了一个FLAGS_backend的引用db_backend,FLAGS_backend是在代码的开头调用DEFINE_string宏进行定义的,它的取值为“lmdb”,&表示它后边定义的是一个引用。


1
2
3
4
5
6
7
8
9
10
1  if (argc != 4) {
2    gflags::ShowUsageWithFlagsRestrict(argv[0],
3        &quot;examples/mnist/convert_mnist_data&quot;);
4  } else {
5    google::InitGoogleLogging(argv[0]);
6    convert_dataset(argv[1], argv[2], argv[3], db_backend);
7  }
8  return 0;
9}
10

        这段代码是main函数的结尾,首先判断命令行参数个数是否为4,如果不是4,说明输入的命令不对,则打印信息。否则,输入的命令行参数为4则初始化日志,并调用convert_dataset()函数进行数据转换。


1
2
3
4
5
6
7
1#else
2int main(int argc, char** argv) {
3  LOG(FATAL) &lt;&lt; &quot;This example requires LevelDB and LMDB; &quot; &lt;&lt;
4  &quot;compile with USE_LEVELDB and USE_LMDB.&quot;;
5}
6#endif  // USE_LEVELDB and USE_LMDB
7

        这段是#else的宏,它是相对于#if defined(USE_LEVELDB) && defined(USE_LMDB)的,判读如果没有定义USE_LEVELDB或USE_LMDB,则打印错误信息并退出。

 

 

 

给TA打赏
共{{data.count}}人
人已打赏
安全运维

MySQL到MongoDB的数据同步方法!

2021-12-11 11:36:11

安全运维

Ubuntu上NFS的安装配置

2021-12-19 17:36:11

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索