博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
用Tensorflow做蝴蝶检测
阅读量:5085 次
发布时间:2019-06-13

本文共 17894 字,大约阅读时间需要 59 分钟。

报名了一个蝴蝶检测比赛,一共给了700多张图,包含94种蝴蝶类别,要求检测出图片中的蝴蝶并正确分类。

1.拿到数据集后,第一部就是将700多张图分成了 483张训练样本和238张测试样本(由于数据集中,有15种类别的蝴蝶只有一张,所以在测试样本中,仅包含了79种蝴蝶类别)

2.利用一个现有的包含蝴蝶类别的模型直接对测试集中的蝴蝶进行检测(相当于二分类),这里选用的是“ faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28 ”模型。该模型是在Open Image 数据集上训练的,总共有545个不同物体类别。

先是看的 object_detection/object_detection_tutorial.ipynb 这里直接导入的是frozen model,导致无法修改阈值,所以师兄换了一种模型导入方式,可以把检测的阈值调低,以提高检测到的蝴蝶数目

然后我需要做的是对检测的结果进行评估(计算Percison),但是走了点弯路。

tensorflow教程中的流程是,先把测试集转成TFrecord格式,然后再inference之后,直接在后面加入检测到的结果。但是在模型的输入处出现的问题,因为原本Frozen Model的输入是一个tensor,而师兄修改后的模型输入直接就是 image,所以我不知道怎么将输入进行转换。这里还需要后续的学习。于是我就换了一个方向,以师兄的code为基础,每次从文件夹中提出一张image,进行inference,然后我就根据image信息,解析其对应的xml文件中的信息,并写成tf_example的格式;同时将模型的输出dets添加到刚刚生成的example中。这样就解决了之前的问题,顺利把 annotation+detection结果保存成了 TFrecord格式。

下一步就是利用 object_detection/metrics/offline_eval_map_corloc.py 进行评估,但是出现了两个问题,导致我一度陷入僵局。第一个就是出现了  " ground_truth_group_of .size  :None type has no attrbute to size ",我以为是我的TFrecord出现了问题,但是最后发现是因为 

decoded_dict = data_parser.parse(example)

这里解析的时候,由于我原本TFrecord中并没有写入  standard_fields.TfExampleFields.object_group_of.object_group_of 信息,所以在解析的时候,这个内容就被填上了None ,所以不存在size,导致上面的问题产生。

