Dependencies: (The specific version is the version in my environment, and does not need to be exactly the same when being used)
- tensorflow-gpu 1.12.0 (building the model)
- numpy 1.16.2 (for matrix operations such as quadtree decomposition, guided filter, etc.)
- matplotlib 3.0.3 (Image reading and writing)
- h5py 2.8.0 (used to read training set and testing set, not nexessary when performing inference)
Parameter configuration instructions:
When running the code, configure the config(a python dictionary) in the main function to control the entire process. The parameters are described as follows:
mode: string, including 'train' , 'test' , 'inference'
ckpt_dir: string type, the storage path of the check point of the network
train_path: string type, training set path (only .h5 files are supported)
test_path: string type, ,testing set path (only .h5 files are supported)
inference_i: string type, the path of the folder where input hazy images are stored
inference_o: string type, the path of the folder which output clear images are written to
param: parameters for building the model, a list containing two variables, the first variable indicates the number of residual blocks, the second variable indicates how many convolutional layers are in each residual block, and cannot be changed after training
qd_param: quadtree parameter, a list containing two variables, the first variable represents the maximum decomposition level of the quadtree, and the second variable represents the decomposition threshold
batch_size: integer, batch size (here refers to the batch size of the histogram)
epoch: integer, how many epoches needed in training process
alpha: float type, coefficients of l2 regularization constraints, to prevent overfitting
learning_rate: float type, learning rate
运行依赖:(具体版本为我的运行环境中的版本,并不需要完全一样)
- tensorflow-gpu 1.12.0 (用于模型搭建)
- numpy 1.16.2 (用于四叉树分解、导向滤波器等矩阵运算)
- matplotlib 3.0.3 (图像读写)
- h5py 2.8.0 (用于读入训练集和测试集,只用于预测可不需要)
参数配置说明:
运行代码时,在main函数中配置字典config以控制整个流程,各参数具体说明如下:
mode: 字符串类型,包含'train'(训练),'test'(测试),'inference'(预测)三种
ckpt_dir:字符串类型,网络的check point存储路径
train_path:字符串类型,训练集路径(仅支持.h5文件)
test_path:字符串类型,测试集路径(仅支持.h5文件)
inference_i:字符串类型,存放预测输入图像的文件夹的路径
inference_o:字符串类型,存放预测输出图像的文件夹的路径
param:模型参数,包含两个变量的list,第一个变量表示residual blocks的数量,第二个变量表示每个residual block中有多少个卷积层,训练完成后不可改变
qd_param:四叉树参数,包含两个变量的list,第一个变量表示四叉树最大分解层数,第二个变量表示分解的阈值
batch_size:整型,批次大小(这里指直方图的批次大小)
epoch:整型,训练的轮数
alpha: 浮点型,二阶正则化约束之前的系数,防止过拟合
learning_rate:浮点型,学习率