TensorFlow Object Detection API 源码(1) DetectionModel
0. 前言
- 参考资料:
1. 基本概念
- 该类定义在 model.py 中。
DetectionModel
是项目中所有检测模型的公共基类。- 在
train.py
及eval.py
等脚本中,通过DetectionModel
子类对象构建计算图。 - 一般来说,并不直接使用
DetectionModel
,而会使用它的几个子类,基于特定算法做了进一步改进: - ssd_meta_arch.py:基于SSD算法。
- faster_rcnn_meta_arch.py:基于Faster R-CNN算法。
- rfcn_meta_arch.py:基于R-FCN算法。
2. 重点方法介绍:
2.1. preprocess
- 作用:输入数据预处理,如的 scaling/shifting/resizing/padding 操作。
- 输入:
inputs
:shape为[batch, height_in, width_in, channels]
的float32 tensor,数值范围在[0, 255]之间。- 输出:
preprocessed_inputs
:shape为[batch, height_out, width_out, channels]
的 float32
tensor。true_image_shapes
:shape为[batch, 3]
的int32 tensor,代表preprocessed_inputs
中每张图片的真实 resolution(格式为[height, width, channels]
)。- 注意事项:
- 不能含有 trainable 变量。
- 不应该对ground truth annotations有影响(即不能进行切片操作,因为切片会影响ground truth bbox)。
- batchsize可以与之后的
predict
中不同,建议不适用batch功能(猜测意思是每次preprocess只对一张图片进行操作,batch多个输入数据的工作放到model外进行)。 - 输出图片的 resolution 可以不同。
- 输出中
preprocessed_inputs
的shape是固定的,但可能是padded with zeros的结果,图像实际resolution不同。
2.2. predict
- 作用:输入preprocess的结果,并获取预测值。
- 输入:即
preprocess
函数中的输出preprocessed_inputs
和true_image_shapes
,不重复介绍了。 - 输出:
prediction_dict
,用字典保存了所有预测结果,可以根据模型自己实现。 - 注意事项:
- 该函数输出的
prediction_dict
会用于后续的loss
或postprocess
中。 - 对于不同的模型,
prediction_dict
中的key也各不相同,具体可以参考DetectionModel
的各个子类。
2.3. postprocess
- 作用:筛选模型预测结果,获取最终检测结果。
- 输入:
true_image_shapes
是preprocess
函数的输出,prediction_dict
是predict
函数的输出,不重复介绍了。- 为了模型的扩展,还添加了
**params
,方便二次开发。 - 输出,一个字典,包含以下参数:
detection_boxes
: shape为[batch, max_detections, 4]
detection_scores
: shape为[batch, max_detections]
detection_classes
: shape为[batch, max_detections]
(根据模型具体实现,该值可能不存在)。instance_masks
: shape为[batch, max_detections, image_height, image_width]
,可选参数。keypoints
: shape为[batch, max_detections, num_keypoints, 2]
,可选参数。num_detections
: shape为[batch]
。- 注意事项:一般用于预测,不用于训练。
2.4. loss
- 作用:通过预测结果计算损失函数的值。
- 输入:
true_image_shapes
是preprocess
函数的输出,prediction_dict
是predict
函数的输出,不重复介绍了。 - 输出:一个字典,保存各类不同损失函数的计算结果tensor,不同算法不同。
- 注意事项:一般用于训练,不用于预测。
2.5. restore_map
- 作用:获取需要从ckpt文件中restore的变量。
- 输入:
fine_tune_checkpoint_type
表示获取数据的类型。 - 输出:一个字典,key为变量名,value为变量对象。
3. 预测脚本中 DetectionModel 的使用
- 预测总体过程:inputs (images tensor) -> preprocess -> predict -> postprocess -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
- 预测脚本
eval.py
中调用了evaluator.py
中的evaluate
方法
def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
checkpoint_dir, eval_dir, graph_hook_fn=None, evaluator_list=None):
# model 就是 DetectionModel 子类的实例
model = create_model_fn()
if eval_config.ignore_groundtruth and not eval_config.export_path:
logging.fatal('If ignore_groundtruth=True then an export_path is '
'required. Aborting!!!')
# 要到该方法下继续查看
tensor_dict, losses_dict = _extract_predictions_and_losses(
model=model,
create_input_dict_fn=create_input_dict_fn,
ignore_groundtruth=eval_config.ignore_groundtruth)
- 进一步分析
_extract_predictions_and_losses
方法:
def _extract_predictions_and_losses(model,
create_input_dict_fn,
ignore_groundtruth=False):
# 获取输入数据过程省略,最终输入数据保存在 original_image 中
original_image = ...
# 调用 model.preprocess 方法
preprocessed_image, true_image_shapes = model.preprocess(
tf.to_float(original_image))
# 调用 model.predict 方法
prediction_dict = model.predict(preprocessed_image, true_image_shapes)
# 调用 model.postprocess 方法
detections = model.postprocess(prediction_dict, true_image_shapes)
# 其他操作省略
4. 训练脚本中 DetectionModel 的使用
- 训练总体过程:inputs (images tensor) -> preprocess -> predict -> loss -> outputs (loss tensor)
- 训练脚本
train.py
调用了trainer.py
中的train
方法。方法很长,但 DetectionModel 子类主要使用在_create_losses
函数中
def _create_losses(input_queue, create_model_fn, train_config):
# 获取子类对象
detection_model = create_model_fn()
# 读取输入数据
(images, _, groundtruth_boxes_list, groundtruth_classes_list,
groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs(
input_queue,
detection_model.num_classes,
train_config.merge_multiple_label_boxes,
train_config.use_multiclass_scores)
# 进行preprocess操作
preprocessed_images = []
true_image_shapes = []
for image in images:
# 调用model.preprocess
resized_image, true_image_shape = detection_model.preprocess(image)
preprocessed_images.append(resized_image)
true_image_shapes.append(true_image_shape)
# batch操作(在文档中说道了,batch建议在model外进行)
images = tf.concat(preprocessed_images, 0)
true_image_shapes = tf.concat(true_image_shapes, 0)
if any(mask is None for mask in groundtruth_masks_list):
groundtruth_masks_list = None
if any(keypoints is None for keypoints in groundtruth_keypoints_list):
groundtruth_keypoints_list = None
# 保存原始ground truth内容
detection_model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list,
groundtruth_masks_list,
groundtruth_keypoints_list)
# 调用model.predict进行预测
prediction_dict = detection_model.predict(images, true_image_shapes)
# 调用 model.loss 计算损失函数,添加到`LOSSES` collection中
losses_dict = detection_model.loss(prediction_dict, true_image_shapes)
for loss_tensor in losses_dict.values():
tf.losses.add_loss(loss_tensor)
编辑于 2018-06-15 13:24