self.optional_items_to_handlers = {        fields.InputDataFields.groundtruth_difficult:            Int64Parser(fields.TfExampleFields.object_difficult),        fields.InputDataFields.groundtruth_group_of:            Int64Parser(fields.TfExampleFields.object_group_of)

 查下 open image 中的特有的group_of参数是什么意思: Indicates that the box spans a group of objects (e.g., a bed of flowers or a crowd of people). We asked annotators to use this tag for cases with more than 5 instances which are heavily occluding each other and are physically touching.

 也就是说,带有group_of标记的说明,该框中包含了5个以上的物体,如拥挤的人群,一个铺满鲜花的床等等。

 

还有一个问题就是,我的GroundTruth中只有蝴蝶和背景两类,但是原本模型的label_map中却包含545类,所以其余的类别是没有GT的,这样在程序中有一个判断:

# object_detection/utils/object_detection_evaluation.py if (self.num_gt_instances_per_class == 0).any():      logging.warn(          'The following classes have no ground truth examples: %s',          np.squeeze(np.argwhere(self.num_gt_instances_per_class == 0)) +          self.label_id_offset)

我后来找到后,直接将其注释掉,最终跑通了。

该模型在蝴蝶单类的检测Precision=0.728

 



 

以下主要解释下评估的代码,防止以后忘记。假设模型输出validation_detections.tfrecord已保存在

models/research/butterfly路径下

第一步是生成配置文件:

# From models/research/butterflySPLIT=validation  # or testmkdir -p ${SPLIT}_eval_metricsecho "label_map_path: '../object_detection/data/oid_bbox_trainable_label_map.pbtxt'tf_record_input_reader: { input_path: '${SPLIT}_detections.tfrecord' }" > ${SPLIT}_eval_metrics/${SPLIT}_input_config.pbtxtecho "metrics_set: 'open_images_detection_metrics'" > ${SPLIT}_eval_metrics/${SPLIT}_eval_config.pbtxt

然后运行评估程序:

# From tensorflow/models/research/butterflySPLIT=validation  # or testPYTHONPATH=$PYTHONPATH:$(readlink -f ..) \python -m object_detection/metrics/offline_eval_map_corloc \  --eval_dir=${SPLIT}_eval_metrics \      #结果保存的路径  --eval_config_path=${SPLIT}_eval_metrics/${SPLIT}_eval_config.pbtxt \    --input_config_path=${SPLIT}_eval_metrics/${SPLIT}_input_config.pbtxt  #输入的路径

首先来看下主程序  models/research/object_detection/metrics/offline_eval_map_corloc.py

  import csv

  import os
  import re
  import tensorflow as tf

 

  from object_detection import evaluator

  from object_detection.core import standard_fields
  from object_detection.metrics import tf_example_parser
  from object_detection.utils import config_util
  from object_detection.utils import label_map_util

... def read_data_and_evaluate(input_config, eval_config): 略 def write_metrics(metrics, output_dir): ... def main(argv):  del argv  required_flags = ['input_config_path', 'eval_config_path', 'eval_dir'] #对应输入的三个参数  for flag_name in required_flags:    if not getattr(FLAGS, flag_name):      raise ValueError('Flag --{} is required'.format(flag_name))  configs = config_util.get_configs_from_multiple_files(      eval_input_config_path=FLAGS.input_config_path,      eval_config_path=FLAGS.eval_config_path)  eval_config = configs['eval_config']  input_config = configs['eval_input_config']  metrics = read_data_and_evaluate(input_config, eval_config)    #主要实现部分在这里  # Save metrics  write_metrics(metrics, FLAGS.eval_dir)

具体来看下

  read_data_and_evaluate(input_config, eval_config):

def read_data_and_evaluate(input_config, eval_config):  """Reads pre-computed object detections and groundtruth from tf_record.  Args:    input_config: input config proto of type  输入配置文件      object_detection.protos.InputReader.    eval_config: evaluation config proto of type 评估配置文件      object_detection.protos.EvalConfig.  Returns:    Evaluated detections metrics.  返回:评估结果  Raises:    ValueError: if input_reader type is not supported or metric type is unknown.  """  if input_config.WhichOneof('input_reader') == 'tf_record_input_reader':    input_paths = input_config.tf_record_input_reader.input_path    label_map = label_map_util.load_labelmap(input_config.label_map_path)#载入label_map    max_num_classes = max([item.id for item in label_map.item])      #获得最大的类别对应id (545)    categories = label_map_util.convert_label_map_to_categories(        label_map, max_num_classes)                       #list类型,eg. categories[110]={'id':111,'name':'Butterfly'}    object_detection_evaluators = evaluator.get_evaluators(        eval_config, categories)    # Support a single evaluator    object_detection_evaluator = object_detection_evaluators[0]      #对应object_detection_evaluation.OpenImagesDetectionEvaluator    skipped_images = 0    processed_images = 0    for input_path in _generate_filenames(input_paths):      tf.logging.info('Processing file: {0}'.format(input_path))      record_iterator = tf.python_io.tf_record_iterator(path=input_path)  #读取 validation_detection.tfrecord      data_parser = tf_example_parser.TfExampleDetectionAndGTParser()      for string_record in record_iterator:                   #迭代器,一共238个测试样本,每次读取一个样本检测结果        tf.logging.log_every_n(tf.logging.INFO, 'Processed %d images...', 1000,                               processed_images)        processed_images += 1        example = tf.train.Example()        example.ParseFromString(string_record)                #解析TFrecord--> example.features.feature 中以字典形式存放数据        decoded_dict = data_parser.parse(example)              #对TFrecord进一步解析,还原:groundtruth_boxes、groundtruth_classes、detection_boxes、detection_classes、detection_scores                                            if decoded_dict:      #对应 object_detection/utils/object_detection_evaluation.py 中的 class OpenImagesDetectionEvaluator(),默认iou_threshold=0.5          object_detection_evaluator.add_single_ground_truth_image_info(               decoded_dict[standard_fields.DetectionResultFields.key],              decoded_dict)          object_detection_evaluator.add_single_detected_image_info(              decoded_dict[standard_fields.DetectionResultFields.key],              decoded_dict)        else:          skipped_images += 1          tf.logging.info('Skipped images: {0}'.format(skipped_images))    return object_detection_evaluator.evaluate()  raise ValueError('Unsupported input_reader_config.')

可以看出,主要的评测过程又放在了 

  object_detection/utils/object_detection_evaluation.py

第一个是 class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator): 继承自  class ObjectDetectionEvaluator(DetectionEvaluator) ,而这个class 又继承自 class DetectionEvaluator(object)

