<-- Home

C++深度学习库DLL

DLL(Deep Learning Library)是基于现代C++实现的机器学习库。作者大佬在攻读博士学位的时候认为现有的机器学习库不能满足他的研究需求,于是自己实现了一个基于现代C++新特性的机器学习库(Orz),作者酷爱模板类,几乎整个库都是template和lambda,一出现编译错误那个错误信息长的不忍直视,该库只有头文件,不需要编译,能直接使用。

作者对该库进行了优化,使得其性能比主流的机器学习库还要快,详见这篇Post

这个库的文档还在完善中,记录一下该库的简单使用过程:

首先是获取

git clone --recursive https://github.com/wichtounet/dll.git

这个库里面包含的子模块比较多,要使用--recursive选项来clone。文件比较多,建议挂个代理。

然后是安装,因为该库不需要编译,所以可以拿来直接用,但要注意库文件的包含关系。作者也提供了一个Makefile文件,来把这些头文件放到/usr/include目录下。

cd dll
sudo make install_headers

该库依赖于C++的新特性,所以需要安装新版本的GCC或者clang编译器,我选择的是clang-4.0。

sudo apt-get install clang-4.0

我是使用CPU来计算,为了加快矩阵运算速度,还需要安装BLAS库,这里选择的是atlas。

sudo apt-get install libatlas-base-dev

如果是使用GPU来计算的话,则安装cuBLAS。

安装基本完成,来试试这个库,测试代码如下。

#include <dll/neural/dense_layer.hpp>
#include <dll/network.hpp>
#include <dll/datasets.hpp>

int main() {
    // Load the dataset
    auto dataset = dll::make_mnist_dataset(dll::batch_size<1000>{}, dll::normalize_pre{});

    // Build the network

    using network_t = dll::dyn_network_desc<
        dll::network_layers<
            dll::dense_layer<28 * 28, 500>,
            dll::dense_layer<500, 250>,
            dll::dense_layer<250, 10, dll::softmax>
        >
        , dll::updater<dll::updater_type::NADAM>     // Nesterov Adam (NADAM)
        , dll::batch_size<1000>                       // The mini-batch size
        , dll::shuffle                               // Shuffle before each epoch
    >::network_t;

    auto net = std::make_unique<network_t>();

    net->display_pretty();

    // Train the network for performance sake
    net->fine_tune(dataset.train(), 20);

    // Test the network on test set
    net->evaluate(dataset.test());

    return 0;
}

上面的程序构建了一个三全连接层神经网络,使用mnist数据集来训练。

编译

clang++4.0 -DETL_PARALLEL -DETL_VECTORIZE_FUL -DETL_BLAS_MODE -std=c++1z -lpthread -lblas main.cpp

运行时要保证mnist文件夹和编译好的可执行文件在同一目录下,运行结果如下

 ------------------------------------------------------------
 | Index | Layer                | Parameters | Output Shape |
 ------------------------------------------------------------
 | 0     | Dense(SIGMOID) (dyn) |     392500 | [Bx500]      |
 | 1     | Dense(SIGMOID) (dyn) |     125250 | [Bx250]      |
 | 2     | Dense(SOFTMAX) (dyn) |       2510 | [Bx10]       |
 ------------------------------------------------------------
                Total Parameters:     520260

Train the network with "Stochastic Gradient Descent"
    Updater: NADAM
       Loss: CATEGORICAL_CROSS_ENTROPY
 Early Stop: Goal(error)

With parameters:
          epochs=20
      batch_size=1000
   learning_rate=0.002
           beta1=0.9
           beta2=0.999

epoch   0/20 batch   60/  60 - error: 0.08787 loss: 0.31868 time 42694ms 
epoch   1/20 batch   60/  60 - error: 0.06163 loss: 0.21418 time 45747ms 
epoch   2/20 batch   60/  60 - error: 0.04328 loss: 0.15348 time 44532ms 
epoch   3/20 batch   60/  60 - error: 0.03520 loss: 0.12210 time 45938ms 
epoch   4/20 batch   60/  60 - error: 0.02417 loss: 0.08843 time 46223ms 
epoch   5/20 batch   60/  60 - error: 0.01897 loss: 0.07094 time 42572ms 
epoch   6/20 batch   60/  60 - error: 0.01458 loss: 0.05706 time 42947ms 
epoch   7/20 batch   60/  60 - error: 0.01352 loss: 0.05202 time 43548ms 
epoch   8/20 batch   60/  60 - error: 0.00825 loss: 0.03680 time 45794ms 
epoch   9/20 batch   60/  60 - error: 0.00548 loss: 0.02755 time 43514ms 
epoch  10/20 batch   60/  60 - error: 0.03513 loss: 0.11808 time 44827ms 
epoch  11/20 batch   60/  60 - error: 0.01080 loss: 0.04365 time 48295ms 
epoch  12/20 batch   60/  60 - error: 0.00543 loss: 0.02718 time 47101ms 
epoch  13/20 batch   60/  60 - error: 0.00332 loss: 0.02009 time 47661ms 
epoch  14/20 batch   60/  60 - error: 0.00222 loss: 0.01571 time 50148ms 
epoch  15/20 batch   60/  60 - error: 0.00137 loss: 0.01276 time 45731ms 
epoch  16/20 batch   60/  60 - error: 0.00110 loss: 0.01059 time 44344ms 
epoch  17/20 batch   60/  60 - error: 0.00065 loss: 0.00855 time 44397ms 
epoch  18/20 batch   60/  60 - error: 0.00045 loss: 0.00736 time 44606ms 
epoch  19/20 batch   60/  60 - error: 0.00028 loss: 0.00611 time 45947ms 
Training took 913s

Evaluation Results
   error: 0.01820 
    loss: 0.05845 
evaluation took 2208ms

因为我笔记本的CPU比较老了,还开的是虚拟机,所有时间可能有点长,准确率达到98%,还是不错的。

然后试试训练好的模型,自己用画图手写几个数字测试一下。

然而有很多自己写的数字还是识别错了。

_(:з」∠)_