datapipeline首先出现在videoanalyst/data/adaptor_dataset.py,这里定义了dataset的创建,继承于torch.utils.data.dataset
self.datapipeline = datapipeline_builder.build(self.task, self.cfg, seed=seed)在videoanalyst/data/datapipeline/builder.py中定义了datapipeline的整体实现。分别用到了
samplertransformestarget sampler = build_sampler(task, cfg.sampler, seed=seed) transformers = build_transformer(task, cfg.transformer, seed=seed) target = build_target(task, cfg.target) pipeline = [] pipeline.extend(transformers) pipeline.append(target) cfg = cfg.datapipeline name = cfg.name module = MODULES[name](sampler, pipeline)在videoanalyst/data/sampler/builder.py定义了siampler的整体框架,具体又由submodules组成。 submodules分成datasets和filter。
submodules_cfg = cfg.submodules dataset_cfg = submodules_cfg.dataset datasets = dataset_builder.build(task, dataset_cfg) if submodules_cfg.filter.name != "": filter_cfg = submodules_cfg.filter data_filter = filter_builder.build(task, filter_cfg) else: data_filter = None name = cfg.name module = MODULES[name](datasets, seed=seed, data_filter=data_filter)在videoanalyst/data/sampler/sampler_impl/track_pair_sampler.py定义了siampler的实现,具体返回一对样本对。
sampled_data = dict( data1=data1, data2=data2, is_negative_pair=is_negative_pair, )例如videoanalyst/data/dataset/dataset_impl/lasot.py定义了dataset的返回值。
from videoanalyst.evaluation.got_benchmark.datasets import LaSOT def update_params(self): r""" an interface for update params """ dataset_root = osp.realpath(self._hyper_params["dataset_root"]) subset = self._hyper_params["subset"] check_integrity = self._hyper_params["check_integrity"] self._state["dataset"] = LaSOT(dataset_root, subset=subset, check_integrity=check_integrity) def __getitem__(self, item: int) -> Dict: img_files, anno = self._state["dataset"][item] anno = xywh2xyxy(anno) sequence_data = dict(image=img_files, anno=anno) return sequence_data def __len__(self): return len(self._state["dataset"])在videoanalyst/evaluation/got_benchmark/datasets/lasot.py 定义了__getitem__返回值:
return img_files, anno也就是sequence_data的返回值。
在videoanalyst/data/filter/filter_impl/track_pair_filter.py中主要实现对data数据的判别处理,如果没有数据则True,如果目标过小过大长宽比过大则False。
def __call__(self, data: Dict) -> bool: if data is None: return True im, anno = data["image"], data["anno"] if self._hyper_params["target_type"] == "bbox": bbox = xyxy2xywh(anno) elif self._hyper_params["target_type"] == "mask": bbox = cv2.boundingRect(anno) else: logger.error("unspported target type {} in filter".format( self._hyper_params["target_type"])) exit() filter_flag = filter_unreasonable_training_boxes( im, bbox, self._hyper_params) return filter_flag此时我们再回头看看Siampler具体做什么
Sample procedure: __getitem__ │ ├── _sample_track_pair #返回一对图片和标注,以dict封装 │ ├── _sample_dataset #随机选择dataset │ ├── _sample_sequence_from_dataset #随机选择seq │ ├── _sample_track_frame_from_static_image #在图片数据集,如COCO │ └── _sample_track_frame_from_sequence #在视频数据集, 如LaSOT │ └── _sample_pair_idx_pair_within_max_diff #在最大间隔范围内选一对图片 │ └── _sample_track_frame ├── _sample_dataset ├── _sample_sequence_from_dataset ├── _sample_track_frame_from_static_image (x2) └── _sample_track_pair_from_sequence实现如下
def __getitem__(self, item) -> dict: is_negative_pair = (self._state["rng"].rand() < self._hyper_params["negative_pair_ratio"]) data1 = data2 = None sample_try_num = 0 while self.data_filter(data1) or self.data_filter(data2): if is_negative_pair: data1 = self._sample_track_frame() data2 = self._sample_track_frame() else: data1, data2 = self._sample_track_pair() data1["image"] = load_image(data1["image"]) data2["image"] = load_image(data2["image"]) sample_try_num += 1 sampled_data = dict( data1=data1, data2=data2, is_negative_pair=is_negative_pair, ) return sampled_data def _sample_track_pair(self) -> Tuple[Dict, Dict]: dataset_idx, dataset = self._sample_dataset() sequence_data = self._sample_sequence_from_dataset(dataset) len_seq = self._get_len_seq(sequence_data) if len_seq == 1 and not isinstance(sequence_data["anno"][0], list): # static image dataset data1 = self._sample_track_frame_from_static_image(sequence_data) data2 = deepcopy(data1) else: # video dataset data1, data2 = self._sample_track_pair_from_sequence( sequence_data, self._state["max_diffs"][dataset_idx]) return data1, data2 def _sample_track_frame(self) -> Dict: _, dataset = self._sample_dataset() sequence_data = self._sample_sequence_from_dataset(dataset) len_seq = self._get_len_seq(sequence_data) if len_seq == 1: # static image dataset data_frame = self._sample_track_frame_from_static_image( sequence_data) else: # video dataset data_frame = self._sample_track_frame_from_sequence(sequence_data) return data_frame总归返回的是sampled_data = dict( data1=data1, data2=data2, is_negative_pair=is_negative_pair, )
videoanalyst/data/transformer/transformer_impl/random_crop_transformer.py中实现原始image的random crop,仍保存在sampled_data
def __call__(self, sampled_data: Dict) -> Dict: r""" sampled_data: Dict() input data Dict(data1=Dict(image, anno), data2=Dict(image, anno)) """ data1 = sampled_data["data1"] data2 = sampled_data["data2"] im_temp, bbox_temp = data1["image"], data1["anno"] im_curr, bbox_curr = data2["image"], data2["anno"] im_z, bbox_z, im_x, bbox_x, _, _ = crop_track_pair( im_temp, bbox_temp, im_curr, bbox_curr, config=self._hyper_params, rng=self._state["rng"]) sampled_data["data1"] = dict(image=im_z, anno=bbox_z) sampled_data["data2"] = dict(image=im_x, anno=bbox_x) return sampled_datavideoanalyst/data/target/target_impl/densebox_target.py在transformer的基础上,生成三种label。
def __call__(self, sampled_data: Dict) -> Dict: data_z = sampled_data["data1"] im_z, bbox_z = data_z["image"], data_z["anno"] data_x = sampled_data["data2"] im_x, bbox_x = data_x["image"], data_x["anno"] is_negative_pair = sampled_data["is_negative_pair"] # input tensor im_z = im_z.transpose(2, 0, 1) im_x = im_x.transpose(2, 0, 1) # training target cls_label, ctr_label, box_label = make_densebox_target( bbox_x.reshape(1, 4), self._hyper_params) if is_negative_pair: cls_label[cls_label == 0] = -1 cls_label[cls_label == 1] = 0 training_data = dict( im_z=im_z, im_x=im_x, bbox_z=bbox_z, bbox_x=bbox_x, cls_gt=cls_label, ctr_gt=ctr_label, box_gt=box_label, is_negative_pair=int(is_negative_pair), ) #training_data = super().__call__(training_data) return training_data此时回头看datapipeline,其实是siampler从dataset中选出图片对,transformer根据x-z-size crop,target生成label。
最后回到全部的起点,在videoanalyst/data/builder.py定义了自己的dataset类,videoanalyst/data/adaptor_dataset.py定义了类的实现,才出现了datapipeline。 可以看到返回值是
def __getitem__(self, item): if self.datapipeline is None: # build datapipeline with random seed the first time when __getitem__ is called # usually, dataset is already spawned (into subprocess) at this point. seed = (torch.initial_seed() + item * self._SEED_STEP + self.ext_seed * self._EXT_SEED_STEP) % self._SEED_DIVIDER self.datapipeline = datapipeline_builder.build(self.task, self.cfg, seed=seed) logger.info("AdaptorDataset #%d built datapipeline with seed=%d" % (item, seed)) training_data = self.datapipeline[item] return training_data