所以我们从上往下看这几个函数,先是基类 DetectionEvaluator(object):Line:42

class DetectionEvaluator(object):  """Interface for object detection evalution classes.  Example usage of the Evaluator:  ------------------------------  evaluator = DetectionEvaluator(categories)                            即挨个添加 GT和detections ,最后一起evaluate()  # Detections and groundtruth for image 1.  evaluator.add_single_groundtruth_image_info(...)  evaluator.add_single_detected_image_info(...)  # Detections and groundtruth for image 2.  evaluator.add_single_groundtruth_image_info(...)  evaluator.add_single_detected_image_info(...)  metrics_dict = evaluator.evaluate()  """  __metaclass__ = ABCMeta  def __init__(self, categories):    """Constructor.    Args:      categories: A list of dicts, each of which has the following keys -        'id': (required) an integer id uniquely identifying this category.        'name': (required) string representing category name e.g., 'cat', 'dog'.    """    self._categories = categories  @abstractmethod  def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):    """Adds groundtruth for a single image to be used for evaluation.    Args:      image_id: A unique string/integer identifier for the image.      groundtruth_dict: A dictionary of groundtruth numpy arrays required        for evaluations.    """    pass  @abstractmethod  def add_single_detected_image_info(self, image_id, detections_dict):    """Adds detections for a single image to be used for evaluation.    Args:      image_id: A unique string/integer identifier for the image.      detections_dict: A dictionary of detection numpy arrays required        for evaluation.    """    pass  @abstractmethod  def evaluate(self):    """Evaluates detections and returns a dictionary of metrics."""    pass  @abstractmethod  def clear(self):    """Clears the state to prepare for a fresh evaluation."""    pass

然后class ObjectDetectionEvaluator(DetectionEvaluator)  Line:104class ObjectDetectionEvaluator(DetectionEvaluator):

"""A class to evaluate detections."""  def __init__(self,               categories,               matching_iou_threshold=0.5,               evaluate_corlocs=False,               metric_prefix=None,               use_weighted_mean_ap=False,               evaluate_masks=False):    """Constructor.    Args:      xxxxxx    Raises:      ValueError: If the category ids are not 1-indexed.    """   ...#这个地方是最关键的,后面会一直用到    self._evaluation = ObjectDetectionEvaluation(        num_groundtruth_classes=self._num_classes,        matching_iou_threshold=self._matching_iou_threshold,        use_weighted_mean_ap=self._use_weighted_mean_ap,        label_id_offset=self._label_id_offset)    ...  def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):    """Adds groundtruth for a single image to be used for evaluation.    """    ...略   self._evaluation.add_single_ground_truth_image_info(xxx)     def add_single_detected_image_info(self, image_id, detections_dict):    """Adds detections for a single image to be used for evaluation.    """    ...
  self._evaluation.add_detected_image_info(xxx)
   def evaluate(self):    """Compute evaluation result.    """    ...   (per_class_ap, mean_ap, _, _, per_class_corloc, mean_corloc) = (self._evaluation.evaluate())
