首发于Bob学步
TensorFlow Object Detection API 源码(1) DetectionModel

TensorFlow Object Detection API 源码(1) DetectionModel

0. 前言


1. 基本概念

  • 该类定义在 model.py 中。
  • DetectionModel是项目中所有检测模型的公共基类。
  • train.pyeval.py等脚本中,通过DetectionModel子类对象构建计算图。
  • 一般来说,并不直接使用DetectionModel,而会使用它的几个子类,基于特定算法做了进一步改进:

2. 重点方法介绍:

2.1. preprocess

  • 作用:输入数据预处理,如的 scaling/shifting/resizing/padding 操作。
  • 输入:
    • inputs:shape为 [batch, height_in, width_in, channels] 的float32 tensor,数值范围在[0, 255]之间。
  • 输出:
    • preprocessed_inputs:shape为[batch, height_out, width_out, channels]的 float32
      tensor。
    • true_image_shapes:shape为[batch, 3]的int32 tensor,代表preprocessed_inputs中每张图片的真实 resolution(格式为[height, width, channels])。
  • 注意事项:
    • 不能含有 trainable 变量。
    • 不应该对ground truth annotations有影响(即不能进行切片操作,因为切片会影响ground truth bbox)。
    • batchsize可以与之后的predict中不同,建议不适用batch功能(猜测意思是每次preprocess只对一张图片进行操作,batch多个输入数据的工作放到model外进行)。
    • 输出图片的 resolution 可以不同。
    • 输出中preprocessed_inputs的shape是固定的,但可能是padded with zeros的结果,图像实际resolution不同。

2.2. predict

  • 作用:输入preprocess的结果,并获取预测值。
  • 输入:即preprocess函数中的输出preprocessed_inputstrue_image_shapes,不重复介绍了。
  • 输出:prediction_dict,用字典保存了所有预测结果,可以根据模型自己实现。
  • 注意事项:
    • 该函数输出的prediction_dict会用于后续的losspostprocess中。
    • 对于不同的模型,prediction_dict中的key也各不相同,具体可以参考DetectionModel的各个子类。

2.3. postprocess

  • 作用:筛选模型预测结果,获取最终检测结果。
  • 输入:
    • true_image_shapespreprocess函数的输出,prediction_dictpredict函数的输出,不重复介绍了。
    • 为了模型的扩展,还添加了**params,方便二次开发。
  • 输出,一个字典,包含以下参数:
    • detection_boxes: shape为[batch, max_detections, 4]
    • detection_scores: shape为[batch, max_detections]
    • detection_classes: shape为[batch, max_detections](根据模型具体实现,该值可能不存在)。
    • instance_masks: shape为[batch, max_detections, image_height, image_width],可选参数。
    • keypoints: shape为[batch, max_detections, num_keypoints, 2],可选参数。
    • num_detections: shape为[batch]
  • 注意事项:一般用于预测,不用于训练。

2.4. loss

  • 作用:通过预测结果计算损失函数的值。
  • 输入:true_image_shapespreprocess函数的输出,prediction_dictpredict函数的输出,不重复介绍了。
  • 输出:一个字典,保存各类不同损失函数的计算结果tensor,不同算法不同。
  • 注意事项:一般用于训练,不用于预测。

2.5. restore_map

  • 作用:获取需要从ckpt文件中restore的变量。
  • 输入:fine_tune_checkpoint_type表示获取数据的类型。
  • 输出:一个字典,key为变量名,value为变量对象。

3. 预测脚本中 DetectionModel 的使用

  • 预测总体过程:inputs (images tensor) -> preprocess -> predict -> postprocess -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
  • 预测脚本eval.py中调用了evaluator.py中的 evaluate方法
def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories,
             checkpoint_dir, eval_dir, graph_hook_fn=None, evaluator_list=None):

  # model 就是 DetectionModel 子类的实例
  model = create_model_fn()

  if eval_config.ignore_groundtruth and not eval_config.export_path:
    logging.fatal('If ignore_groundtruth=True then an export_path is '
                  'required. Aborting!!!')

  # 要到该方法下继续查看
  tensor_dict, losses_dict = _extract_predictions_and_losses(
      model=model,
      create_input_dict_fn=create_input_dict_fn,
      ignore_groundtruth=eval_config.ignore_groundtruth)
  • 进一步分析 _extract_predictions_and_losses 方法:
def _extract_predictions_and_losses(model,
                                    create_input_dict_fn,
                                    ignore_groundtruth=False):
  # 获取输入数据过程省略,最终输入数据保存在 original_image 中
  original_image = ...

  # 调用 model.preprocess 方法
  preprocessed_image, true_image_shapes = model.preprocess(
      tf.to_float(original_image))

  # 调用 model.predict 方法
  prediction_dict = model.predict(preprocessed_image, true_image_shapes)

  # 调用 model.postprocess 方法
  detections = model.postprocess(prediction_dict, true_image_shapes)

  # 其他操作省略

4. 训练脚本中 DetectionModel 的使用

  • 训练总体过程:inputs (images tensor) -> preprocess -> predict -> loss -> outputs (loss tensor)
  • 训练脚本train.py调用了trainer.py中的train方法。方法很长,但 DetectionModel 子类主要使用在_create_losses函数中
def _create_losses(input_queue, create_model_fn, train_config):
  # 获取子类对象
  detection_model = create_model_fn()

  # 读取输入数据
  (images, _, groundtruth_boxes_list, groundtruth_classes_list,
   groundtruth_masks_list, groundtruth_keypoints_list, _) = get_inputs(
       input_queue,
       detection_model.num_classes,
       train_config.merge_multiple_label_boxes,
       train_config.use_multiclass_scores)

  # 进行preprocess操作
  preprocessed_images = []
  true_image_shapes = []
  for image in images:
    # 调用model.preprocess
    resized_image, true_image_shape = detection_model.preprocess(image)
    preprocessed_images.append(resized_image)
    true_image_shapes.append(true_image_shape)

  # batch操作(在文档中说道了,batch建议在model外进行)
  images = tf.concat(preprocessed_images, 0)
  true_image_shapes = tf.concat(true_image_shapes, 0)

  if any(mask is None for mask in groundtruth_masks_list):
    groundtruth_masks_list = None
  if any(keypoints is None for keypoints in groundtruth_keypoints_list):
    groundtruth_keypoints_list = None

  # 保存原始ground truth内容
  detection_model.provide_groundtruth(groundtruth_boxes_list,
                                      groundtruth_classes_list,
                                      groundtruth_masks_list,
                                      groundtruth_keypoints_list)

  # 调用model.predict进行预测
  prediction_dict = detection_model.predict(images, true_image_shapes)

  # 调用 model.loss 计算损失函数,添加到`LOSSES` collection中
  losses_dict = detection_model.loss(prediction_dict, true_image_shapes)
  for loss_tensor in losses_dict.values():
    tf.losses.add_loss(loss_tensor)

编辑于 2018-06-15 13:24