... def clear(self):    """Clears the state to prepare for a fresh evaluation."""    self._evaluation = ObjectDetectionEvaluation(        num_groundtruth_classes=self._num_classes,        matching_iou_threshold=self._matching_iou_threshold,        use_weighted_mean_ap=self._use_weighted_mean_ap,        label_id_offset=self._label_id_offset)    self._image_ids.clear()

最后是OpenImagesDetectionEvaluator(ObjectDetectionEvaluator)    Line:376

class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):  """A class to evaluate detections using Open Images V2 metrics.    Open Images V2 introduce group_of type of bounding boxes and this metric    handles those boxes appropriately.  """  def __init__(self,               categories,               matching_iou_threshold=0.5,               evaluate_corlocs=False):    """Constructor.    Args:      categories: A list of dicts, each of which has the following keys -        'id': (required) an integer id uniquely identifying this category.        'name': (required) string representing category name e.g., 'cat', 'dog'.      matching_iou_threshold: IOU threshold to use for matching groundtruth        boxes to detection boxes.      evaluate_corlocs: if True, additionally evaluates and returns CorLoc.    """    super(OpenImagesDetectionEvaluator, self).__init__(        categories,        matching_iou_threshold,        evaluate_corlocs,        metric_prefix='OpenImagesV2')  def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):    """Adds groundtruth for a single image to be used for evaluation.    """    if image_id in self._image_ids:      raise ValueError('Image with id {} already added.'.format(image_id))    groundtruth_classes = (        groundtruth_dict[standard_fields.InputDataFields.groundtruth_classes] -        self._label_id_offset)    # If the key is not present in the groundtruth_dict or the array is empty    # (unless there are no annotations for the groundtruth on this image)    # use values from the dictionary or insert None otherwise.    if (standard_fields.InputDataFields.groundtruth_group_of in        groundtruth_dict.keys() and        (groundtruth_dict[standard_fields.InputDataFields.groundtruth_group_of]         .size or not groundtruth_classes.size)):      groundtruth_group_of = groundtruth_dict[          standard_fields.InputDataFields.groundtruth_group_of]    else:      groundtruth_group_of = None      if not len(self._image_ids) % 1000:        logging.warn(            'image %s does not have groundtruth group_of flag specified',            image_id)    self._evaluation.add_single_ground_truth_image_info(        image_id,        groundtruth_dict[standard_fields.InputDataFields.groundtruth_boxes],        groundtruth_classes,        groundtruth_is_difficult_list=None,        groundtruth_is_group_of_list=groundtruth_group_of)    self._image_ids.update([image_id])

可以看出,这里只是修改了add_single_ground_truth_image_info 函数,其他都没变。而在其父类中,有把主要的工作交给了 class ObjectDetectionEvaluation(object) 来处理,这下整个代码在逐渐清晰起来。我最后会画个程序包含关系图,可能更容易理解些。

下面整个才是主要的保存 GT 和 Detection 结果的部分哦!!!

class ObjectDetectionEvaluation(object):  """Internal implementation of Pascal object detection metrics."""  def __init__(self,num_groundtruth_classes,matching_iou_threshold=0.5,nms_iou_threshold=1.0,nms_max_output_boxes=10000,use_weighted_mean_ap=False,label_id_offset=0):    if num_groundtruth_classes < 1:      raise ValueError('Need at least 1 groundtruth class for evaluation.')    self.per_image_eval = per_image_evaluation.PerImageEvaluation(        num_groundtruth_classes=num_groundtruth_classes,        matching_iou_threshold=matching_iou_threshold,        nms_iou_threshold=nms_iou_threshold,        nms_max_output_boxes=nms_max_output_boxes)    def clear_detections(self):    self._initialize_detections()  def add_single_ground_truth_image_info(self,image_key, groundtruth_boxes, groundtruth_class_labels, groundtruth_is_difficult_list=None, groundtruth_is_group_of_list=None, groundtruth_masks=None):   def add_single_detected_image_info(self, image_key, detected_boxes,detected_scores, detected_class_labels,detected_masks=None):       scores, tp_fp_labels, is_class_correctly_detected_in_image = (        self.per_image_eval.compute_object_detection_metrics(            detected_boxes=detected_boxes,            detected_scores=detected_scores,            detected_class_labels=detected_class_labels,            groundtruth_boxes=groundtruth_boxes,            groundtruth_class_labels=groundtruth_class_labels,            groundtruth_is_difficult_list=groundtruth_is_difficult_list,            groundtruth_is_group_of_list=groundtruth_is_group_of_list,            detected_masks=detected_masks,            groundtruth_masks=groundtruth_masks))    for i in range(self.num_class):      if scores[i].shape[0] > 0:        self.scores_per_class[i].append(scores[i])        self.tp_fp_labels_per_class[i].append(tp_fp_labels[i])    (self.num_images_correctly_detected_per_class    ) += is_class_correctly_detected_in_image   def evaluate(self):    """Compute evaluation result.    Returns:      A named tuple with the following fields -        average_precision: float numpy array of average precision for            each class.        mean_ap: mean average precision of all classes, float scalar        precisions: List of precisions, each precision is a float numpy            array        recalls: List of recalls, each recall is a float numpy array        corloc: numpy float array        mean_corloc: Mean CorLoc score for each class, float scalar    """         scores = np.concatenate(self.scores_per_class[class_index])      tp_fp_labels = np.concatenate(self.tp_fp_labels_per_class[class_index])      precision, recall = metrics.compute_precision_recall(          scores, tp_fp_labels, self.num_gt_instances_per_class[class_index])      self.precisions_per_class.append(precision)      self.recalls_per_class.append(recall)      average_precision = metrics.compute_average_precision(precision, recall)      self.average_precision_per_class[class_index] = average_precision    self.corloc_per_class = metrics.compute_cor_loc(        self.num_gt_imgs_per_class,        self.num_images_correctly_detected_per_class)     mean_ap = np.nanmean(self.average_precision_per_class)    mean_corloc = np.nanmean(self.corloc_per_class)    return ObjectDetectionEvalMetrics(        self.average_precision_per_class, mean_ap, self.precisions_per_class,        self.recalls_per_class, self.corloc_per_class, mean_corloc)

我把不重要的部分都剃掉了,主要有两个重要的函数 1. object_detection/utils/per_image_evaluatuion.py  计算单张图的precision和recall

                        2. object_detection/utils/metrics.py          统计上述结果,并计算mAP等数值

1. object_detection/utils/per_image_evaluatuion.py  计算单张图的precision和recall

 

scores, tp_fp_labels, is_class_correctly_detected_in_image = compute_object_detection_metrics(...)-->      scores, tp_fp_labels = self._compute_tp_fp(...)       -->for i in range(self.num_groundtruth_classes):               scores, tp_fp_labels = self._compute_tp_fp_for_single_class(...)               -->(iou, ioa, scores,num_detected_boxes) = self._get_overlaps_and_scores_box_mode(...)                   -->detected_boxlist = np_box_list_ops.non_max_suppression(...)           -->

 

转载于:https://www.cnblogs.com/caffeaoto/p/8758962.html

你可能感兴趣的文章
To learn, or not to learn Windows Mobile - that is the question
查看>>
数据库高级应用之事务
查看>>
25-Fibonacci(矩阵快速幂)
查看>>
fastcgi与cgi的区别[转载]
查看>>
理解与模拟一个简单servlet容器
查看>>
Linux高阶命令进阶
查看>>
c++ 插入容器元素(insert)
查看>>
重写List集合的ToString方法
查看>>
localX,mouseX,stageX的区别
查看>>
链表 创建 插入 删除 查找 合并
查看>>
Matlab将变量写入文本文件
查看>>
在JS中创建对象的方式
查看>>
ognl.NoSuchPropertyException(没有对应属性异常)
查看>>
京东金融面试
查看>>
【面试题】反转单链表
查看>>
Jsp学习(四)
查看>>
SQL Cookbook:查询结果排序
查看>>
jQuery DOM节点操作 - 父节点、子节点、兄弟节点
查看>>
Cookie 和 LocalStorage
查看>>
【博客美化小妙招】你希望有一个可爱的看板娘吗?
查看>>