Spaces:
Sleeping
Sleeping
Yufan Chen
commited on
Commit
·
68ebefc
1
Parent(s):
b36858c
Update basic element for README.md, PubLayNet-P Dataset download link, perturbations implementation related code, RoDLA model evaluation code and RoDLA model code.
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +150 -3
- model/README.md +137 -0
- model/configs/_base_/datasets/doclaynet.py +49 -0
- model/configs/_base_/datasets/m6doc.py +49 -0
- model/configs/_base_/datasets/publaynet.py +50 -0
- model/configs/_base_/default_runtime.py +16 -0
- model/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py +196 -0
- model/configs/_base_/models/cascade_mask_rcnn_r50_fpn_crowdhuman.py +183 -0
- model/configs/_base_/models/cascade_rcnn_r50_fpn.py +179 -0
- model/configs/_base_/models/fast_rcnn_r50_fpn.py +62 -0
- model/configs/_base_/models/faster_rcnn_r50_caffe_c4.py +114 -0
- model/configs/_base_/models/faster_rcnn_r50_caffe_dc5.py +105 -0
- model/configs/_base_/models/faster_rcnn_r50_fpn.py +108 -0
- model/configs/_base_/models/mask_rcnn_convnext_fpn.py +128 -0
- model/configs/_base_/models/mask_rcnn_r50_caffe_c4.py +125 -0
- model/configs/_base_/models/mask_rcnn_r50_fpn.py +120 -0
- model/configs/_base_/models/retinanet_r50_fpn.py +60 -0
- model/configs/_base_/models/rpn_r50_caffe_c4.py +58 -0
- model/configs/_base_/models/rpn_r50_fpn.py +58 -0
- model/configs/_base_/models/ssd300.py +56 -0
- model/configs/_base_/schedules/schedule_0.5x.py +11 -0
- model/configs/_base_/schedules/schedule_1x.py +11 -0
- model/configs/_base_/schedules/schedule_2x.py +11 -0
- model/configs/_base_/schedules/schedule_3x.py +11 -0
- model/configs/_base_/schedules/schedule_4x.py +11 -0
- model/configs/doclaynet/rodla_internimage_xl_3x_doclaynet.py +178 -0
- model/configs/m6doc/rodla_internimage_xl_2x_m6doc.py +187 -0
- model/configs/publaynet/rodla_internimage_xl_2x_publaynet.py +177 -0
- model/deploy.py +310 -0
- model/dist_test.sh +9 -0
- model/dist_train.sh +9 -0
- model/get_flops.py +120 -0
- model/image_demo.py +61 -0
- model/mmcv_custom/__init__.py +11 -0
- model/mmcv_custom/checkpoint.py +487 -0
- model/mmcv_custom/custom_layer_decay_optimizer_constructor.py +142 -0
- model/mmdet_custom/__init__.py +8 -0
- model/mmdet_custom/datasets/__init__.py +9 -0
- model/mmdet_custom/datasets/doclaynet.py +527 -0
- model/mmdet_custom/datasets/m6doc.py +529 -0
- model/mmdet_custom/datasets/publaynet.py +528 -0
- model/mmdet_custom/models/__init__.py +10 -0
- model/mmdet_custom/models/backbones/__init__.py +11 -0
- model/mmdet_custom/models/backbones/beit.py +601 -0
- model/mmdet_custom/models/backbones/intern_image.py +702 -0
- model/mmdet_custom/models/backbones/swin_transformer.py +648 -0
- model/mmdet_custom/models/dense_heads/__init__.py +11 -0
- model/mmdet_custom/models/dense_heads/deformable_detr_head.py +332 -0
- model/mmdet_custom/models/dense_heads/detr_head.py +954 -0
- model/mmdet_custom/models/dense_heads/dino_head.py +364 -0
README.md
CHANGED
|
@@ -1,3 +1,150 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">📓RoDLA</h1>
|
| 2 |
+
<h3 align="center">Benchmarking the Robustness of Document Layout Analysis Models (CVPR'24)</h3>
|
| 3 |
+
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href="https://arxiv.org/pdf/2403.14442.pdf">
|
| 6 |
+
<img src="https://img.shields.io/badge/PDF-arXiv-brightgreen" /></a>
|
| 7 |
+
<a href="https://yufanchen96.github.io/projects/RoDLA/">
|
| 8 |
+
<img src="https://img.shields.io/badge/Project-Homepage-red" /></a>
|
| 9 |
+
<a href="https://pytorch.org/get-started/previous-versions/#linux-and-windows">
|
| 10 |
+
<img src="https://img.shields.io/badge/Framework-PyTorch%201.10.2-orange" /></a>
|
| 11 |
+
<a href="https://github.com/yufanchen96/RoDLA/blob/main/LICENSE">
|
| 12 |
+
<img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" /></a>
|
| 13 |
+
<img alt="visits" src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fyufanchen96%2FRoDLA&count_bg=%23A53DC8&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visits&edge_flat=false">
|
| 14 |
+
</p>
|
| 15 |
+
|
| 16 |
+
## 🏡 Project Homepage
|
| 17 |
+
|
| 18 |
+
This is the official repository for our CVPR 2024 paper **RoDLA:Benchmarking the Robustness of Document Layout Analysis Models**. For more result and benchmarking details, please visit our [project homepage](https://yufanchen96.github.io/projects/RoDLA/).
|
| 19 |
+
|
| 20 |
+
## 🔎 Introduction
|
| 21 |
+
We introduce **RoDLA** that aims to benchmark the robustness of Document Layout Analysis (DLA) models. RoDLA is a large-scale benchmark that contains **450,000**+ documents with diverse layouts and contents.
|
| 22 |
+
We also provide a set of evaluation metrics to facilitate the comparison of different DLA models. We hope that RoDLA can serve as a standard benchmark for the robustness evaluation of DLA models.
|
| 23 |
+
<p align="center">
|
| 24 |
+
<img src="assets/benchmark_v2.png" width="360" />
|
| 25 |
+
</p>
|
| 26 |
+
|
| 27 |
+
## 📝 Catalog
|
| 28 |
+
- [x] Perturbation Benchmark Dataset
|
| 29 |
+
- [x] PubLayNet-P
|
| 30 |
+
- [ ] DocLayNet-P
|
| 31 |
+
- [ ] M<sup>6</sup>Doc-P
|
| 32 |
+
- [x] Perturbation Generation and Evaluation Code
|
| 33 |
+
- [ ] RoDLA Model Checkpoints
|
| 34 |
+
- [ ] RoDLA Model Training Code
|
| 35 |
+
- [x] RoDLA Model Evaluation Code
|
| 36 |
+
|
| 37 |
+
## 📦 Installation
|
| 38 |
+
**1. Clone the repository**
|
| 39 |
+
```
|
| 40 |
+
git clone https://github.com/yufanchen96/RoDLA.git
|
| 41 |
+
cd RoDLA
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
**2. Create a conda virtual environment**
|
| 45 |
+
```
|
| 46 |
+
# create virtual environment
|
| 47 |
+
conda create -n RoDLA python=3.7 -y
|
| 48 |
+
conda activate RoDLA
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
**3. Install benchmark dependencies**
|
| 52 |
+
|
| 53 |
+
- Install Basic Dependencies
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
pip install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
|
| 57 |
+
python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
| 58 |
+
pip install -U openmim
|
| 59 |
+
mim install mmcv-full==1.5.0
|
| 60 |
+
pip install timm==0.6.11 mmdet==2.28.1
|
| 61 |
+
pip install Pillow==9.5.0
|
| 62 |
+
pip install opencv-python termcolor yacs pyyaml scipy
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
- Install ocrodeg Dependencies
|
| 66 |
+
```
|
| 67 |
+
cd ./ocrodeg
|
| 68 |
+
pip install -e .
|
| 69 |
+
cd ..
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
- Compile CUDA operators
|
| 73 |
+
```
|
| 74 |
+
cd ./model/ops_dcnv3
|
| 75 |
+
sh ./make.sh
|
| 76 |
+
python test.py
|
| 77 |
+
cd ../..
|
| 78 |
+
```
|
| 79 |
+
- You can also install the operator using .whl files
|
| 80 |
+
|
| 81 |
+
[DCNv3-1.0-whl](https://github.com/OpenGVLab/InternImage/releases/tag/whl_files)
|
| 82 |
+
|
| 83 |
+
## 📂 Dataset Preparation
|
| 84 |
+
|
| 85 |
+
### RoDLA Benchmark Dataset Preparation
|
| 86 |
+
Download the RoDLA dataset from Google Driver to the desired root directory.
|
| 87 |
+
- [PubLayNet-P](https://drive.google.com/file/d/1bfjaxb5fAjU7sFqtM3GfNYm0ynrB5Vwo/view?usp=drive_link)
|
| 88 |
+
- [DocLayNet-P]()
|
| 89 |
+
- [M<sup>6</sup>Doc-P]()
|
| 90 |
+
|
| 91 |
+
### Self-generated Perturbation Dataset Preparation
|
| 92 |
+
Prepare the dataset as follows by yourself:
|
| 93 |
+
```
|
| 94 |
+
cd ./perturbation
|
| 95 |
+
|
| 96 |
+
python apply_perturbation.py \
|
| 97 |
+
--dataset_dir ./publaynet/val \
|
| 98 |
+
--json_dir ./publaynet/val.json \
|
| 99 |
+
--dataset_name PubLayNet-P \
|
| 100 |
+
--output_dir ./PubLayNet-P \
|
| 101 |
+
--pert_method all \
|
| 102 |
+
--background_folder ./background \
|
| 103 |
+
--metric all
|
| 104 |
+
```
|
| 105 |
+
### Dataset Structure
|
| 106 |
+
|
| 107 |
+
After dataset preparation, the perturbed dataset structure would be:
|
| 108 |
+
```
|
| 109 |
+
.desired_root
|
| 110 |
+
└── PubLayNet-P
|
| 111 |
+
├── Background
|
| 112 |
+
│ ├── Background_1
|
| 113 |
+
│ │ ├── psnr.json
|
| 114 |
+
│ │ ├── ms_ssim.json
|
| 115 |
+
│ │ ├── cw_ssim.json
|
| 116 |
+
│ │ ├── val.json
|
| 117 |
+
│ │ ├── val
|
| 118 |
+
│ │ │ ├── PMC538274_00004.jpg
|
| 119 |
+
...
|
| 120 |
+
│ ├── Background_2
|
| 121 |
+
...
|
| 122 |
+
├── Rotation
|
| 123 |
+
...
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
## 🚀 Quick Start
|
| 127 |
+
|
| 128 |
+
### Evaluate the RoDLA model
|
| 129 |
+
```
|
| 130 |
+
cd ./model
|
| 131 |
+
python -u test.py configs/publaynet/rodla_internimage_xl_2x_publaynet.py \
|
| 132 |
+
checkpoint_dir/rodla_internimage_xl_2x_publaynet.pth \
|
| 133 |
+
--work-dir result/rodla_internimage_publaynet/Speckle_1 \
|
| 134 |
+
--eval bbox \
|
| 135 |
+
--cfg-options data.test.ann_file='PubLayNet-P/Speckle/Speckle/val.json' \
|
| 136 |
+
data.test.img_prefix='PubLayNet-P/Speckle/Speckle/val/'
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
## 🌳 Citation
|
| 140 |
+
If you find this code useful for your research, please consider citing:
|
| 141 |
+
```
|
| 142 |
+
@misc{chen2024rodla,
|
| 143 |
+
title={RoDLA: Benchmarking the Robustness of Document Layout Analysis Models},
|
| 144 |
+
author={Yufan Chen and Jiaming Zhang and Kunyu Peng and Junwei Zheng and Ruiping Liu and Philip Torr and Rainer Stiefelhagen},
|
| 145 |
+
year={2024},
|
| 146 |
+
eprint={2403.14442},
|
| 147 |
+
archivePrefix={arXiv},
|
| 148 |
+
primaryClass={cs.CV}
|
| 149 |
+
}
|
| 150 |
+
```
|
model/README.md
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# InternImage for Object Detection
|
| 2 |
+
|
| 3 |
+
This folder contains the implementation of the InternImage for object detection.
|
| 4 |
+
|
| 5 |
+
Our detection code is developed on top of [MMDetection v2.28.1](https://github.com/open-mmlab/mmdetection/tree/v2.28.1).
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## Usage
|
| 9 |
+
|
| 10 |
+
### Install
|
| 11 |
+
|
| 12 |
+
- Clone this repo:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
git clone https://github.com/OpenGVLab/InternImage.git
|
| 16 |
+
cd InternImage
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
- Create a conda virtual environment and activate it:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
conda create -n internimage python=3.7 -y
|
| 23 |
+
conda activate internimage
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
- Install `CUDA>=10.2` with `cudnn>=7` following
|
| 27 |
+
the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
|
| 28 |
+
- Install `PyTorch>=1.10.0` and `torchvision>=0.9.0` with `CUDA>=10.2`:
|
| 29 |
+
|
| 30 |
+
For examples, to install torch==1.11 with CUDA==11.3:
|
| 31 |
+
```bash
|
| 32 |
+
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
- Install `timm==0.6.11` and `mmcv-full==1.5.0`:
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
pip install -U openmim
|
| 39 |
+
mim install mmcv-full==1.5.0
|
| 40 |
+
pip install timm==0.6.11 mmdet==2.28.1
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
- Install other requirements:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install opencv-python termcolor yacs pyyaml scipy
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
- Compile CUDA operators
|
| 50 |
+
```bash
|
| 51 |
+
cd ./ops_dcnv3
|
| 52 |
+
sh ./make.sh
|
| 53 |
+
# unit test (should see all checking is True)
|
| 54 |
+
python test.py
|
| 55 |
+
```
|
| 56 |
+
- You can also install the operator using .whl files
|
| 57 |
+
|
| 58 |
+
[DCNv3-1.0-whl](https://github.com/OpenGVLab/InternImage/releases/tag/whl_files)
|
| 59 |
+
|
| 60 |
+
### Data Preparation
|
| 61 |
+
|
| 62 |
+
Prepare COCO according to the guidelines in [MMDetection v2.28.1](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/1_exist_data_model.md).
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
### Evaluation
|
| 66 |
+
|
| 67 |
+
To evaluate our `InternImage` on COCO val, run:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
sh dist_test.sh <config-file> <checkpoint> <gpu-num> --eval bbox segm
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
For example, to evaluate the `InternImage-T` with a single GPU:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
python test.py configs/coco/mask_rcnn_internimage_t_fpn_1x_coco.py checkpoint_dir/det/mask_rcnn_internimage_t_fpn_1x_coco.pth --eval bbox segm
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
For example, to evaluate the `InternImage-B` with a single node with 8 GPUs:
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
sh dist_test.sh configs/coco/mask_rcnn_internimage_b_fpn_1x_coco.py checkpoint_dir/det/mask_rcnn_internimage_b_fpn_1x_coco.py 8 --eval bbox segm
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Training on COCO
|
| 86 |
+
|
| 87 |
+
To train an `InternImage` on COCO, run:
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
sh dist_train.sh <config-file> <gpu-num>
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
For example, to train `InternImage-T` with 8 GPU on 1 node, run:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
sh dist_train.sh configs/coco/mask_rcnn_internimage_t_fpn_1x_coco.py 8
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### Manage Jobs with Slurm
|
| 100 |
+
|
| 101 |
+
For example, to train `InternImage-L` with 32 GPU on 4 node, run:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
GPUS=32 sh slurm_train.sh <partition> <job-name> configs/coco/cascade_internimage_xl_fpn_3x_coco.py work_dirs/cascade_internimage_xl_fpn_3x_coco
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### Export
|
| 108 |
+
|
| 109 |
+
To export a detection model from PyTorch to TensorRT, run:
|
| 110 |
+
```shell
|
| 111 |
+
MODEL="model_name"
|
| 112 |
+
CKPT_PATH="/path/to/model/ckpt.pth"
|
| 113 |
+
|
| 114 |
+
python deploy.py \
|
| 115 |
+
"./deploy/configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py" \
|
| 116 |
+
"./configs/coco/${MODEL}.py" \
|
| 117 |
+
"${CKPT_PATH}" \
|
| 118 |
+
"./deploy/demo.jpg" \
|
| 119 |
+
--work-dir "./work_dirs/mmdet/instance-seg/${MODEL}" \
|
| 120 |
+
--device cuda \
|
| 121 |
+
--dump-info
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
For example, to export `mask_rcnn_internimage_t_fpn_1x_coco` from PyTorch to TensorRT, run:
|
| 125 |
+
```shell
|
| 126 |
+
MODEL="mask_rcnn_internimage_t_fpn_1x_coco"
|
| 127 |
+
CKPT_PATH="/path/to/model/ckpt/mask_rcnn_internimage_t_fpn_1x_coco.pth"
|
| 128 |
+
|
| 129 |
+
python deploy.py \
|
| 130 |
+
"./deploy/configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py" \
|
| 131 |
+
"./configs/coco/${MODEL}.py" \
|
| 132 |
+
"${CKPT_PATH}" \
|
| 133 |
+
"./deploy/demo.jpg" \
|
| 134 |
+
--work-dir "./work_dirs/mmdet/instance-seg/${MODEL}" \
|
| 135 |
+
--device cuda \
|
| 136 |
+
--dump-info
|
| 137 |
+
```
|
model/configs/_base_/datasets/doclaynet.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'DocLayNetDataset'
|
| 3 |
+
data_root = 'data/DocLayNet/'
|
| 4 |
+
img_norm_cfg = dict(
|
| 5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 6 |
+
train_pipeline = [
|
| 7 |
+
dict(type='LoadImageFromFile'),
|
| 8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
| 9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
| 10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 11 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 12 |
+
dict(type='Pad', size_divisor=32),
|
| 13 |
+
dict(type='DefaultFormatBundle'),
|
| 14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
| 15 |
+
]
|
| 16 |
+
test_pipeline = [
|
| 17 |
+
dict(type='LoadImageFromFile'),
|
| 18 |
+
dict(
|
| 19 |
+
type='MultiScaleFlipAug',
|
| 20 |
+
img_scale=(1333, 800),
|
| 21 |
+
flip=False,
|
| 22 |
+
transforms=[
|
| 23 |
+
dict(type='Resize', keep_ratio=True),
|
| 24 |
+
dict(type='RandomFlip'),
|
| 25 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 26 |
+
dict(type='Pad', size_divisor=32),
|
| 27 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 28 |
+
dict(type='Collect', keys=['img']),
|
| 29 |
+
])
|
| 30 |
+
]
|
| 31 |
+
data = dict(
|
| 32 |
+
samples_per_gpu=2,
|
| 33 |
+
workers_per_gpu=2,
|
| 34 |
+
train=dict(
|
| 35 |
+
type=dataset_type,
|
| 36 |
+
ann_file=data_root + 'COCO/train.json',
|
| 37 |
+
img_prefix=data_root + 'PNG/',
|
| 38 |
+
pipeline=train_pipeline),
|
| 39 |
+
val=dict(
|
| 40 |
+
type=dataset_type,
|
| 41 |
+
ann_file=data_root + 'COCO/val.json',
|
| 42 |
+
img_prefix=data_root + 'PNG/',
|
| 43 |
+
pipeline=test_pipeline),
|
| 44 |
+
test=dict(
|
| 45 |
+
type=dataset_type,
|
| 46 |
+
ann_file=data_root + 'COCO/val.json',
|
| 47 |
+
img_prefix=data_root + 'PNG/',
|
| 48 |
+
pipeline=test_pipeline))
|
| 49 |
+
evaluation = dict(metric=['bbox', 'segm'], classwise=True)
|
model/configs/_base_/datasets/m6doc.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'M6DocDataset'
|
| 3 |
+
data_root = 'data/M6Doc/'
|
| 4 |
+
img_norm_cfg = dict(
|
| 5 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 6 |
+
train_pipeline = [
|
| 7 |
+
dict(type='LoadImageFromFile'),
|
| 8 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
| 9 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
| 10 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 11 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 12 |
+
dict(type='Pad', size_divisor=32),
|
| 13 |
+
dict(type='DefaultFormatBundle'),
|
| 14 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
| 15 |
+
]
|
| 16 |
+
test_pipeline = [
|
| 17 |
+
dict(type='LoadImageFromFile'),
|
| 18 |
+
dict(
|
| 19 |
+
type='MultiScaleFlipAug',
|
| 20 |
+
img_scale=(1333, 800),
|
| 21 |
+
flip=False,
|
| 22 |
+
transforms=[
|
| 23 |
+
dict(type='Resize', keep_ratio=True),
|
| 24 |
+
dict(type='RandomFlip'),
|
| 25 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 26 |
+
dict(type='Pad', size_divisor=32),
|
| 27 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 28 |
+
dict(type='Collect', keys=['img']),
|
| 29 |
+
])
|
| 30 |
+
]
|
| 31 |
+
data = dict(
|
| 32 |
+
samples_per_gpu=2,
|
| 33 |
+
workers_per_gpu=2,
|
| 34 |
+
train=dict(
|
| 35 |
+
type=dataset_type,
|
| 36 |
+
ann_file=data_root + 'annotations/instances_train2017.json',
|
| 37 |
+
img_prefix=data_root + 'train2017/',
|
| 38 |
+
pipeline=train_pipeline),
|
| 39 |
+
val=dict(
|
| 40 |
+
type=dataset_type,
|
| 41 |
+
ann_file=data_root + 'annotations/instances_val2017.json',
|
| 42 |
+
img_prefix=data_root + 'val2017/',
|
| 43 |
+
pipeline=test_pipeline),
|
| 44 |
+
test=dict(
|
| 45 |
+
type=dataset_type,
|
| 46 |
+
ann_file=data_root + 'annotations/instances_val2017.json',
|
| 47 |
+
img_prefix=data_root + 'val2017/',
|
| 48 |
+
pipeline=test_pipeline))
|
| 49 |
+
evaluation = dict(metric=['bbox', 'segm'], classwise=True)
|
model/configs/_base_/datasets/publaynet.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset settings
|
| 2 |
+
dataset_type = 'PubLayNetDataset'
|
| 3 |
+
data_root = 'data/PubLayNet/publaynet/'
|
| 4 |
+
classes = ('text', 'title', 'list', 'table', 'figure',)
|
| 5 |
+
img_norm_cfg = dict(
|
| 6 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 7 |
+
train_pipeline = [
|
| 8 |
+
dict(type='LoadImageFromFile'),
|
| 9 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
| 10 |
+
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
|
| 11 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 12 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 13 |
+
dict(type='Pad', size_divisor=32),
|
| 14 |
+
dict(type='DefaultFormatBundle'),
|
| 15 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
| 16 |
+
]
|
| 17 |
+
test_pipeline = [
|
| 18 |
+
dict(type='LoadImageFromFile'),
|
| 19 |
+
dict(
|
| 20 |
+
type='MultiScaleFlipAug',
|
| 21 |
+
img_scale=(1333, 800),
|
| 22 |
+
flip=False,
|
| 23 |
+
transforms=[
|
| 24 |
+
dict(type='Resize', keep_ratio=True),
|
| 25 |
+
dict(type='RandomFlip'),
|
| 26 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 27 |
+
dict(type='Pad', size_divisor=32),
|
| 28 |
+
dict(type='ImageToTensor', keys=['img']),
|
| 29 |
+
dict(type='Collect', keys=['img']),
|
| 30 |
+
])
|
| 31 |
+
]
|
| 32 |
+
data = dict(
|
| 33 |
+
samples_per_gpu=2,
|
| 34 |
+
workers_per_gpu=2,
|
| 35 |
+
train=dict(
|
| 36 |
+
type=dataset_type,
|
| 37 |
+
ann_file=data_root + 'train.json',
|
| 38 |
+
img_prefix=data_root + 'train/',
|
| 39 |
+
pipeline=train_pipeline),
|
| 40 |
+
val=dict(
|
| 41 |
+
type=dataset_type,
|
| 42 |
+
ann_file=data_root + 'val.json',
|
| 43 |
+
img_prefix=data_root + 'val/',
|
| 44 |
+
pipeline=test_pipeline),
|
| 45 |
+
test=dict(
|
| 46 |
+
type=dataset_type,
|
| 47 |
+
ann_file=data_root + 'val.json',
|
| 48 |
+
img_prefix=data_root + 'val/',
|
| 49 |
+
pipeline=test_pipeline))
|
| 50 |
+
evaluation = dict(metric=['bbox', 'segm'], classwise=True)
|
model/configs/_base_/default_runtime.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_config = dict(interval=0)
|
| 2 |
+
# yapf:disable
|
| 3 |
+
log_config = dict(
|
| 4 |
+
interval=50,
|
| 5 |
+
hooks=[
|
| 6 |
+
dict(type='TextLoggerHook'),
|
| 7 |
+
# dict(type='TensorboardLoggerHook')
|
| 8 |
+
])
|
| 9 |
+
# yapf:enable
|
| 10 |
+
custom_hooks = [dict(type='NumClassCheckHook')]
|
| 11 |
+
|
| 12 |
+
dist_params = dict(backend='nccl')
|
| 13 |
+
log_level = 'INFO'
|
| 14 |
+
load_from = None
|
| 15 |
+
resume_from = None
|
| 16 |
+
workflow = [('train', 1)]
|
model/configs/_base_/models/cascade_mask_rcnn_r50_fpn.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='CascadeRCNN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
num_outs=5),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=256,
|
| 22 |
+
feat_channels=256,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[8],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[4, 8, 16, 32, 64]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='CascadeRoIHead',
|
| 37 |
+
num_stages=3,
|
| 38 |
+
stage_loss_weights=[1, 0.5, 0.25],
|
| 39 |
+
bbox_roi_extractor=dict(
|
| 40 |
+
type='SingleRoIExtractor',
|
| 41 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 42 |
+
out_channels=256,
|
| 43 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 44 |
+
bbox_head=[
|
| 45 |
+
dict(
|
| 46 |
+
type='Shared2FCBBoxHead',
|
| 47 |
+
in_channels=256,
|
| 48 |
+
fc_out_channels=1024,
|
| 49 |
+
roi_feat_size=7,
|
| 50 |
+
num_classes=80,
|
| 51 |
+
bbox_coder=dict(
|
| 52 |
+
type='DeltaXYWHBBoxCoder',
|
| 53 |
+
target_means=[0., 0., 0., 0.],
|
| 54 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 55 |
+
reg_class_agnostic=True,
|
| 56 |
+
loss_cls=dict(
|
| 57 |
+
type='CrossEntropyLoss',
|
| 58 |
+
use_sigmoid=False,
|
| 59 |
+
loss_weight=1.0),
|
| 60 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 61 |
+
loss_weight=1.0)),
|
| 62 |
+
dict(
|
| 63 |
+
type='Shared2FCBBoxHead',
|
| 64 |
+
in_channels=256,
|
| 65 |
+
fc_out_channels=1024,
|
| 66 |
+
roi_feat_size=7,
|
| 67 |
+
num_classes=80,
|
| 68 |
+
bbox_coder=dict(
|
| 69 |
+
type='DeltaXYWHBBoxCoder',
|
| 70 |
+
target_means=[0., 0., 0., 0.],
|
| 71 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
| 72 |
+
reg_class_agnostic=True,
|
| 73 |
+
loss_cls=dict(
|
| 74 |
+
type='CrossEntropyLoss',
|
| 75 |
+
use_sigmoid=False,
|
| 76 |
+
loss_weight=1.0),
|
| 77 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 78 |
+
loss_weight=1.0)),
|
| 79 |
+
dict(
|
| 80 |
+
type='Shared2FCBBoxHead',
|
| 81 |
+
in_channels=256,
|
| 82 |
+
fc_out_channels=1024,
|
| 83 |
+
roi_feat_size=7,
|
| 84 |
+
num_classes=80,
|
| 85 |
+
bbox_coder=dict(
|
| 86 |
+
type='DeltaXYWHBBoxCoder',
|
| 87 |
+
target_means=[0., 0., 0., 0.],
|
| 88 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
| 89 |
+
reg_class_agnostic=True,
|
| 90 |
+
loss_cls=dict(
|
| 91 |
+
type='CrossEntropyLoss',
|
| 92 |
+
use_sigmoid=False,
|
| 93 |
+
loss_weight=1.0),
|
| 94 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
| 95 |
+
],
|
| 96 |
+
mask_roi_extractor=dict(
|
| 97 |
+
type='SingleRoIExtractor',
|
| 98 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
| 99 |
+
out_channels=256,
|
| 100 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 101 |
+
mask_head=dict(
|
| 102 |
+
type='FCNMaskHead',
|
| 103 |
+
num_convs=4,
|
| 104 |
+
in_channels=256,
|
| 105 |
+
conv_out_channels=256,
|
| 106 |
+
num_classes=80,
|
| 107 |
+
loss_mask=dict(
|
| 108 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
| 109 |
+
# model training and testing settings
|
| 110 |
+
train_cfg=dict(
|
| 111 |
+
rpn=dict(
|
| 112 |
+
assigner=dict(
|
| 113 |
+
type='MaxIoUAssigner',
|
| 114 |
+
pos_iou_thr=0.7,
|
| 115 |
+
neg_iou_thr=0.3,
|
| 116 |
+
min_pos_iou=0.3,
|
| 117 |
+
match_low_quality=True,
|
| 118 |
+
ignore_iof_thr=-1),
|
| 119 |
+
sampler=dict(
|
| 120 |
+
type='RandomSampler',
|
| 121 |
+
num=256,
|
| 122 |
+
pos_fraction=0.5,
|
| 123 |
+
neg_pos_ub=-1,
|
| 124 |
+
add_gt_as_proposals=False),
|
| 125 |
+
allowed_border=0,
|
| 126 |
+
pos_weight=-1,
|
| 127 |
+
debug=False),
|
| 128 |
+
rpn_proposal=dict(
|
| 129 |
+
nms_pre=2000,
|
| 130 |
+
max_per_img=2000,
|
| 131 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 132 |
+
min_bbox_size=0),
|
| 133 |
+
rcnn=[
|
| 134 |
+
dict(
|
| 135 |
+
assigner=dict(
|
| 136 |
+
type='MaxIoUAssigner',
|
| 137 |
+
pos_iou_thr=0.5,
|
| 138 |
+
neg_iou_thr=0.5,
|
| 139 |
+
min_pos_iou=0.5,
|
| 140 |
+
match_low_quality=False,
|
| 141 |
+
ignore_iof_thr=-1),
|
| 142 |
+
sampler=dict(
|
| 143 |
+
type='RandomSampler',
|
| 144 |
+
num=512,
|
| 145 |
+
pos_fraction=0.25,
|
| 146 |
+
neg_pos_ub=-1,
|
| 147 |
+
add_gt_as_proposals=True),
|
| 148 |
+
mask_size=28,
|
| 149 |
+
pos_weight=-1,
|
| 150 |
+
debug=False),
|
| 151 |
+
dict(
|
| 152 |
+
assigner=dict(
|
| 153 |
+
type='MaxIoUAssigner',
|
| 154 |
+
pos_iou_thr=0.6,
|
| 155 |
+
neg_iou_thr=0.6,
|
| 156 |
+
min_pos_iou=0.6,
|
| 157 |
+
match_low_quality=False,
|
| 158 |
+
ignore_iof_thr=-1),
|
| 159 |
+
sampler=dict(
|
| 160 |
+
type='RandomSampler',
|
| 161 |
+
num=512,
|
| 162 |
+
pos_fraction=0.25,
|
| 163 |
+
neg_pos_ub=-1,
|
| 164 |
+
add_gt_as_proposals=True),
|
| 165 |
+
mask_size=28,
|
| 166 |
+
pos_weight=-1,
|
| 167 |
+
debug=False),
|
| 168 |
+
dict(
|
| 169 |
+
assigner=dict(
|
| 170 |
+
type='MaxIoUAssigner',
|
| 171 |
+
pos_iou_thr=0.7,
|
| 172 |
+
neg_iou_thr=0.7,
|
| 173 |
+
min_pos_iou=0.7,
|
| 174 |
+
match_low_quality=False,
|
| 175 |
+
ignore_iof_thr=-1),
|
| 176 |
+
sampler=dict(
|
| 177 |
+
type='RandomSampler',
|
| 178 |
+
num=512,
|
| 179 |
+
pos_fraction=0.25,
|
| 180 |
+
neg_pos_ub=-1,
|
| 181 |
+
add_gt_as_proposals=True),
|
| 182 |
+
mask_size=28,
|
| 183 |
+
pos_weight=-1,
|
| 184 |
+
debug=False)
|
| 185 |
+
]),
|
| 186 |
+
test_cfg=dict(
|
| 187 |
+
rpn=dict(
|
| 188 |
+
nms_pre=1000,
|
| 189 |
+
max_per_img=1000,
|
| 190 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 191 |
+
min_bbox_size=0),
|
| 192 |
+
rcnn=dict(
|
| 193 |
+
score_thr=0.05,
|
| 194 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 195 |
+
max_per_img=100,
|
| 196 |
+
mask_thr_binary=0.5)))
|
model/configs/_base_/models/cascade_mask_rcnn_r50_fpn_crowdhuman.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='CascadeRCNN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
num_outs=5),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=256,
|
| 22 |
+
feat_channels=256,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[8],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[4, 8, 16, 32, 64]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='CascadeRoIHead',
|
| 37 |
+
num_stages=3,
|
| 38 |
+
stage_loss_weights=[1, 0.5, 0.25],
|
| 39 |
+
bbox_roi_extractor=dict(
|
| 40 |
+
type='SingleRoIExtractor',
|
| 41 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 42 |
+
out_channels=256,
|
| 43 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 44 |
+
bbox_head=[
|
| 45 |
+
dict(
|
| 46 |
+
type='Shared2FCBBoxHead',
|
| 47 |
+
in_channels=256,
|
| 48 |
+
fc_out_channels=1024,
|
| 49 |
+
roi_feat_size=7,
|
| 50 |
+
num_classes=80,
|
| 51 |
+
bbox_coder=dict(
|
| 52 |
+
type='DeltaXYWHBBoxCoder',
|
| 53 |
+
target_means=[0., 0., 0., 0.],
|
| 54 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 55 |
+
reg_class_agnostic=True,
|
| 56 |
+
loss_cls=dict(
|
| 57 |
+
type='CrossEntropyLoss',
|
| 58 |
+
use_sigmoid=False,
|
| 59 |
+
loss_weight=1.0),
|
| 60 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 61 |
+
loss_weight=1.0)),
|
| 62 |
+
dict(
|
| 63 |
+
type='Shared2FCBBoxHead',
|
| 64 |
+
in_channels=256,
|
| 65 |
+
fc_out_channels=1024,
|
| 66 |
+
roi_feat_size=7,
|
| 67 |
+
num_classes=80,
|
| 68 |
+
bbox_coder=dict(
|
| 69 |
+
type='DeltaXYWHBBoxCoder',
|
| 70 |
+
target_means=[0., 0., 0., 0.],
|
| 71 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
| 72 |
+
reg_class_agnostic=True,
|
| 73 |
+
loss_cls=dict(
|
| 74 |
+
type='CrossEntropyLoss',
|
| 75 |
+
use_sigmoid=False,
|
| 76 |
+
loss_weight=1.0),
|
| 77 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 78 |
+
loss_weight=1.0)),
|
| 79 |
+
dict(
|
| 80 |
+
type='Shared2FCBBoxHead',
|
| 81 |
+
in_channels=256,
|
| 82 |
+
fc_out_channels=1024,
|
| 83 |
+
roi_feat_size=7,
|
| 84 |
+
num_classes=80,
|
| 85 |
+
bbox_coder=dict(
|
| 86 |
+
type='DeltaXYWHBBoxCoder',
|
| 87 |
+
target_means=[0., 0., 0., 0.],
|
| 88 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
| 89 |
+
reg_class_agnostic=True,
|
| 90 |
+
loss_cls=dict(
|
| 91 |
+
type='CrossEntropyLoss',
|
| 92 |
+
use_sigmoid=False,
|
| 93 |
+
loss_weight=1.0),
|
| 94 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
| 95 |
+
],),
|
| 96 |
+
# model training and testing settings
|
| 97 |
+
train_cfg=dict(
|
| 98 |
+
rpn=dict(
|
| 99 |
+
assigner=dict(
|
| 100 |
+
type='MaxIoUAssigner',
|
| 101 |
+
pos_iou_thr=0.7,
|
| 102 |
+
neg_iou_thr=0.3,
|
| 103 |
+
min_pos_iou=0.3,
|
| 104 |
+
match_low_quality=True,
|
| 105 |
+
ignore_iof_thr=-1),
|
| 106 |
+
sampler=dict(
|
| 107 |
+
type='RandomSampler',
|
| 108 |
+
num=256,
|
| 109 |
+
pos_fraction=0.5,
|
| 110 |
+
neg_pos_ub=-1,
|
| 111 |
+
add_gt_as_proposals=False),
|
| 112 |
+
allowed_border=0,
|
| 113 |
+
pos_weight=-1,
|
| 114 |
+
debug=False),
|
| 115 |
+
rpn_proposal=dict(
|
| 116 |
+
nms_pre=2000,
|
| 117 |
+
max_per_img=2000,
|
| 118 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 119 |
+
min_bbox_size=0),
|
| 120 |
+
rcnn=[
|
| 121 |
+
dict(
|
| 122 |
+
assigner=dict(
|
| 123 |
+
type='MaxIoUAssigner',
|
| 124 |
+
pos_iou_thr=0.5,
|
| 125 |
+
neg_iou_thr=0.5,
|
| 126 |
+
min_pos_iou=0.5,
|
| 127 |
+
match_low_quality=False,
|
| 128 |
+
ignore_iof_thr=-1),
|
| 129 |
+
sampler=dict(
|
| 130 |
+
type='RandomSampler',
|
| 131 |
+
num=512,
|
| 132 |
+
pos_fraction=0.25,
|
| 133 |
+
neg_pos_ub=-1,
|
| 134 |
+
add_gt_as_proposals=True),
|
| 135 |
+
mask_size=28,
|
| 136 |
+
pos_weight=-1,
|
| 137 |
+
debug=False),
|
| 138 |
+
dict(
|
| 139 |
+
assigner=dict(
|
| 140 |
+
type='MaxIoUAssigner',
|
| 141 |
+
pos_iou_thr=0.6,
|
| 142 |
+
neg_iou_thr=0.6,
|
| 143 |
+
min_pos_iou=0.6,
|
| 144 |
+
match_low_quality=False,
|
| 145 |
+
ignore_iof_thr=-1),
|
| 146 |
+
sampler=dict(
|
| 147 |
+
type='RandomSampler',
|
| 148 |
+
num=512,
|
| 149 |
+
pos_fraction=0.25,
|
| 150 |
+
neg_pos_ub=-1,
|
| 151 |
+
add_gt_as_proposals=True),
|
| 152 |
+
mask_size=28,
|
| 153 |
+
pos_weight=-1,
|
| 154 |
+
debug=False),
|
| 155 |
+
dict(
|
| 156 |
+
assigner=dict(
|
| 157 |
+
type='MaxIoUAssigner',
|
| 158 |
+
pos_iou_thr=0.7,
|
| 159 |
+
neg_iou_thr=0.7,
|
| 160 |
+
min_pos_iou=0.7,
|
| 161 |
+
match_low_quality=False,
|
| 162 |
+
ignore_iof_thr=-1),
|
| 163 |
+
sampler=dict(
|
| 164 |
+
type='RandomSampler',
|
| 165 |
+
num=512,
|
| 166 |
+
pos_fraction=0.25,
|
| 167 |
+
neg_pos_ub=-1,
|
| 168 |
+
add_gt_as_proposals=True),
|
| 169 |
+
mask_size=28,
|
| 170 |
+
pos_weight=-1,
|
| 171 |
+
debug=False)
|
| 172 |
+
]),
|
| 173 |
+
test_cfg=dict(
|
| 174 |
+
rpn=dict(
|
| 175 |
+
nms_pre=1000,
|
| 176 |
+
max_per_img=1000,
|
| 177 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 178 |
+
min_bbox_size=0),
|
| 179 |
+
rcnn=dict(
|
| 180 |
+
score_thr=0.05,
|
| 181 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 182 |
+
max_per_img=100,
|
| 183 |
+
mask_thr_binary=0.5)))
|
model/configs/_base_/models/cascade_rcnn_r50_fpn.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='CascadeRCNN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
num_outs=5),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=256,
|
| 22 |
+
feat_channels=256,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[8],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[4, 8, 16, 32, 64]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='CascadeRoIHead',
|
| 37 |
+
num_stages=3,
|
| 38 |
+
stage_loss_weights=[1, 0.5, 0.25],
|
| 39 |
+
bbox_roi_extractor=dict(
|
| 40 |
+
type='SingleRoIExtractor',
|
| 41 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 42 |
+
out_channels=256,
|
| 43 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 44 |
+
bbox_head=[
|
| 45 |
+
dict(
|
| 46 |
+
type='Shared2FCBBoxHead',
|
| 47 |
+
in_channels=256,
|
| 48 |
+
fc_out_channels=1024,
|
| 49 |
+
roi_feat_size=7,
|
| 50 |
+
num_classes=80,
|
| 51 |
+
bbox_coder=dict(
|
| 52 |
+
type='DeltaXYWHBBoxCoder',
|
| 53 |
+
target_means=[0., 0., 0., 0.],
|
| 54 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 55 |
+
reg_class_agnostic=True,
|
| 56 |
+
loss_cls=dict(
|
| 57 |
+
type='CrossEntropyLoss',
|
| 58 |
+
use_sigmoid=False,
|
| 59 |
+
loss_weight=1.0),
|
| 60 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 61 |
+
loss_weight=1.0)),
|
| 62 |
+
dict(
|
| 63 |
+
type='Shared2FCBBoxHead',
|
| 64 |
+
in_channels=256,
|
| 65 |
+
fc_out_channels=1024,
|
| 66 |
+
roi_feat_size=7,
|
| 67 |
+
num_classes=80,
|
| 68 |
+
bbox_coder=dict(
|
| 69 |
+
type='DeltaXYWHBBoxCoder',
|
| 70 |
+
target_means=[0., 0., 0., 0.],
|
| 71 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
| 72 |
+
reg_class_agnostic=True,
|
| 73 |
+
loss_cls=dict(
|
| 74 |
+
type='CrossEntropyLoss',
|
| 75 |
+
use_sigmoid=False,
|
| 76 |
+
loss_weight=1.0),
|
| 77 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 78 |
+
loss_weight=1.0)),
|
| 79 |
+
dict(
|
| 80 |
+
type='Shared2FCBBoxHead',
|
| 81 |
+
in_channels=256,
|
| 82 |
+
fc_out_channels=1024,
|
| 83 |
+
roi_feat_size=7,
|
| 84 |
+
num_classes=80,
|
| 85 |
+
bbox_coder=dict(
|
| 86 |
+
type='DeltaXYWHBBoxCoder',
|
| 87 |
+
target_means=[0., 0., 0., 0.],
|
| 88 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
| 89 |
+
reg_class_agnostic=True,
|
| 90 |
+
loss_cls=dict(
|
| 91 |
+
type='CrossEntropyLoss',
|
| 92 |
+
use_sigmoid=False,
|
| 93 |
+
loss_weight=1.0),
|
| 94 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
| 95 |
+
]),
|
| 96 |
+
# model training and testing settings
|
| 97 |
+
train_cfg=dict(
|
| 98 |
+
rpn=dict(
|
| 99 |
+
assigner=dict(
|
| 100 |
+
type='MaxIoUAssigner',
|
| 101 |
+
pos_iou_thr=0.7,
|
| 102 |
+
neg_iou_thr=0.3,
|
| 103 |
+
min_pos_iou=0.3,
|
| 104 |
+
match_low_quality=True,
|
| 105 |
+
ignore_iof_thr=-1),
|
| 106 |
+
sampler=dict(
|
| 107 |
+
type='RandomSampler',
|
| 108 |
+
num=256,
|
| 109 |
+
pos_fraction=0.5,
|
| 110 |
+
neg_pos_ub=-1,
|
| 111 |
+
add_gt_as_proposals=False),
|
| 112 |
+
allowed_border=0,
|
| 113 |
+
pos_weight=-1,
|
| 114 |
+
debug=False),
|
| 115 |
+
rpn_proposal=dict(
|
| 116 |
+
nms_pre=2000,
|
| 117 |
+
max_per_img=2000,
|
| 118 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 119 |
+
min_bbox_size=0),
|
| 120 |
+
rcnn=[
|
| 121 |
+
dict(
|
| 122 |
+
assigner=dict(
|
| 123 |
+
type='MaxIoUAssigner',
|
| 124 |
+
pos_iou_thr=0.5,
|
| 125 |
+
neg_iou_thr=0.5,
|
| 126 |
+
min_pos_iou=0.5,
|
| 127 |
+
match_low_quality=False,
|
| 128 |
+
ignore_iof_thr=-1),
|
| 129 |
+
sampler=dict(
|
| 130 |
+
type='RandomSampler',
|
| 131 |
+
num=512,
|
| 132 |
+
pos_fraction=0.25,
|
| 133 |
+
neg_pos_ub=-1,
|
| 134 |
+
add_gt_as_proposals=True),
|
| 135 |
+
pos_weight=-1,
|
| 136 |
+
debug=False),
|
| 137 |
+
dict(
|
| 138 |
+
assigner=dict(
|
| 139 |
+
type='MaxIoUAssigner',
|
| 140 |
+
pos_iou_thr=0.6,
|
| 141 |
+
neg_iou_thr=0.6,
|
| 142 |
+
min_pos_iou=0.6,
|
| 143 |
+
match_low_quality=False,
|
| 144 |
+
ignore_iof_thr=-1),
|
| 145 |
+
sampler=dict(
|
| 146 |
+
type='RandomSampler',
|
| 147 |
+
num=512,
|
| 148 |
+
pos_fraction=0.25,
|
| 149 |
+
neg_pos_ub=-1,
|
| 150 |
+
add_gt_as_proposals=True),
|
| 151 |
+
pos_weight=-1,
|
| 152 |
+
debug=False),
|
| 153 |
+
dict(
|
| 154 |
+
assigner=dict(
|
| 155 |
+
type='MaxIoUAssigner',
|
| 156 |
+
pos_iou_thr=0.7,
|
| 157 |
+
neg_iou_thr=0.7,
|
| 158 |
+
min_pos_iou=0.7,
|
| 159 |
+
match_low_quality=False,
|
| 160 |
+
ignore_iof_thr=-1),
|
| 161 |
+
sampler=dict(
|
| 162 |
+
type='RandomSampler',
|
| 163 |
+
num=512,
|
| 164 |
+
pos_fraction=0.25,
|
| 165 |
+
neg_pos_ub=-1,
|
| 166 |
+
add_gt_as_proposals=True),
|
| 167 |
+
pos_weight=-1,
|
| 168 |
+
debug=False)
|
| 169 |
+
]),
|
| 170 |
+
test_cfg=dict(
|
| 171 |
+
rpn=dict(
|
| 172 |
+
nms_pre=1000,
|
| 173 |
+
max_per_img=1000,
|
| 174 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 175 |
+
min_bbox_size=0),
|
| 176 |
+
rcnn=dict(
|
| 177 |
+
score_thr=0.05,
|
| 178 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 179 |
+
max_per_img=100)))
|
model/configs/_base_/models/fast_rcnn_r50_fpn.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='FastRCNN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
num_outs=5),
|
| 19 |
+
roi_head=dict(
|
| 20 |
+
type='StandardRoIHead',
|
| 21 |
+
bbox_roi_extractor=dict(
|
| 22 |
+
type='SingleRoIExtractor',
|
| 23 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 24 |
+
out_channels=256,
|
| 25 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 26 |
+
bbox_head=dict(
|
| 27 |
+
type='Shared2FCBBoxHead',
|
| 28 |
+
in_channels=256,
|
| 29 |
+
fc_out_channels=1024,
|
| 30 |
+
roi_feat_size=7,
|
| 31 |
+
num_classes=80,
|
| 32 |
+
bbox_coder=dict(
|
| 33 |
+
type='DeltaXYWHBBoxCoder',
|
| 34 |
+
target_means=[0., 0., 0., 0.],
|
| 35 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 36 |
+
reg_class_agnostic=False,
|
| 37 |
+
loss_cls=dict(
|
| 38 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 39 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
| 40 |
+
# model training and testing settings
|
| 41 |
+
train_cfg=dict(
|
| 42 |
+
rcnn=dict(
|
| 43 |
+
assigner=dict(
|
| 44 |
+
type='MaxIoUAssigner',
|
| 45 |
+
pos_iou_thr=0.5,
|
| 46 |
+
neg_iou_thr=0.5,
|
| 47 |
+
min_pos_iou=0.5,
|
| 48 |
+
match_low_quality=False,
|
| 49 |
+
ignore_iof_thr=-1),
|
| 50 |
+
sampler=dict(
|
| 51 |
+
type='RandomSampler',
|
| 52 |
+
num=512,
|
| 53 |
+
pos_fraction=0.25,
|
| 54 |
+
neg_pos_ub=-1,
|
| 55 |
+
add_gt_as_proposals=True),
|
| 56 |
+
pos_weight=-1,
|
| 57 |
+
debug=False)),
|
| 58 |
+
test_cfg=dict(
|
| 59 |
+
rcnn=dict(
|
| 60 |
+
score_thr=0.05,
|
| 61 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 62 |
+
max_per_img=100)))
|
model/configs/_base_/models/faster_rcnn_r50_caffe_c4.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='BN', requires_grad=False)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='FasterRCNN',
|
| 5 |
+
backbone=dict(
|
| 6 |
+
type='ResNet',
|
| 7 |
+
depth=50,
|
| 8 |
+
num_stages=3,
|
| 9 |
+
strides=(1, 2, 2),
|
| 10 |
+
dilations=(1, 1, 1),
|
| 11 |
+
out_indices=(2, ),
|
| 12 |
+
frozen_stages=1,
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=True,
|
| 15 |
+
style='caffe',
|
| 16 |
+
init_cfg=dict(
|
| 17 |
+
type='Pretrained',
|
| 18 |
+
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=1024,
|
| 22 |
+
feat_channels=1024,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[2, 4, 8, 16, 32],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[16]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='StandardRoIHead',
|
| 37 |
+
shared_head=dict(
|
| 38 |
+
type='ResLayer',
|
| 39 |
+
depth=50,
|
| 40 |
+
stage=3,
|
| 41 |
+
stride=2,
|
| 42 |
+
dilation=1,
|
| 43 |
+
style='caffe',
|
| 44 |
+
norm_cfg=norm_cfg,
|
| 45 |
+
norm_eval=True),
|
| 46 |
+
bbox_roi_extractor=dict(
|
| 47 |
+
type='SingleRoIExtractor',
|
| 48 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
| 49 |
+
out_channels=1024,
|
| 50 |
+
featmap_strides=[16]),
|
| 51 |
+
bbox_head=dict(
|
| 52 |
+
type='BBoxHead',
|
| 53 |
+
with_avg_pool=True,
|
| 54 |
+
roi_feat_size=7,
|
| 55 |
+
in_channels=2048,
|
| 56 |
+
num_classes=80,
|
| 57 |
+
bbox_coder=dict(
|
| 58 |
+
type='DeltaXYWHBBoxCoder',
|
| 59 |
+
target_means=[0., 0., 0., 0.],
|
| 60 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 61 |
+
reg_class_agnostic=False,
|
| 62 |
+
loss_cls=dict(
|
| 63 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 64 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
| 65 |
+
# model training and testing settings
|
| 66 |
+
train_cfg=dict(
|
| 67 |
+
rpn=dict(
|
| 68 |
+
assigner=dict(
|
| 69 |
+
type='MaxIoUAssigner',
|
| 70 |
+
pos_iou_thr=0.7,
|
| 71 |
+
neg_iou_thr=0.3,
|
| 72 |
+
min_pos_iou=0.3,
|
| 73 |
+
match_low_quality=True,
|
| 74 |
+
ignore_iof_thr=-1),
|
| 75 |
+
sampler=dict(
|
| 76 |
+
type='RandomSampler',
|
| 77 |
+
num=256,
|
| 78 |
+
pos_fraction=0.5,
|
| 79 |
+
neg_pos_ub=-1,
|
| 80 |
+
add_gt_as_proposals=False),
|
| 81 |
+
allowed_border=0,
|
| 82 |
+
pos_weight=-1,
|
| 83 |
+
debug=False),
|
| 84 |
+
rpn_proposal=dict(
|
| 85 |
+
nms_pre=12000,
|
| 86 |
+
max_per_img=2000,
|
| 87 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 88 |
+
min_bbox_size=0),
|
| 89 |
+
rcnn=dict(
|
| 90 |
+
assigner=dict(
|
| 91 |
+
type='MaxIoUAssigner',
|
| 92 |
+
pos_iou_thr=0.5,
|
| 93 |
+
neg_iou_thr=0.5,
|
| 94 |
+
min_pos_iou=0.5,
|
| 95 |
+
match_low_quality=False,
|
| 96 |
+
ignore_iof_thr=-1),
|
| 97 |
+
sampler=dict(
|
| 98 |
+
type='RandomSampler',
|
| 99 |
+
num=512,
|
| 100 |
+
pos_fraction=0.25,
|
| 101 |
+
neg_pos_ub=-1,
|
| 102 |
+
add_gt_as_proposals=True),
|
| 103 |
+
pos_weight=-1,
|
| 104 |
+
debug=False)),
|
| 105 |
+
test_cfg=dict(
|
| 106 |
+
rpn=dict(
|
| 107 |
+
nms_pre=6000,
|
| 108 |
+
max_per_img=1000,
|
| 109 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 110 |
+
min_bbox_size=0),
|
| 111 |
+
rcnn=dict(
|
| 112 |
+
score_thr=0.05,
|
| 113 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 114 |
+
max_per_img=100)))
|
model/configs/_base_/models/faster_rcnn_r50_caffe_dc5.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='BN', requires_grad=False)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='FasterRCNN',
|
| 5 |
+
backbone=dict(
|
| 6 |
+
type='ResNet',
|
| 7 |
+
depth=50,
|
| 8 |
+
num_stages=4,
|
| 9 |
+
strides=(1, 2, 2, 1),
|
| 10 |
+
dilations=(1, 1, 1, 2),
|
| 11 |
+
out_indices=(3, ),
|
| 12 |
+
frozen_stages=1,
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=True,
|
| 15 |
+
style='caffe',
|
| 16 |
+
init_cfg=dict(
|
| 17 |
+
type='Pretrained',
|
| 18 |
+
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=2048,
|
| 22 |
+
feat_channels=2048,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[2, 4, 8, 16, 32],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[16]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='StandardRoIHead',
|
| 37 |
+
bbox_roi_extractor=dict(
|
| 38 |
+
type='SingleRoIExtractor',
|
| 39 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 40 |
+
out_channels=2048,
|
| 41 |
+
featmap_strides=[16]),
|
| 42 |
+
bbox_head=dict(
|
| 43 |
+
type='Shared2FCBBoxHead',
|
| 44 |
+
in_channels=2048,
|
| 45 |
+
fc_out_channels=1024,
|
| 46 |
+
roi_feat_size=7,
|
| 47 |
+
num_classes=80,
|
| 48 |
+
bbox_coder=dict(
|
| 49 |
+
type='DeltaXYWHBBoxCoder',
|
| 50 |
+
target_means=[0., 0., 0., 0.],
|
| 51 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 52 |
+
reg_class_agnostic=False,
|
| 53 |
+
loss_cls=dict(
|
| 54 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 55 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
| 56 |
+
# model training and testing settings
|
| 57 |
+
train_cfg=dict(
|
| 58 |
+
rpn=dict(
|
| 59 |
+
assigner=dict(
|
| 60 |
+
type='MaxIoUAssigner',
|
| 61 |
+
pos_iou_thr=0.7,
|
| 62 |
+
neg_iou_thr=0.3,
|
| 63 |
+
min_pos_iou=0.3,
|
| 64 |
+
match_low_quality=True,
|
| 65 |
+
ignore_iof_thr=-1),
|
| 66 |
+
sampler=dict(
|
| 67 |
+
type='RandomSampler',
|
| 68 |
+
num=256,
|
| 69 |
+
pos_fraction=0.5,
|
| 70 |
+
neg_pos_ub=-1,
|
| 71 |
+
add_gt_as_proposals=False),
|
| 72 |
+
allowed_border=0,
|
| 73 |
+
pos_weight=-1,
|
| 74 |
+
debug=False),
|
| 75 |
+
rpn_proposal=dict(
|
| 76 |
+
nms_pre=12000,
|
| 77 |
+
max_per_img=2000,
|
| 78 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 79 |
+
min_bbox_size=0),
|
| 80 |
+
rcnn=dict(
|
| 81 |
+
assigner=dict(
|
| 82 |
+
type='MaxIoUAssigner',
|
| 83 |
+
pos_iou_thr=0.5,
|
| 84 |
+
neg_iou_thr=0.5,
|
| 85 |
+
min_pos_iou=0.5,
|
| 86 |
+
match_low_quality=False,
|
| 87 |
+
ignore_iof_thr=-1),
|
| 88 |
+
sampler=dict(
|
| 89 |
+
type='RandomSampler',
|
| 90 |
+
num=512,
|
| 91 |
+
pos_fraction=0.25,
|
| 92 |
+
neg_pos_ub=-1,
|
| 93 |
+
add_gt_as_proposals=True),
|
| 94 |
+
pos_weight=-1,
|
| 95 |
+
debug=False)),
|
| 96 |
+
test_cfg=dict(
|
| 97 |
+
rpn=dict(
|
| 98 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 99 |
+
nms_pre=6000,
|
| 100 |
+
max_per_img=1000,
|
| 101 |
+
min_bbox_size=0),
|
| 102 |
+
rcnn=dict(
|
| 103 |
+
score_thr=0.05,
|
| 104 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 105 |
+
max_per_img=100)))
|
model/configs/_base_/models/faster_rcnn_r50_fpn.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='FasterRCNN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
num_outs=5),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=256,
|
| 22 |
+
feat_channels=256,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[8],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[4, 8, 16, 32, 64]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='StandardRoIHead',
|
| 37 |
+
bbox_roi_extractor=dict(
|
| 38 |
+
type='SingleRoIExtractor',
|
| 39 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 40 |
+
out_channels=256,
|
| 41 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 42 |
+
bbox_head=dict(
|
| 43 |
+
type='Shared2FCBBoxHead',
|
| 44 |
+
in_channels=256,
|
| 45 |
+
fc_out_channels=1024,
|
| 46 |
+
roi_feat_size=7,
|
| 47 |
+
num_classes=80,
|
| 48 |
+
bbox_coder=dict(
|
| 49 |
+
type='DeltaXYWHBBoxCoder',
|
| 50 |
+
target_means=[0., 0., 0., 0.],
|
| 51 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 52 |
+
reg_class_agnostic=False,
|
| 53 |
+
loss_cls=dict(
|
| 54 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 55 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
|
| 56 |
+
# model training and testing settings
|
| 57 |
+
train_cfg=dict(
|
| 58 |
+
rpn=dict(
|
| 59 |
+
assigner=dict(
|
| 60 |
+
type='MaxIoUAssigner',
|
| 61 |
+
pos_iou_thr=0.7,
|
| 62 |
+
neg_iou_thr=0.3,
|
| 63 |
+
min_pos_iou=0.3,
|
| 64 |
+
match_low_quality=True,
|
| 65 |
+
ignore_iof_thr=-1),
|
| 66 |
+
sampler=dict(
|
| 67 |
+
type='RandomSampler',
|
| 68 |
+
num=256,
|
| 69 |
+
pos_fraction=0.5,
|
| 70 |
+
neg_pos_ub=-1,
|
| 71 |
+
add_gt_as_proposals=False),
|
| 72 |
+
allowed_border=-1,
|
| 73 |
+
pos_weight=-1,
|
| 74 |
+
debug=False),
|
| 75 |
+
rpn_proposal=dict(
|
| 76 |
+
nms_pre=2000,
|
| 77 |
+
max_per_img=1000,
|
| 78 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 79 |
+
min_bbox_size=0),
|
| 80 |
+
rcnn=dict(
|
| 81 |
+
assigner=dict(
|
| 82 |
+
type='MaxIoUAssigner',
|
| 83 |
+
pos_iou_thr=0.5,
|
| 84 |
+
neg_iou_thr=0.5,
|
| 85 |
+
min_pos_iou=0.5,
|
| 86 |
+
match_low_quality=False,
|
| 87 |
+
ignore_iof_thr=-1),
|
| 88 |
+
sampler=dict(
|
| 89 |
+
type='RandomSampler',
|
| 90 |
+
num=512,
|
| 91 |
+
pos_fraction=0.25,
|
| 92 |
+
neg_pos_ub=-1,
|
| 93 |
+
add_gt_as_proposals=True),
|
| 94 |
+
pos_weight=-1,
|
| 95 |
+
debug=False)),
|
| 96 |
+
test_cfg=dict(
|
| 97 |
+
rpn=dict(
|
| 98 |
+
nms_pre=1000,
|
| 99 |
+
max_per_img=1000,
|
| 100 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 101 |
+
min_bbox_size=0),
|
| 102 |
+
rcnn=dict(
|
| 103 |
+
score_thr=0.05,
|
| 104 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 105 |
+
max_per_img=100)
|
| 106 |
+
# soft-nms is also supported for rcnn testing
|
| 107 |
+
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
|
| 108 |
+
))
|
model/configs/_base_/models/mask_rcnn_convnext_fpn.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# model settings
|
| 10 |
+
model = dict(
|
| 11 |
+
type='MaskRCNN',
|
| 12 |
+
pretrained=None,
|
| 13 |
+
backbone=dict(
|
| 14 |
+
type='ConvNeXt',
|
| 15 |
+
in_chans=3,
|
| 16 |
+
depths=[3, 3, 9, 3],
|
| 17 |
+
dims=[96, 192, 384, 768],
|
| 18 |
+
drop_path_rate=0.2,
|
| 19 |
+
layer_scale_init_value=1e-6,
|
| 20 |
+
out_indices=[0, 1, 2, 3],
|
| 21 |
+
),
|
| 22 |
+
neck=dict(
|
| 23 |
+
type='FPN',
|
| 24 |
+
in_channels=[128, 256, 512, 1024],
|
| 25 |
+
out_channels=256,
|
| 26 |
+
num_outs=5),
|
| 27 |
+
rpn_head=dict(
|
| 28 |
+
type='RPNHead',
|
| 29 |
+
in_channels=256,
|
| 30 |
+
feat_channels=256,
|
| 31 |
+
anchor_generator=dict(
|
| 32 |
+
type='AnchorGenerator',
|
| 33 |
+
scales=[8],
|
| 34 |
+
ratios=[0.5, 1.0, 2.0],
|
| 35 |
+
strides=[4, 8, 16, 32, 64]),
|
| 36 |
+
bbox_coder=dict(
|
| 37 |
+
type='DeltaXYWHBBoxCoder',
|
| 38 |
+
target_means=[.0, .0, .0, .0],
|
| 39 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 40 |
+
loss_cls=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 42 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 43 |
+
roi_head=dict(
|
| 44 |
+
type='StandardRoIHead',
|
| 45 |
+
bbox_roi_extractor=dict(
|
| 46 |
+
type='SingleRoIExtractor',
|
| 47 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 48 |
+
out_channels=256,
|
| 49 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 50 |
+
bbox_head=dict(
|
| 51 |
+
type='Shared2FCBBoxHead',
|
| 52 |
+
in_channels=256,
|
| 53 |
+
fc_out_channels=1024,
|
| 54 |
+
roi_feat_size=7,
|
| 55 |
+
num_classes=80,
|
| 56 |
+
bbox_coder=dict(
|
| 57 |
+
type='DeltaXYWHBBoxCoder',
|
| 58 |
+
target_means=[0., 0., 0., 0.],
|
| 59 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 60 |
+
reg_class_agnostic=False,
|
| 61 |
+
loss_cls=dict(
|
| 62 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 63 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 64 |
+
mask_roi_extractor=dict(
|
| 65 |
+
type='SingleRoIExtractor',
|
| 66 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
| 67 |
+
out_channels=256,
|
| 68 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 69 |
+
mask_head=dict(
|
| 70 |
+
type='FCNMaskHead',
|
| 71 |
+
num_convs=4,
|
| 72 |
+
in_channels=256,
|
| 73 |
+
conv_out_channels=256,
|
| 74 |
+
num_classes=80,
|
| 75 |
+
loss_mask=dict(
|
| 76 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
| 77 |
+
# model training and testing settings
|
| 78 |
+
train_cfg=dict(
|
| 79 |
+
rpn=dict(
|
| 80 |
+
assigner=dict(
|
| 81 |
+
type='MaxIoUAssigner',
|
| 82 |
+
pos_iou_thr=0.7,
|
| 83 |
+
neg_iou_thr=0.3,
|
| 84 |
+
min_pos_iou=0.3,
|
| 85 |
+
match_low_quality=True,
|
| 86 |
+
ignore_iof_thr=-1),
|
| 87 |
+
sampler=dict(
|
| 88 |
+
type='RandomSampler',
|
| 89 |
+
num=256,
|
| 90 |
+
pos_fraction=0.5,
|
| 91 |
+
neg_pos_ub=-1,
|
| 92 |
+
add_gt_as_proposals=False),
|
| 93 |
+
allowed_border=-1,
|
| 94 |
+
pos_weight=-1,
|
| 95 |
+
debug=False),
|
| 96 |
+
rpn_proposal=dict(
|
| 97 |
+
nms_pre=2000,
|
| 98 |
+
max_per_img=1000,
|
| 99 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 100 |
+
min_bbox_size=0),
|
| 101 |
+
rcnn=dict(
|
| 102 |
+
assigner=dict(
|
| 103 |
+
type='MaxIoUAssigner',
|
| 104 |
+
pos_iou_thr=0.5,
|
| 105 |
+
neg_iou_thr=0.5,
|
| 106 |
+
min_pos_iou=0.5,
|
| 107 |
+
match_low_quality=True,
|
| 108 |
+
ignore_iof_thr=-1),
|
| 109 |
+
sampler=dict(
|
| 110 |
+
type='RandomSampler',
|
| 111 |
+
num=512,
|
| 112 |
+
pos_fraction=0.25,
|
| 113 |
+
neg_pos_ub=-1,
|
| 114 |
+
add_gt_as_proposals=True),
|
| 115 |
+
mask_size=28,
|
| 116 |
+
pos_weight=-1,
|
| 117 |
+
debug=False)),
|
| 118 |
+
test_cfg=dict(
|
| 119 |
+
rpn=dict(
|
| 120 |
+
nms_pre=1000,
|
| 121 |
+
max_per_img=1000,
|
| 122 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 123 |
+
min_bbox_size=0),
|
| 124 |
+
rcnn=dict(
|
| 125 |
+
score_thr=0.05,
|
| 126 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 127 |
+
max_per_img=100,
|
| 128 |
+
mask_thr_binary=0.5)))
|
model/configs/_base_/models/mask_rcnn_r50_caffe_c4.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
norm_cfg = dict(type='BN', requires_grad=False)
|
| 3 |
+
model = dict(
|
| 4 |
+
type='MaskRCNN',
|
| 5 |
+
backbone=dict(
|
| 6 |
+
type='ResNet',
|
| 7 |
+
depth=50,
|
| 8 |
+
num_stages=3,
|
| 9 |
+
strides=(1, 2, 2),
|
| 10 |
+
dilations=(1, 1, 1),
|
| 11 |
+
out_indices=(2, ),
|
| 12 |
+
frozen_stages=1,
|
| 13 |
+
norm_cfg=norm_cfg,
|
| 14 |
+
norm_eval=True,
|
| 15 |
+
style='caffe',
|
| 16 |
+
init_cfg=dict(
|
| 17 |
+
type='Pretrained',
|
| 18 |
+
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=1024,
|
| 22 |
+
feat_channels=1024,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[2, 4, 8, 16, 32],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[16]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='StandardRoIHead',
|
| 37 |
+
shared_head=dict(
|
| 38 |
+
type='ResLayer',
|
| 39 |
+
depth=50,
|
| 40 |
+
stage=3,
|
| 41 |
+
stride=2,
|
| 42 |
+
dilation=1,
|
| 43 |
+
style='caffe',
|
| 44 |
+
norm_cfg=norm_cfg,
|
| 45 |
+
norm_eval=True),
|
| 46 |
+
bbox_roi_extractor=dict(
|
| 47 |
+
type='SingleRoIExtractor',
|
| 48 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
| 49 |
+
out_channels=1024,
|
| 50 |
+
featmap_strides=[16]),
|
| 51 |
+
bbox_head=dict(
|
| 52 |
+
type='BBoxHead',
|
| 53 |
+
with_avg_pool=True,
|
| 54 |
+
roi_feat_size=7,
|
| 55 |
+
in_channels=2048,
|
| 56 |
+
num_classes=80,
|
| 57 |
+
bbox_coder=dict(
|
| 58 |
+
type='DeltaXYWHBBoxCoder',
|
| 59 |
+
target_means=[0., 0., 0., 0.],
|
| 60 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 61 |
+
reg_class_agnostic=False,
|
| 62 |
+
loss_cls=dict(
|
| 63 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 64 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 65 |
+
mask_roi_extractor=None,
|
| 66 |
+
mask_head=dict(
|
| 67 |
+
type='FCNMaskHead',
|
| 68 |
+
num_convs=0,
|
| 69 |
+
in_channels=2048,
|
| 70 |
+
conv_out_channels=256,
|
| 71 |
+
num_classes=80,
|
| 72 |
+
loss_mask=dict(
|
| 73 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
| 74 |
+
# model training and testing settings
|
| 75 |
+
train_cfg=dict(
|
| 76 |
+
rpn=dict(
|
| 77 |
+
assigner=dict(
|
| 78 |
+
type='MaxIoUAssigner',
|
| 79 |
+
pos_iou_thr=0.7,
|
| 80 |
+
neg_iou_thr=0.3,
|
| 81 |
+
min_pos_iou=0.3,
|
| 82 |
+
match_low_quality=True,
|
| 83 |
+
ignore_iof_thr=-1),
|
| 84 |
+
sampler=dict(
|
| 85 |
+
type='RandomSampler',
|
| 86 |
+
num=256,
|
| 87 |
+
pos_fraction=0.5,
|
| 88 |
+
neg_pos_ub=-1,
|
| 89 |
+
add_gt_as_proposals=False),
|
| 90 |
+
allowed_border=0,
|
| 91 |
+
pos_weight=-1,
|
| 92 |
+
debug=False),
|
| 93 |
+
rpn_proposal=dict(
|
| 94 |
+
nms_pre=12000,
|
| 95 |
+
max_per_img=2000,
|
| 96 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 97 |
+
min_bbox_size=0),
|
| 98 |
+
rcnn=dict(
|
| 99 |
+
assigner=dict(
|
| 100 |
+
type='MaxIoUAssigner',
|
| 101 |
+
pos_iou_thr=0.5,
|
| 102 |
+
neg_iou_thr=0.5,
|
| 103 |
+
min_pos_iou=0.5,
|
| 104 |
+
match_low_quality=False,
|
| 105 |
+
ignore_iof_thr=-1),
|
| 106 |
+
sampler=dict(
|
| 107 |
+
type='RandomSampler',
|
| 108 |
+
num=512,
|
| 109 |
+
pos_fraction=0.25,
|
| 110 |
+
neg_pos_ub=-1,
|
| 111 |
+
add_gt_as_proposals=True),
|
| 112 |
+
mask_size=14,
|
| 113 |
+
pos_weight=-1,
|
| 114 |
+
debug=False)),
|
| 115 |
+
test_cfg=dict(
|
| 116 |
+
rpn=dict(
|
| 117 |
+
nms_pre=6000,
|
| 118 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 119 |
+
max_per_img=1000,
|
| 120 |
+
min_bbox_size=0),
|
| 121 |
+
rcnn=dict(
|
| 122 |
+
score_thr=0.05,
|
| 123 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 124 |
+
max_per_img=100,
|
| 125 |
+
mask_thr_binary=0.5)))
|
model/configs/_base_/models/mask_rcnn_r50_fpn.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='MaskRCNN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
num_outs=5),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=256,
|
| 22 |
+
feat_channels=256,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[8],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[4, 8, 16, 32, 64]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 35 |
+
roi_head=dict(
|
| 36 |
+
type='StandardRoIHead',
|
| 37 |
+
bbox_roi_extractor=dict(
|
| 38 |
+
type='SingleRoIExtractor',
|
| 39 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 40 |
+
out_channels=256,
|
| 41 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 42 |
+
bbox_head=dict(
|
| 43 |
+
type='Shared2FCBBoxHead',
|
| 44 |
+
in_channels=256,
|
| 45 |
+
fc_out_channels=1024,
|
| 46 |
+
roi_feat_size=7,
|
| 47 |
+
num_classes=80,
|
| 48 |
+
bbox_coder=dict(
|
| 49 |
+
type='DeltaXYWHBBoxCoder',
|
| 50 |
+
target_means=[0., 0., 0., 0.],
|
| 51 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 52 |
+
reg_class_agnostic=False,
|
| 53 |
+
loss_cls=dict(
|
| 54 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 55 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 56 |
+
mask_roi_extractor=dict(
|
| 57 |
+
type='SingleRoIExtractor',
|
| 58 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
| 59 |
+
out_channels=256,
|
| 60 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 61 |
+
mask_head=dict(
|
| 62 |
+
type='FCNMaskHead',
|
| 63 |
+
num_convs=4,
|
| 64 |
+
in_channels=256,
|
| 65 |
+
conv_out_channels=256,
|
| 66 |
+
num_classes=80,
|
| 67 |
+
loss_mask=dict(
|
| 68 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
| 69 |
+
# model training and testing settings
|
| 70 |
+
train_cfg=dict(
|
| 71 |
+
rpn=dict(
|
| 72 |
+
assigner=dict(
|
| 73 |
+
type='MaxIoUAssigner',
|
| 74 |
+
pos_iou_thr=0.7,
|
| 75 |
+
neg_iou_thr=0.3,
|
| 76 |
+
min_pos_iou=0.3,
|
| 77 |
+
match_low_quality=True,
|
| 78 |
+
ignore_iof_thr=-1),
|
| 79 |
+
sampler=dict(
|
| 80 |
+
type='RandomSampler',
|
| 81 |
+
num=256,
|
| 82 |
+
pos_fraction=0.5,
|
| 83 |
+
neg_pos_ub=-1,
|
| 84 |
+
add_gt_as_proposals=False),
|
| 85 |
+
allowed_border=-1,
|
| 86 |
+
pos_weight=-1,
|
| 87 |
+
debug=False),
|
| 88 |
+
rpn_proposal=dict(
|
| 89 |
+
nms_pre=2000,
|
| 90 |
+
max_per_img=1000,
|
| 91 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 92 |
+
min_bbox_size=0),
|
| 93 |
+
rcnn=dict(
|
| 94 |
+
assigner=dict(
|
| 95 |
+
type='MaxIoUAssigner',
|
| 96 |
+
pos_iou_thr=0.5,
|
| 97 |
+
neg_iou_thr=0.5,
|
| 98 |
+
min_pos_iou=0.5,
|
| 99 |
+
match_low_quality=True,
|
| 100 |
+
ignore_iof_thr=-1),
|
| 101 |
+
sampler=dict(
|
| 102 |
+
type='RandomSampler',
|
| 103 |
+
num=512,
|
| 104 |
+
pos_fraction=0.25,
|
| 105 |
+
neg_pos_ub=-1,
|
| 106 |
+
add_gt_as_proposals=True),
|
| 107 |
+
mask_size=28,
|
| 108 |
+
pos_weight=-1,
|
| 109 |
+
debug=False)),
|
| 110 |
+
test_cfg=dict(
|
| 111 |
+
rpn=dict(
|
| 112 |
+
nms_pre=1000,
|
| 113 |
+
max_per_img=1000,
|
| 114 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 115 |
+
min_bbox_size=0),
|
| 116 |
+
rcnn=dict(
|
| 117 |
+
score_thr=0.05,
|
| 118 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 119 |
+
max_per_img=100,
|
| 120 |
+
mask_thr_binary=0.5)))
|
model/configs/_base_/models/retinanet_r50_fpn.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='RetinaNet',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
start_level=1,
|
| 19 |
+
add_extra_convs='on_input',
|
| 20 |
+
num_outs=5),
|
| 21 |
+
bbox_head=dict(
|
| 22 |
+
type='RetinaHead',
|
| 23 |
+
num_classes=80,
|
| 24 |
+
in_channels=256,
|
| 25 |
+
stacked_convs=4,
|
| 26 |
+
feat_channels=256,
|
| 27 |
+
anchor_generator=dict(
|
| 28 |
+
type='AnchorGenerator',
|
| 29 |
+
octave_base_scale=4,
|
| 30 |
+
scales_per_octave=3,
|
| 31 |
+
ratios=[0.5, 1.0, 2.0],
|
| 32 |
+
strides=[8, 16, 32, 64, 128]),
|
| 33 |
+
bbox_coder=dict(
|
| 34 |
+
type='DeltaXYWHBBoxCoder',
|
| 35 |
+
target_means=[.0, .0, .0, .0],
|
| 36 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 37 |
+
loss_cls=dict(
|
| 38 |
+
type='FocalLoss',
|
| 39 |
+
use_sigmoid=True,
|
| 40 |
+
gamma=2.0,
|
| 41 |
+
alpha=0.25,
|
| 42 |
+
loss_weight=1.0),
|
| 43 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 44 |
+
# model training and testing settings
|
| 45 |
+
train_cfg=dict(
|
| 46 |
+
assigner=dict(
|
| 47 |
+
type='MaxIoUAssigner',
|
| 48 |
+
pos_iou_thr=0.5,
|
| 49 |
+
neg_iou_thr=0.4,
|
| 50 |
+
min_pos_iou=0,
|
| 51 |
+
ignore_iof_thr=-1),
|
| 52 |
+
allowed_border=-1,
|
| 53 |
+
pos_weight=-1,
|
| 54 |
+
debug=False),
|
| 55 |
+
test_cfg=dict(
|
| 56 |
+
nms_pre=1000,
|
| 57 |
+
min_bbox_size=0,
|
| 58 |
+
score_thr=0.05,
|
| 59 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 60 |
+
max_per_img=100))
|
model/configs/_base_/models/rpn_r50_caffe_c4.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='RPN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=3,
|
| 8 |
+
strides=(1, 2, 2),
|
| 9 |
+
dilations=(1, 1, 1),
|
| 10 |
+
out_indices=(2, ),
|
| 11 |
+
frozen_stages=1,
|
| 12 |
+
norm_cfg=dict(type='BN', requires_grad=False),
|
| 13 |
+
norm_eval=True,
|
| 14 |
+
style='caffe',
|
| 15 |
+
init_cfg=dict(
|
| 16 |
+
type='Pretrained',
|
| 17 |
+
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
|
| 18 |
+
neck=None,
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=1024,
|
| 22 |
+
feat_channels=1024,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[2, 4, 8, 16, 32],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[16]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 35 |
+
# model training and testing settings
|
| 36 |
+
train_cfg=dict(
|
| 37 |
+
rpn=dict(
|
| 38 |
+
assigner=dict(
|
| 39 |
+
type='MaxIoUAssigner',
|
| 40 |
+
pos_iou_thr=0.7,
|
| 41 |
+
neg_iou_thr=0.3,
|
| 42 |
+
min_pos_iou=0.3,
|
| 43 |
+
ignore_iof_thr=-1),
|
| 44 |
+
sampler=dict(
|
| 45 |
+
type='RandomSampler',
|
| 46 |
+
num=256,
|
| 47 |
+
pos_fraction=0.5,
|
| 48 |
+
neg_pos_ub=-1,
|
| 49 |
+
add_gt_as_proposals=False),
|
| 50 |
+
allowed_border=0,
|
| 51 |
+
pos_weight=-1,
|
| 52 |
+
debug=False)),
|
| 53 |
+
test_cfg=dict(
|
| 54 |
+
rpn=dict(
|
| 55 |
+
nms_pre=12000,
|
| 56 |
+
max_per_img=2000,
|
| 57 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 58 |
+
min_bbox_size=0)))
|
model/configs/_base_/models/rpn_r50_fpn.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
model = dict(
|
| 3 |
+
type='RPN',
|
| 4 |
+
backbone=dict(
|
| 5 |
+
type='ResNet',
|
| 6 |
+
depth=50,
|
| 7 |
+
num_stages=4,
|
| 8 |
+
out_indices=(0, 1, 2, 3),
|
| 9 |
+
frozen_stages=1,
|
| 10 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 11 |
+
norm_eval=True,
|
| 12 |
+
style='pytorch',
|
| 13 |
+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='FPN',
|
| 16 |
+
in_channels=[256, 512, 1024, 2048],
|
| 17 |
+
out_channels=256,
|
| 18 |
+
num_outs=5),
|
| 19 |
+
rpn_head=dict(
|
| 20 |
+
type='RPNHead',
|
| 21 |
+
in_channels=256,
|
| 22 |
+
feat_channels=256,
|
| 23 |
+
anchor_generator=dict(
|
| 24 |
+
type='AnchorGenerator',
|
| 25 |
+
scales=[8],
|
| 26 |
+
ratios=[0.5, 1.0, 2.0],
|
| 27 |
+
strides=[4, 8, 16, 32, 64]),
|
| 28 |
+
bbox_coder=dict(
|
| 29 |
+
type='DeltaXYWHBBoxCoder',
|
| 30 |
+
target_means=[.0, .0, .0, .0],
|
| 31 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 32 |
+
loss_cls=dict(
|
| 33 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 34 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 35 |
+
# model training and testing settings
|
| 36 |
+
train_cfg=dict(
|
| 37 |
+
rpn=dict(
|
| 38 |
+
assigner=dict(
|
| 39 |
+
type='MaxIoUAssigner',
|
| 40 |
+
pos_iou_thr=0.7,
|
| 41 |
+
neg_iou_thr=0.3,
|
| 42 |
+
min_pos_iou=0.3,
|
| 43 |
+
ignore_iof_thr=-1),
|
| 44 |
+
sampler=dict(
|
| 45 |
+
type='RandomSampler',
|
| 46 |
+
num=256,
|
| 47 |
+
pos_fraction=0.5,
|
| 48 |
+
neg_pos_ub=-1,
|
| 49 |
+
add_gt_as_proposals=False),
|
| 50 |
+
allowed_border=0,
|
| 51 |
+
pos_weight=-1,
|
| 52 |
+
debug=False)),
|
| 53 |
+
test_cfg=dict(
|
| 54 |
+
rpn=dict(
|
| 55 |
+
nms_pre=2000,
|
| 56 |
+
max_per_img=1000,
|
| 57 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 58 |
+
min_bbox_size=0)))
|
model/configs/_base_/models/ssd300.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model settings
|
| 2 |
+
input_size = 300
|
| 3 |
+
model = dict(
|
| 4 |
+
type='SingleStageDetector',
|
| 5 |
+
backbone=dict(
|
| 6 |
+
type='SSDVGG',
|
| 7 |
+
depth=16,
|
| 8 |
+
with_last_pool=False,
|
| 9 |
+
ceil_mode=True,
|
| 10 |
+
out_indices=(3, 4),
|
| 11 |
+
out_feature_indices=(22, 34),
|
| 12 |
+
init_cfg=dict(
|
| 13 |
+
type='Pretrained', checkpoint='open-mmlab://vgg16_caffe')),
|
| 14 |
+
neck=dict(
|
| 15 |
+
type='SSDNeck',
|
| 16 |
+
in_channels=(512, 1024),
|
| 17 |
+
out_channels=(512, 1024, 512, 256, 256, 256),
|
| 18 |
+
level_strides=(2, 2, 1, 1),
|
| 19 |
+
level_paddings=(1, 1, 0, 0),
|
| 20 |
+
l2_norm_scale=20),
|
| 21 |
+
bbox_head=dict(
|
| 22 |
+
type='SSDHead',
|
| 23 |
+
in_channels=(512, 1024, 512, 256, 256, 256),
|
| 24 |
+
num_classes=80,
|
| 25 |
+
anchor_generator=dict(
|
| 26 |
+
type='SSDAnchorGenerator',
|
| 27 |
+
scale_major=False,
|
| 28 |
+
input_size=input_size,
|
| 29 |
+
basesize_ratio_range=(0.15, 0.9),
|
| 30 |
+
strides=[8, 16, 32, 64, 100, 300],
|
| 31 |
+
ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]]),
|
| 32 |
+
bbox_coder=dict(
|
| 33 |
+
type='DeltaXYWHBBoxCoder',
|
| 34 |
+
target_means=[.0, .0, .0, .0],
|
| 35 |
+
target_stds=[0.1, 0.1, 0.2, 0.2])),
|
| 36 |
+
# model training and testing settings
|
| 37 |
+
train_cfg=dict(
|
| 38 |
+
assigner=dict(
|
| 39 |
+
type='MaxIoUAssigner',
|
| 40 |
+
pos_iou_thr=0.5,
|
| 41 |
+
neg_iou_thr=0.5,
|
| 42 |
+
min_pos_iou=0.,
|
| 43 |
+
ignore_iof_thr=-1,
|
| 44 |
+
gt_max_assign_all=False),
|
| 45 |
+
smoothl1_beta=1.,
|
| 46 |
+
allowed_border=-1,
|
| 47 |
+
pos_weight=-1,
|
| 48 |
+
neg_pos_ratio=3,
|
| 49 |
+
debug=False),
|
| 50 |
+
test_cfg=dict(
|
| 51 |
+
nms_pre=1000,
|
| 52 |
+
nms=dict(type='nms', iou_threshold=0.45),
|
| 53 |
+
min_bbox_size=0,
|
| 54 |
+
score_thr=0.02,
|
| 55 |
+
max_per_img=200))
|
| 56 |
+
cudnn_benchmark = True
|
model/configs/_base_/schedules/schedule_0.5x.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# optimizer
|
| 2 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
| 3 |
+
optimizer_config = dict(grad_clip=None)
|
| 4 |
+
# learning policy
|
| 5 |
+
lr_config = dict(
|
| 6 |
+
policy='step',
|
| 7 |
+
warmup='linear',
|
| 8 |
+
warmup_iters=500,
|
| 9 |
+
warmup_ratio=0.001,
|
| 10 |
+
step=[8, 11])
|
| 11 |
+
runner = dict(type='EpochBasedRunner', max_epochs=6)
|
model/configs/_base_/schedules/schedule_1x.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# optimizer
|
| 2 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
| 3 |
+
optimizer_config = dict(grad_clip=None)
|
| 4 |
+
# learning policy
|
| 5 |
+
lr_config = dict(
|
| 6 |
+
policy='step',
|
| 7 |
+
warmup='linear',
|
| 8 |
+
warmup_iters=500,
|
| 9 |
+
warmup_ratio=0.001,
|
| 10 |
+
step=[8, 11])
|
| 11 |
+
runner = dict(type='EpochBasedRunner', max_epochs=12)
|
model/configs/_base_/schedules/schedule_2x.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# optimizer
|
| 2 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
| 3 |
+
optimizer_config = dict(grad_clip=None)
|
| 4 |
+
# learning policy
|
| 5 |
+
lr_config = dict(
|
| 6 |
+
policy='step',
|
| 7 |
+
warmup='linear',
|
| 8 |
+
warmup_iters=500,
|
| 9 |
+
warmup_ratio=0.001,
|
| 10 |
+
step=[16, 22])
|
| 11 |
+
runner = dict(type='EpochBasedRunner', max_epochs=24)
|
model/configs/_base_/schedules/schedule_3x.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# optimizer
|
| 2 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
| 3 |
+
optimizer_config = dict(grad_clip=None)
|
| 4 |
+
# learning policy
|
| 5 |
+
lr_config = dict(
|
| 6 |
+
policy='step',
|
| 7 |
+
warmup='linear',
|
| 8 |
+
warmup_iters=500,
|
| 9 |
+
warmup_ratio=0.001,
|
| 10 |
+
step=[27, 33])
|
| 11 |
+
runner = dict(type='EpochBasedRunner', max_epochs=36)
|
model/configs/_base_/schedules/schedule_4x.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# optimizer
|
| 2 |
+
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
| 3 |
+
optimizer_config = dict(grad_clip=None)
|
| 4 |
+
# learning policy
|
| 5 |
+
lr_config = dict(
|
| 6 |
+
policy='step',
|
| 7 |
+
warmup='linear',
|
| 8 |
+
warmup_iters=500,
|
| 9 |
+
warmup_ratio=0.001,
|
| 10 |
+
step=[36, 44])
|
| 11 |
+
runner = dict(type='EpochBasedRunner', max_epochs=48)
|
model/configs/doclaynet/rodla_internimage_xl_3x_doclaynet.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
_base_ = [
|
| 7 |
+
'../_base_/datasets/doclaynet.py',
|
| 8 |
+
'../_base_/schedules/schedule_3x.py',
|
| 9 |
+
'../_base_/default_runtime.py'
|
| 10 |
+
]
|
| 11 |
+
pretrained = 'https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_xl_22k_192to384.pth'
|
| 12 |
+
model = dict(
|
| 13 |
+
type='DINO',
|
| 14 |
+
backbone=dict(
|
| 15 |
+
_delete_=True,
|
| 16 |
+
type='InternImage',
|
| 17 |
+
core_op='DCNv3',
|
| 18 |
+
channels=192,
|
| 19 |
+
depths=[5, 5, 22, 5],
|
| 20 |
+
groups=[12, 24, 48, 96],
|
| 21 |
+
mlp_ratio=4.,
|
| 22 |
+
drop_path_rate=0.4,
|
| 23 |
+
norm_layer='LN',
|
| 24 |
+
layer_scale=1.0,
|
| 25 |
+
offset_scale=2.0,
|
| 26 |
+
post_norm=True,
|
| 27 |
+
with_cp=True,
|
| 28 |
+
out_indices=(1, 2, 3),
|
| 29 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
|
| 30 |
+
neck=dict(
|
| 31 |
+
type='ChannelMapper',
|
| 32 |
+
in_channels=[384, 768, 1536],
|
| 33 |
+
kernel_size=1,
|
| 34 |
+
out_channels=256,
|
| 35 |
+
act_cfg=None,
|
| 36 |
+
norm_cfg=dict(type='GN', num_groups=32),
|
| 37 |
+
num_outs=4),
|
| 38 |
+
bbox_head=dict(
|
| 39 |
+
type='DINOHead',
|
| 40 |
+
num_query=3000,
|
| 41 |
+
num_classes=11,
|
| 42 |
+
in_channels=2048,
|
| 43 |
+
sync_cls_avg_factor=True,
|
| 44 |
+
as_two_stage=True,
|
| 45 |
+
with_box_refine=True,
|
| 46 |
+
dn_cfg=dict(
|
| 47 |
+
type='CdnQueryGenerator',
|
| 48 |
+
noise_scale=dict(label=0.5, box=1.0),
|
| 49 |
+
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
|
| 50 |
+
transformer=dict(
|
| 51 |
+
type='DinoTransformer',
|
| 52 |
+
two_stage_num_proposals=3000,
|
| 53 |
+
encoder=dict(
|
| 54 |
+
type='DetrTransformerEncoder',
|
| 55 |
+
num_layers=6,
|
| 56 |
+
transformerlayers=dict(
|
| 57 |
+
type='TAPFANTransformerLayer',
|
| 58 |
+
attn_cfgs=dict(
|
| 59 |
+
type='MultiScaleDeformableAttention',
|
| 60 |
+
embed_dims=256,
|
| 61 |
+
dropout=0.0),
|
| 62 |
+
feedforward_channels=2048,
|
| 63 |
+
#ffn_dropout=0.0, # 0.1 for DeformDETR
|
| 64 |
+
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
|
| 65 |
+
decoder=dict(
|
| 66 |
+
type='DinoTransformerDecoder',
|
| 67 |
+
num_layers=6,
|
| 68 |
+
return_intermediate=True,
|
| 69 |
+
transformerlayers=dict(
|
| 70 |
+
type='DetrTransformerDecoderLayer',
|
| 71 |
+
attn_cfgs=[
|
| 72 |
+
dict(
|
| 73 |
+
type='MultiheadAttention',
|
| 74 |
+
embed_dims=256,
|
| 75 |
+
num_heads=8,
|
| 76 |
+
dropout=0.0),
|
| 77 |
+
dict(
|
| 78 |
+
type='MultiScaleDeformableAttention',
|
| 79 |
+
embed_dims=256,
|
| 80 |
+
dropout=0.0),
|
| 81 |
+
],
|
| 82 |
+
feedforward_channels=2048,
|
| 83 |
+
ffn_dropout=0.0, # 0.1 for DeformDETR
|
| 84 |
+
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
|
| 85 |
+
'ffn', 'norm')))),
|
| 86 |
+
positional_encoding=dict(
|
| 87 |
+
type='SinePositionalEncoding',
|
| 88 |
+
num_feats=128,
|
| 89 |
+
temperature=20,
|
| 90 |
+
normalize=True),
|
| 91 |
+
loss_cls=dict(
|
| 92 |
+
type='FocalLoss',
|
| 93 |
+
use_sigmoid=True,
|
| 94 |
+
gamma=2.0,
|
| 95 |
+
alpha=0.25,
|
| 96 |
+
loss_weight=1.0),
|
| 97 |
+
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
|
| 98 |
+
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
|
| 99 |
+
# training and testing settings
|
| 100 |
+
train_cfg=dict(
|
| 101 |
+
assigner=dict(
|
| 102 |
+
type='HungarianAssigner',
|
| 103 |
+
cls_cost=dict(type='FocalLossCost', weight=2.0),
|
| 104 |
+
reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
|
| 105 |
+
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))),
|
| 106 |
+
test_cfg=dict(max_per_img=300))
|
| 107 |
+
img_norm_cfg = dict(
|
| 108 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 109 |
+
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
|
| 110 |
+
# from the default setting in mmdet.
|
| 111 |
+
train_pipeline = [
|
| 112 |
+
dict(type='LoadImageFromFile'),
|
| 113 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
| 114 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 115 |
+
dict(
|
| 116 |
+
type='AutoAugment',
|
| 117 |
+
policies=[
|
| 118 |
+
[
|
| 119 |
+
dict(
|
| 120 |
+
type='Resize',
|
| 121 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
| 122 |
+
(576, 1333), (608, 1333), (640, 1333),
|
| 123 |
+
(672, 1333), (704, 1333), (736, 1333),
|
| 124 |
+
(768, 1333), (800, 1333)],
|
| 125 |
+
multiscale_mode='value',
|
| 126 |
+
keep_ratio=True)
|
| 127 |
+
],
|
| 128 |
+
[
|
| 129 |
+
dict(
|
| 130 |
+
type='Resize',
|
| 131 |
+
img_scale=[(400, 4200), (500, 4200), (600, 4200)],
|
| 132 |
+
multiscale_mode='value',
|
| 133 |
+
keep_ratio=True),
|
| 134 |
+
dict(
|
| 135 |
+
type='RandomCrop',
|
| 136 |
+
crop_type='absolute_range',
|
| 137 |
+
crop_size=(384, 600),
|
| 138 |
+
allow_negative_crop=False),
|
| 139 |
+
dict(
|
| 140 |
+
type='Resize',
|
| 141 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
| 142 |
+
(576, 1333), (608, 1333), (640, 1333),
|
| 143 |
+
(672, 1333), (704, 1333), (736, 1333),
|
| 144 |
+
(768, 1333), (800, 1333)],
|
| 145 |
+
multiscale_mode='value',
|
| 146 |
+
override=True,
|
| 147 |
+
keep_ratio=True)
|
| 148 |
+
]
|
| 149 |
+
]),
|
| 150 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 151 |
+
dict(type='Pad', size_divisor=32),
|
| 152 |
+
dict(type='DefaultFormatBundle'),
|
| 153 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
| 154 |
+
]
|
| 155 |
+
# By default, models are trained on 8 GPUs with 2 images per GPU
|
| 156 |
+
data = dict(
|
| 157 |
+
samples_per_gpu=1,
|
| 158 |
+
train=dict(pipeline=train_pipeline))
|
| 159 |
+
# optimizer
|
| 160 |
+
optimizer = dict(
|
| 161 |
+
_delete_=True, type='AdamW', lr=0.0001 * 4, weight_decay=0.005,
|
| 162 |
+
constructor='CustomLayerDecayOptimizerConstructor',
|
| 163 |
+
paramwise_cfg=dict(num_layers=37, layer_decay_rate=0.90,
|
| 164 |
+
depths=[5, 5, 22, 5]))
|
| 165 |
+
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=1, norm_type=2))
|
| 166 |
+
# learning policy
|
| 167 |
+
lr_config = dict(
|
| 168 |
+
policy='step',
|
| 169 |
+
warmup='linear',
|
| 170 |
+
warmup_iters=1000,
|
| 171 |
+
warmup_ratio=0.01,
|
| 172 |
+
step=[24, 33])
|
| 173 |
+
evaluation = dict(save_best='auto', metric=['bbox'], classwise=True)
|
| 174 |
+
checkpoint_config = dict(
|
| 175 |
+
interval=1,
|
| 176 |
+
max_keep_ckpts=3,
|
| 177 |
+
save_last=True,
|
| 178 |
+
)
|
model/configs/m6doc/rodla_internimage_xl_2x_m6doc.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
_base_ = [
|
| 7 |
+
'../_base_/datasets/m6doc.py',
|
| 8 |
+
'../_base_/schedules/schedule_3x.py',
|
| 9 |
+
'../_base_/default_runtime.py'
|
| 10 |
+
]
|
| 11 |
+
pretrained = 'https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_xl_22k_192to384.pth'
|
| 12 |
+
model = dict(
|
| 13 |
+
type='DINO',
|
| 14 |
+
backbone=dict(
|
| 15 |
+
_delete_=True,
|
| 16 |
+
type='InternImage',
|
| 17 |
+
core_op='DCNv3',
|
| 18 |
+
channels=192,
|
| 19 |
+
depths=[5, 5, 22, 5],
|
| 20 |
+
groups=[12, 24, 48, 96],
|
| 21 |
+
mlp_ratio=4.,
|
| 22 |
+
drop_path_rate=0.4,
|
| 23 |
+
norm_layer='LN',
|
| 24 |
+
layer_scale=1.0,
|
| 25 |
+
offset_scale=2.0,
|
| 26 |
+
post_norm=True,
|
| 27 |
+
with_cp=True,
|
| 28 |
+
out_indices=(1, 2, 3),
|
| 29 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
|
| 30 |
+
neck=dict(
|
| 31 |
+
type='ChannelMapper',
|
| 32 |
+
in_channels=[384, 768, 1536],
|
| 33 |
+
kernel_size=1,
|
| 34 |
+
out_channels=256,
|
| 35 |
+
act_cfg=None,
|
| 36 |
+
norm_cfg=dict(type='GN', num_groups=32),
|
| 37 |
+
num_outs=4),
|
| 38 |
+
bbox_head=dict(
|
| 39 |
+
type='DINOHead',
|
| 40 |
+
num_query=3000,
|
| 41 |
+
num_classes=75,
|
| 42 |
+
in_channels=2048,
|
| 43 |
+
sync_cls_avg_factor=True,
|
| 44 |
+
as_two_stage=True,
|
| 45 |
+
with_box_refine=True,
|
| 46 |
+
dn_cfg=dict(
|
| 47 |
+
type='CdnQueryGenerator',
|
| 48 |
+
noise_scale=dict(label=0.5, box=1.0),
|
| 49 |
+
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=200)),
|
| 50 |
+
transformer=dict(
|
| 51 |
+
type='DinoTransformer',
|
| 52 |
+
num_feature_levels=4,
|
| 53 |
+
two_stage_num_proposals=3000,
|
| 54 |
+
encoder=dict(
|
| 55 |
+
type='DetrTransformerEncoder',
|
| 56 |
+
num_layers=6,
|
| 57 |
+
transformerlayers=dict(
|
| 58 |
+
type='TAPFANTransformerLayer',
|
| 59 |
+
attn_cfgs=dict(
|
| 60 |
+
type='MultiScaleDeformableAttention',
|
| 61 |
+
embed_dims=256,
|
| 62 |
+
dropout=0.0),
|
| 63 |
+
feedforward_channels=2048,
|
| 64 |
+
#ffn_dropout=0.0, # 0.1 for DeformDETR
|
| 65 |
+
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
|
| 66 |
+
decoder=dict(
|
| 67 |
+
type='DinoTransformerDecoder',
|
| 68 |
+
num_layers=8,
|
| 69 |
+
return_intermediate=True,
|
| 70 |
+
transformerlayers=dict(
|
| 71 |
+
type='DetrTransformerDecoderLayer',
|
| 72 |
+
num_feature_levels=4,
|
| 73 |
+
attn_cfgs=[
|
| 74 |
+
dict(
|
| 75 |
+
type='MultiheadAttention',
|
| 76 |
+
embed_dims=256,
|
| 77 |
+
num_heads=8,
|
| 78 |
+
dropout=0.0),
|
| 79 |
+
dict(
|
| 80 |
+
type='MultiScaleDeformableAttention',
|
| 81 |
+
embed_dims=256,
|
| 82 |
+
dropout=0.0),
|
| 83 |
+
],
|
| 84 |
+
feedforward_channels=2048,
|
| 85 |
+
ffn_dropout=0.0, # 0.1 for DeformDETR
|
| 86 |
+
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
|
| 87 |
+
'ffn', 'norm')))),
|
| 88 |
+
positional_encoding=dict(
|
| 89 |
+
type='SinePositionalEncoding',
|
| 90 |
+
num_feats=128,
|
| 91 |
+
temperature=20,
|
| 92 |
+
normalize=True),
|
| 93 |
+
loss_cls=dict(
|
| 94 |
+
type='FocalLoss',
|
| 95 |
+
use_sigmoid=True,
|
| 96 |
+
gamma=2.0,
|
| 97 |
+
alpha=0.25,
|
| 98 |
+
loss_weight=1.0),
|
| 99 |
+
loss_bbox=dict(type='SmoothL1Loss', loss_weight=5.0),
|
| 100 |
+
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
|
| 101 |
+
# training and testing settings
|
| 102 |
+
train_cfg=dict(
|
| 103 |
+
assigner=dict(
|
| 104 |
+
type='HungarianAssigner',
|
| 105 |
+
cls_cost=dict(type='FocalLossCost', weight=2.0),
|
| 106 |
+
reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
|
| 107 |
+
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))),
|
| 108 |
+
test_cfg=dict(max_per_img=300))
|
| 109 |
+
img_norm_cfg = dict(
|
| 110 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 111 |
+
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
|
| 112 |
+
# from the default setting in mmdet.
|
| 113 |
+
train_pipeline = [
|
| 114 |
+
dict(type='LoadImageFromFile'),
|
| 115 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
| 116 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 117 |
+
dict(
|
| 118 |
+
type='AutoAugment',
|
| 119 |
+
policies=[
|
| 120 |
+
[
|
| 121 |
+
dict(
|
| 122 |
+
type='Resize',
|
| 123 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
| 124 |
+
(576, 1333), (608, 1333), (640, 1333),
|
| 125 |
+
(672, 1333), (704, 1333), (736, 1333),
|
| 126 |
+
(768, 1333), (800, 1333)],
|
| 127 |
+
multiscale_mode='value',
|
| 128 |
+
keep_ratio=True)
|
| 129 |
+
],
|
| 130 |
+
[
|
| 131 |
+
dict(
|
| 132 |
+
type='Resize',
|
| 133 |
+
img_scale=[(400, 4200), (500, 4200), (600, 4200)],
|
| 134 |
+
multiscale_mode='value',
|
| 135 |
+
keep_ratio=True),
|
| 136 |
+
dict(
|
| 137 |
+
type='RandomCrop',
|
| 138 |
+
crop_type='absolute_range',
|
| 139 |
+
crop_size=(384, 600),
|
| 140 |
+
allow_negative_crop=False),
|
| 141 |
+
dict(
|
| 142 |
+
type='Resize',
|
| 143 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
| 144 |
+
(576, 1333), (608, 1333), (640, 1333),
|
| 145 |
+
(672, 1333), (704, 1333), (736, 1333),
|
| 146 |
+
(768, 1333), (800, 1333)],
|
| 147 |
+
multiscale_mode='value',
|
| 148 |
+
override=True,
|
| 149 |
+
keep_ratio=True)
|
| 150 |
+
]
|
| 151 |
+
]),
|
| 152 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 153 |
+
dict(type='Pad', size_divisor=32),
|
| 154 |
+
dict(type='DefaultFormatBundle'),
|
| 155 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
| 156 |
+
]
|
| 157 |
+
# By default, models are trained on 8 GPUs with 2 images per GPU
|
| 158 |
+
data = dict(
|
| 159 |
+
samples_per_gpu=2,
|
| 160 |
+
train=dict(pipeline=train_pipeline))
|
| 161 |
+
# optimizer
|
| 162 |
+
optimizer = dict(
|
| 163 |
+
_delete_=True, type='AdamW', lr=0.0001 * 2, weight_decay=0.05,
|
| 164 |
+
constructor='CustomLayerDecayOptimizerConstructor',
|
| 165 |
+
paramwise_cfg=dict(num_layers=37, layer_decay_rate=0.90,
|
| 166 |
+
depths=[5, 5, 22, 5]))
|
| 167 |
+
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=0.1, norm_type=2))
|
| 168 |
+
# learning policy
|
| 169 |
+
lr_config = dict(
|
| 170 |
+
policy='step',
|
| 171 |
+
warmup='linear',
|
| 172 |
+
warmup_iters=500,
|
| 173 |
+
warmup_ratio=0.001,
|
| 174 |
+
step=[24, 33])
|
| 175 |
+
evaluation = dict(save_best='auto', metric=['bbox'], classwise=True)
|
| 176 |
+
checkpoint_config = dict(
|
| 177 |
+
interval=1,
|
| 178 |
+
max_keep_ckpts=3,
|
| 179 |
+
save_last=True,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
log_config = dict(
|
| 183 |
+
interval=10,
|
| 184 |
+
hooks=[
|
| 185 |
+
dict(type='TextLoggerHook'),
|
| 186 |
+
# dict(type='TensorboardLoggerHook')
|
| 187 |
+
])
|
model/configs/publaynet/rodla_internimage_xl_2x_publaynet.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
_base_ = [
|
| 7 |
+
'../_base_/datasets/publaynet.py',
|
| 8 |
+
'../_base_/schedules/schedule_2x.py',
|
| 9 |
+
'../_base_/default_runtime.py'
|
| 10 |
+
]
|
| 11 |
+
pretrained = 'https://github.com/OpenGVLab/InternImage/releases/download/cls_model/internimage_xl_22k_192to384.pth'
|
| 12 |
+
model = dict(
|
| 13 |
+
type='DINO',
|
| 14 |
+
backbone=dict(
|
| 15 |
+
_delete_=True,
|
| 16 |
+
type='InternImage',
|
| 17 |
+
core_op='DCNv3',
|
| 18 |
+
channels=192,
|
| 19 |
+
depths=[5, 5, 22, 5],
|
| 20 |
+
groups=[12, 24, 48, 96],
|
| 21 |
+
mlp_ratio=4.,
|
| 22 |
+
drop_path_rate=0.4,
|
| 23 |
+
norm_layer='LN',
|
| 24 |
+
layer_scale=1.0,
|
| 25 |
+
offset_scale=2.0,
|
| 26 |
+
post_norm=True,
|
| 27 |
+
with_cp=True,
|
| 28 |
+
out_indices=(1, 2, 3),
|
| 29 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
|
| 30 |
+
neck=dict(
|
| 31 |
+
type='ChannelMapper',
|
| 32 |
+
in_channels=[384, 768, 1536],
|
| 33 |
+
kernel_size=1,
|
| 34 |
+
out_channels=256,
|
| 35 |
+
act_cfg=None,
|
| 36 |
+
norm_cfg=dict(type='GN', num_groups=32),
|
| 37 |
+
num_outs=4),
|
| 38 |
+
bbox_head=dict(
|
| 39 |
+
type='DINOHead',
|
| 40 |
+
num_query=3000,
|
| 41 |
+
num_classes=5,
|
| 42 |
+
in_channels=2048,
|
| 43 |
+
sync_cls_avg_factor=True,
|
| 44 |
+
as_two_stage=True,
|
| 45 |
+
with_box_refine=True,
|
| 46 |
+
dn_cfg=dict(
|
| 47 |
+
type='CdnQueryGenerator',
|
| 48 |
+
noise_scale=dict(label=0.5, box=1.0),
|
| 49 |
+
group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)),
|
| 50 |
+
transformer=dict(
|
| 51 |
+
type='DinoTransformer',
|
| 52 |
+
two_stage_num_proposals=3000,
|
| 53 |
+
encoder=dict(
|
| 54 |
+
type='DetrTransformerEncoder',
|
| 55 |
+
num_layers=6,
|
| 56 |
+
transformerlayers=dict(
|
| 57 |
+
type='TAPFANTransformerLayer',
|
| 58 |
+
attn_cfgs=dict(
|
| 59 |
+
type='MultiScaleDeformableAttention',
|
| 60 |
+
embed_dims=256,
|
| 61 |
+
dropout=0.0),
|
| 62 |
+
feedforward_channels=2048,
|
| 63 |
+
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
|
| 64 |
+
decoder=dict(
|
| 65 |
+
type='DinoTransformerDecoder',
|
| 66 |
+
num_layers=6,
|
| 67 |
+
return_intermediate=True,
|
| 68 |
+
transformerlayers=dict(
|
| 69 |
+
type='DetrTransformerDecoderLayer',
|
| 70 |
+
attn_cfgs=[
|
| 71 |
+
dict(
|
| 72 |
+
type='MultiheadAttention',
|
| 73 |
+
embed_dims=256,
|
| 74 |
+
num_heads=8,
|
| 75 |
+
dropout=0.0),
|
| 76 |
+
dict(
|
| 77 |
+
type='MultiScaleDeformableAttention',
|
| 78 |
+
embed_dims=256,
|
| 79 |
+
dropout=0.0),
|
| 80 |
+
],
|
| 81 |
+
feedforward_channels=2048,
|
| 82 |
+
ffn_dropout=0.0,
|
| 83 |
+
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
|
| 84 |
+
'ffn', 'norm')))),
|
| 85 |
+
positional_encoding=dict(
|
| 86 |
+
type='SinePositionalEncoding',
|
| 87 |
+
num_feats=128,
|
| 88 |
+
temperature=20,
|
| 89 |
+
normalize=True),
|
| 90 |
+
loss_cls=dict(
|
| 91 |
+
type='FocalLoss',
|
| 92 |
+
use_sigmoid=True,
|
| 93 |
+
gamma=2.0,
|
| 94 |
+
alpha=0.25,
|
| 95 |
+
loss_weight=1.0),
|
| 96 |
+
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
|
| 97 |
+
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
|
| 98 |
+
# training and testing settings
|
| 99 |
+
train_cfg=dict(
|
| 100 |
+
assigner=dict(
|
| 101 |
+
type='HungarianAssigner',
|
| 102 |
+
cls_cost=dict(type='FocalLossCost', weight=2.0),
|
| 103 |
+
reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
|
| 104 |
+
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))),
|
| 105 |
+
test_cfg=dict(max_per_img=300))
|
| 106 |
+
img_norm_cfg = dict(
|
| 107 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 108 |
+
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
|
| 109 |
+
# from the default setting in mmdet.
|
| 110 |
+
train_pipeline = [
|
| 111 |
+
dict(type='LoadImageFromFile'),
|
| 112 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
| 113 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 114 |
+
dict(
|
| 115 |
+
type='AutoAugment',
|
| 116 |
+
policies=[
|
| 117 |
+
[
|
| 118 |
+
dict(
|
| 119 |
+
type='Resize',
|
| 120 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
| 121 |
+
(576, 1333), (608, 1333), (640, 1333),
|
| 122 |
+
(672, 1333), (704, 1333), (736, 1333),
|
| 123 |
+
(768, 1333), (800, 1333)],
|
| 124 |
+
multiscale_mode='value',
|
| 125 |
+
keep_ratio=True)
|
| 126 |
+
],
|
| 127 |
+
[
|
| 128 |
+
dict(
|
| 129 |
+
type='Resize',
|
| 130 |
+
img_scale=[(400, 4200), (500, 4200), (600, 4200)],
|
| 131 |
+
multiscale_mode='value',
|
| 132 |
+
keep_ratio=True),
|
| 133 |
+
dict(
|
| 134 |
+
type='RandomCrop',
|
| 135 |
+
crop_type='absolute_range',
|
| 136 |
+
crop_size=(384, 600),
|
| 137 |
+
allow_negative_crop=False),
|
| 138 |
+
dict(
|
| 139 |
+
type='Resize',
|
| 140 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
| 141 |
+
(576, 1333), (608, 1333), (640, 1333),
|
| 142 |
+
(672, 1333), (704, 1333), (736, 1333),
|
| 143 |
+
(768, 1333), (800, 1333)],
|
| 144 |
+
multiscale_mode='value',
|
| 145 |
+
override=True,
|
| 146 |
+
keep_ratio=True)
|
| 147 |
+
]
|
| 148 |
+
]),
|
| 149 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 150 |
+
dict(type='Pad', size_divisor=32),
|
| 151 |
+
dict(type='DefaultFormatBundle'),
|
| 152 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
|
| 153 |
+
]
|
| 154 |
+
# By default, models are trained on 8 GPUs with 2 images per GPU
|
| 155 |
+
data = dict(
|
| 156 |
+
samples_per_gpu=2,
|
| 157 |
+
train=dict(pipeline=train_pipeline))
|
| 158 |
+
# optimizer
|
| 159 |
+
optimizer = dict(
|
| 160 |
+
_delete_=True, type='AdamW', lr=0.0002, weight_decay=0.0001,
|
| 161 |
+
constructor='CustomLayerDecayOptimizerConstructor',
|
| 162 |
+
paramwise_cfg=dict(num_layers=37, layer_decay_rate=0.90,
|
| 163 |
+
depths=[5, 5, 22, 5]))
|
| 164 |
+
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=0.1, norm_type=2))
|
| 165 |
+
# learning policy
|
| 166 |
+
lr_config = dict(
|
| 167 |
+
policy='step',
|
| 168 |
+
warmup='linear',
|
| 169 |
+
warmup_iters=500,
|
| 170 |
+
warmup_ratio=0.001,
|
| 171 |
+
step=[16, 22])
|
| 172 |
+
evaluation = dict(save_best='auto', metric=['bbox'], classwise=True)
|
| 173 |
+
checkpoint_config = dict(
|
| 174 |
+
interval=1,
|
| 175 |
+
max_keep_ckpts=3,
|
| 176 |
+
save_last=True,
|
| 177 |
+
)
|
model/deploy.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import mmcv
|
| 9 |
+
import torch.multiprocessing as mp
|
| 10 |
+
from torch.multiprocessing import Process, set_start_method
|
| 11 |
+
|
| 12 |
+
from mmdeploy.apis import (create_calib_input_data, extract_model,
|
| 13 |
+
get_predefined_partition_cfg, torch2onnx,
|
| 14 |
+
torch2torchscript, visualize_model)
|
| 15 |
+
from mmdeploy.apis.core import PIPELINE_MANAGER
|
| 16 |
+
from mmdeploy.apis.utils import to_backend
|
| 17 |
+
from mmdeploy.backend.sdk.export_info import export2SDK
|
| 18 |
+
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename,
|
| 19 |
+
get_ir_config, get_partition_config,
|
| 20 |
+
get_root_logger, load_config, target_wrapper)
|
| 21 |
+
|
| 22 |
+
import mmcv_custom
|
| 23 |
+
import mmdet_custom
|
| 24 |
+
|
| 25 |
+
def parse_args():
|
| 26 |
+
parser = argparse.ArgumentParser(description='Export model to backends.')
|
| 27 |
+
parser.add_argument('deploy_cfg', help='deploy config path')
|
| 28 |
+
parser.add_argument('model_cfg', help='model config path')
|
| 29 |
+
parser.add_argument('checkpoint', help='model checkpoint path')
|
| 30 |
+
parser.add_argument('img', help='image used to convert model model')
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
'--test-img',
|
| 33 |
+
default=None,
|
| 34 |
+
type=str,
|
| 35 |
+
nargs='+',
|
| 36 |
+
help='image used to test model')
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
'--work-dir',
|
| 39 |
+
default=os.getcwd(),
|
| 40 |
+
help='the dir to save logs and models')
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
'--calib-dataset-cfg',
|
| 43 |
+
help=('dataset config path used to calibrate in int8 mode. If not '
|
| 44 |
+
'specified, it will use "val" dataset in model config instead.'),
|
| 45 |
+
default=None)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
'--device', help='device used for conversion', default='cpu')
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
'--log-level',
|
| 50 |
+
help='set log level',
|
| 51 |
+
default='INFO',
|
| 52 |
+
choices=list(logging._nameToLevel.keys()))
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
'--show', action='store_true', help='Show detection outputs')
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
'--dump-info', action='store_true', help='Output information for SDK')
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
'--quant-image-dir',
|
| 59 |
+
default=None,
|
| 60 |
+
help='Image directory for quantize model.')
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
'--quant', action='store_true', help='Quantize model to low bit.')
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
'--uri',
|
| 65 |
+
default='192.168.1.1:60000',
|
| 66 |
+
help='Remote ipv4:port or ipv6:port for inference on edge device.')
|
| 67 |
+
args = parser.parse_args()
|
| 68 |
+
return args
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def create_process(name, target, args, kwargs, ret_value=None):
|
| 72 |
+
logger = get_root_logger()
|
| 73 |
+
logger.info(f'{name} start.')
|
| 74 |
+
log_level = logger.level
|
| 75 |
+
|
| 76 |
+
wrap_func = partial(target_wrapper, target, log_level, ret_value)
|
| 77 |
+
|
| 78 |
+
process = Process(target=wrap_func, args=args, kwargs=kwargs)
|
| 79 |
+
process.start()
|
| 80 |
+
process.join()
|
| 81 |
+
|
| 82 |
+
if ret_value is not None:
|
| 83 |
+
if ret_value.value != 0:
|
| 84 |
+
logger.error(f'{name} failed.')
|
| 85 |
+
exit(1)
|
| 86 |
+
else:
|
| 87 |
+
logger.info(f'{name} success.')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def torch2ir(ir_type: IR):
|
| 91 |
+
"""Return the conversion function from torch to the intermediate
|
| 92 |
+
representation.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
ir_type (IR): The type of the intermediate representation.
|
| 96 |
+
"""
|
| 97 |
+
if ir_type == IR.ONNX:
|
| 98 |
+
return torch2onnx
|
| 99 |
+
elif ir_type == IR.TORCHSCRIPT:
|
| 100 |
+
return torch2torchscript
|
| 101 |
+
else:
|
| 102 |
+
raise KeyError(f'Unexpected IR type {ir_type}')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
args = parse_args()
|
| 107 |
+
set_start_method('spawn', force=True)
|
| 108 |
+
logger = get_root_logger()
|
| 109 |
+
log_level = logging.getLevelName(args.log_level)
|
| 110 |
+
logger.setLevel(log_level)
|
| 111 |
+
|
| 112 |
+
pipeline_funcs = [
|
| 113 |
+
torch2onnx, torch2torchscript, extract_model, create_calib_input_data
|
| 114 |
+
]
|
| 115 |
+
PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs)
|
| 116 |
+
PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs)
|
| 117 |
+
|
| 118 |
+
deploy_cfg_path = args.deploy_cfg
|
| 119 |
+
model_cfg_path = args.model_cfg
|
| 120 |
+
checkpoint_path = args.checkpoint
|
| 121 |
+
quant = args.quant
|
| 122 |
+
quant_image_dir = args.quant_image_dir
|
| 123 |
+
|
| 124 |
+
# load deploy_cfg
|
| 125 |
+
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
| 126 |
+
|
| 127 |
+
# create work_dir if not
|
| 128 |
+
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
| 129 |
+
|
| 130 |
+
if args.dump_info:
|
| 131 |
+
export2SDK(
|
| 132 |
+
deploy_cfg,
|
| 133 |
+
model_cfg,
|
| 134 |
+
args.work_dir,
|
| 135 |
+
pth=checkpoint_path,
|
| 136 |
+
device=args.device)
|
| 137 |
+
|
| 138 |
+
ret_value = mp.Value('d', 0, lock=False)
|
| 139 |
+
|
| 140 |
+
# convert to IR
|
| 141 |
+
ir_config = get_ir_config(deploy_cfg)
|
| 142 |
+
ir_save_file = ir_config['save_file']
|
| 143 |
+
ir_type = IR.get(ir_config['type'])
|
| 144 |
+
torch2ir(ir_type)(
|
| 145 |
+
args.img,
|
| 146 |
+
args.work_dir,
|
| 147 |
+
ir_save_file,
|
| 148 |
+
deploy_cfg_path,
|
| 149 |
+
model_cfg_path,
|
| 150 |
+
checkpoint_path,
|
| 151 |
+
device=args.device)
|
| 152 |
+
|
| 153 |
+
# convert backend
|
| 154 |
+
ir_files = [osp.join(args.work_dir, ir_save_file)]
|
| 155 |
+
|
| 156 |
+
# partition model
|
| 157 |
+
partition_cfgs = get_partition_config(deploy_cfg)
|
| 158 |
+
|
| 159 |
+
if partition_cfgs is not None:
|
| 160 |
+
|
| 161 |
+
if 'partition_cfg' in partition_cfgs:
|
| 162 |
+
partition_cfgs = partition_cfgs.get('partition_cfg', None)
|
| 163 |
+
else:
|
| 164 |
+
assert 'type' in partition_cfgs
|
| 165 |
+
partition_cfgs = get_predefined_partition_cfg(
|
| 166 |
+
deploy_cfg, partition_cfgs['type'])
|
| 167 |
+
|
| 168 |
+
origin_ir_file = ir_files[0]
|
| 169 |
+
ir_files = []
|
| 170 |
+
for partition_cfg in partition_cfgs:
|
| 171 |
+
save_file = partition_cfg['save_file']
|
| 172 |
+
save_path = osp.join(args.work_dir, save_file)
|
| 173 |
+
start = partition_cfg['start']
|
| 174 |
+
end = partition_cfg['end']
|
| 175 |
+
dynamic_axes = partition_cfg.get('dynamic_axes', None)
|
| 176 |
+
|
| 177 |
+
extract_model(
|
| 178 |
+
origin_ir_file,
|
| 179 |
+
start,
|
| 180 |
+
end,
|
| 181 |
+
dynamic_axes=dynamic_axes,
|
| 182 |
+
save_file=save_path)
|
| 183 |
+
|
| 184 |
+
ir_files.append(save_path)
|
| 185 |
+
|
| 186 |
+
# calib data
|
| 187 |
+
calib_filename = get_calib_filename(deploy_cfg)
|
| 188 |
+
if calib_filename is not None:
|
| 189 |
+
calib_path = osp.join(args.work_dir, calib_filename)
|
| 190 |
+
create_calib_input_data(
|
| 191 |
+
calib_path,
|
| 192 |
+
deploy_cfg_path,
|
| 193 |
+
model_cfg_path,
|
| 194 |
+
checkpoint_path,
|
| 195 |
+
dataset_cfg=args.calib_dataset_cfg,
|
| 196 |
+
dataset_type='val',
|
| 197 |
+
device=args.device)
|
| 198 |
+
|
| 199 |
+
backend_files = ir_files
|
| 200 |
+
# convert backend
|
| 201 |
+
backend = get_backend(deploy_cfg)
|
| 202 |
+
|
| 203 |
+
# preprocess deploy_cfg
|
| 204 |
+
if backend == Backend.RKNN:
|
| 205 |
+
# TODO: Add this to task_processor in the future
|
| 206 |
+
import tempfile
|
| 207 |
+
|
| 208 |
+
from mmdeploy.utils import (get_common_config, get_normalization,
|
| 209 |
+
get_quantization_config,
|
| 210 |
+
get_rknn_quantization)
|
| 211 |
+
quantization_cfg = get_quantization_config(deploy_cfg)
|
| 212 |
+
common_params = get_common_config(deploy_cfg)
|
| 213 |
+
if get_rknn_quantization(deploy_cfg) is True:
|
| 214 |
+
transform = get_normalization(model_cfg)
|
| 215 |
+
common_params.update(
|
| 216 |
+
dict(
|
| 217 |
+
mean_values=[transform['mean']],
|
| 218 |
+
std_values=[transform['std']]))
|
| 219 |
+
|
| 220 |
+
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name
|
| 221 |
+
with open(dataset_file, 'w') as f:
|
| 222 |
+
f.writelines([osp.abspath(args.img)])
|
| 223 |
+
quantization_cfg.setdefault('dataset', dataset_file)
|
| 224 |
+
if backend == Backend.ASCEND:
|
| 225 |
+
# TODO: Add this to backend manager in the future
|
| 226 |
+
if args.dump_info:
|
| 227 |
+
from mmdeploy.backend.ascend import update_sdk_pipeline
|
| 228 |
+
update_sdk_pipeline(args.work_dir)
|
| 229 |
+
|
| 230 |
+
# convert to backend
|
| 231 |
+
PIPELINE_MANAGER.set_log_level(log_level, [to_backend])
|
| 232 |
+
if backend == Backend.TENSORRT:
|
| 233 |
+
PIPELINE_MANAGER.enable_multiprocess(True, [to_backend])
|
| 234 |
+
backend_files = to_backend(
|
| 235 |
+
backend,
|
| 236 |
+
ir_files,
|
| 237 |
+
work_dir=args.work_dir,
|
| 238 |
+
deploy_cfg=deploy_cfg,
|
| 239 |
+
log_level=log_level,
|
| 240 |
+
device=args.device,
|
| 241 |
+
uri=args.uri)
|
| 242 |
+
|
| 243 |
+
# ncnn quantization
|
| 244 |
+
if backend == Backend.NCNN and quant:
|
| 245 |
+
from onnx2ncnn_quant_table import get_table
|
| 246 |
+
|
| 247 |
+
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8
|
| 248 |
+
model_param_paths = backend_files[::2]
|
| 249 |
+
model_bin_paths = backend_files[1::2]
|
| 250 |
+
backend_files = []
|
| 251 |
+
for onnx_path, model_param_path, model_bin_path in zip(
|
| 252 |
+
ir_files, model_param_paths, model_bin_paths):
|
| 253 |
+
|
| 254 |
+
deploy_cfg, model_cfg = load_config(deploy_cfg_path,
|
| 255 |
+
model_cfg_path)
|
| 256 |
+
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501
|
| 257 |
+
onnx_path, args.work_dir)
|
| 258 |
+
|
| 259 |
+
create_process(
|
| 260 |
+
'ncnn quant table',
|
| 261 |
+
target=get_table,
|
| 262 |
+
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx,
|
| 263 |
+
quant_table, quant_image_dir, args.device),
|
| 264 |
+
kwargs=dict(),
|
| 265 |
+
ret_value=ret_value)
|
| 266 |
+
|
| 267 |
+
create_process(
|
| 268 |
+
'ncnn_int8',
|
| 269 |
+
target=ncnn2int8,
|
| 270 |
+
args=(model_param_path, model_bin_path, quant_table,
|
| 271 |
+
quant_param, quant_bin),
|
| 272 |
+
kwargs=dict(),
|
| 273 |
+
ret_value=ret_value)
|
| 274 |
+
backend_files += [quant_param, quant_bin]
|
| 275 |
+
|
| 276 |
+
if args.test_img is None:
|
| 277 |
+
args.test_img = args.img
|
| 278 |
+
|
| 279 |
+
extra = dict(
|
| 280 |
+
backend=backend,
|
| 281 |
+
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'),
|
| 282 |
+
show_result=args.show)
|
| 283 |
+
if backend == Backend.SNPE:
|
| 284 |
+
extra['uri'] = args.uri
|
| 285 |
+
|
| 286 |
+
# get backend inference result, try render
|
| 287 |
+
create_process(
|
| 288 |
+
f'visualize {backend.value} model',
|
| 289 |
+
target=visualize_model,
|
| 290 |
+
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img,
|
| 291 |
+
args.device),
|
| 292 |
+
kwargs=extra,
|
| 293 |
+
ret_value=ret_value)
|
| 294 |
+
|
| 295 |
+
# get pytorch model inference result, try visualize if possible
|
| 296 |
+
create_process(
|
| 297 |
+
'visualize pytorch model',
|
| 298 |
+
target=visualize_model,
|
| 299 |
+
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path],
|
| 300 |
+
args.test_img, args.device),
|
| 301 |
+
kwargs=dict(
|
| 302 |
+
backend=Backend.PYTORCH,
|
| 303 |
+
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'),
|
| 304 |
+
show_result=args.show),
|
| 305 |
+
ret_value=ret_value)
|
| 306 |
+
logger.info('All process success.')
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if __name__ == '__main__':
|
| 310 |
+
main()
|
model/dist_test.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
CONFIG=$1
|
| 4 |
+
CHECKPOINT=$2
|
| 5 |
+
GPUS=$3
|
| 6 |
+
PORT=${PORT:-29511}
|
| 7 |
+
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
| 8 |
+
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
|
| 9 |
+
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
|
model/dist_train.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
CONFIG=$1
|
| 4 |
+
GPUS=$2
|
| 5 |
+
PORT=${PORT:-29500}
|
| 6 |
+
|
| 7 |
+
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
| 8 |
+
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=63667 \
|
| 9 |
+
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
|
model/get_flops.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from mmcv import Config, DictAction
|
| 7 |
+
|
| 8 |
+
from mmdet.models import build_detector
|
| 9 |
+
import mmcv_custom # noqa: F401,F403
|
| 10 |
+
import mmdet_custom # noqa: F401,F403
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string
|
| 14 |
+
from mmcv.cnn import get_model_complexity_info
|
| 15 |
+
except ImportError:
|
| 16 |
+
raise ImportError('Please upgrade mmcv to >0.6.2')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
parser = argparse.ArgumentParser(description='Train a detector')
|
| 21 |
+
parser.add_argument('config', help='train config file path')
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
'--shape',
|
| 24 |
+
type=int,
|
| 25 |
+
nargs='+',
|
| 26 |
+
default=[800, 1280],
|
| 27 |
+
help='input image size')
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
'--cfg-options',
|
| 30 |
+
nargs='+',
|
| 31 |
+
action=DictAction,
|
| 32 |
+
help='override some settings in the used config, the key-value pair '
|
| 33 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
| 34 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
| 35 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
| 36 |
+
'Note that the quotation marks are necessary and that no white space '
|
| 37 |
+
'is allowed.')
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
'--size-divisor',
|
| 40 |
+
type=int,
|
| 41 |
+
default=32,
|
| 42 |
+
help='Pad the input image, the minimum size that is divisible '
|
| 43 |
+
'by size_divisor, -1 means do not pad the image.')
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
return args
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def dcnv3_flops(n, k, c):
|
| 49 |
+
return 5 * n * k * c
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_flops(model, input_shape):
|
| 53 |
+
flops, params = get_model_complexity_info(model, input_shape, as_strings=False)
|
| 54 |
+
|
| 55 |
+
backbone = model.backbone
|
| 56 |
+
backbone_name = type(backbone).__name__
|
| 57 |
+
_, H, W = input_shape
|
| 58 |
+
|
| 59 |
+
temp = 0
|
| 60 |
+
if 'InternImage' in backbone_name:
|
| 61 |
+
depths = backbone.depths # [4, 4, 18, 4]
|
| 62 |
+
for idx, depth in enumerate(depths):
|
| 63 |
+
channels = backbone.channels * (2 ** idx)
|
| 64 |
+
h = H / (4 * (2 ** idx))
|
| 65 |
+
w = W / (4 * (2 ** idx))
|
| 66 |
+
temp += depth * dcnv3_flops(n=h * w, k=3 * 3, c=channels)
|
| 67 |
+
|
| 68 |
+
flops = flops + temp
|
| 69 |
+
return flops_to_string(flops), params_to_string(params)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == '__main__':
|
| 73 |
+
|
| 74 |
+
args = parse_args()
|
| 75 |
+
|
| 76 |
+
if len(args.shape) == 1:
|
| 77 |
+
h = w = args.shape[0]
|
| 78 |
+
elif len(args.shape) == 2:
|
| 79 |
+
h, w = args.shape
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError('invalid input shape')
|
| 82 |
+
orig_shape = (3, h, w)
|
| 83 |
+
divisor = args.size_divisor
|
| 84 |
+
if divisor > 0:
|
| 85 |
+
h = int(np.ceil(h / divisor)) * divisor
|
| 86 |
+
w = int(np.ceil(w / divisor)) * divisor
|
| 87 |
+
|
| 88 |
+
input_shape = (3, h, w)
|
| 89 |
+
|
| 90 |
+
cfg = Config.fromfile(args.config)
|
| 91 |
+
if args.cfg_options is not None:
|
| 92 |
+
cfg.merge_from_dict(args.cfg_options)
|
| 93 |
+
|
| 94 |
+
model = build_detector(
|
| 95 |
+
cfg.model,
|
| 96 |
+
train_cfg=cfg.get('train_cfg'),
|
| 97 |
+
test_cfg=cfg.get('test_cfg'))
|
| 98 |
+
|
| 99 |
+
if torch.cuda.is_available():
|
| 100 |
+
model.cuda()
|
| 101 |
+
model.eval()
|
| 102 |
+
if hasattr(model, 'forward_dummy'):
|
| 103 |
+
model.forward = model.forward_dummy
|
| 104 |
+
else:
|
| 105 |
+
raise NotImplementedError(
|
| 106 |
+
'FLOPs counter is currently not currently supported with {}'.
|
| 107 |
+
format(model.__class__.__name__))
|
| 108 |
+
|
| 109 |
+
flops, params = get_flops(model, input_shape)
|
| 110 |
+
split_line = '=' * 30
|
| 111 |
+
|
| 112 |
+
if divisor > 0 and \
|
| 113 |
+
input_shape != orig_shape:
|
| 114 |
+
print(f'{split_line}\nUse size divisor set input shape '
|
| 115 |
+
f'from {orig_shape} to {input_shape}\n')
|
| 116 |
+
print(f'{split_line}\nInput shape: {input_shape}\n'
|
| 117 |
+
f'Flops: {flops}\nParams: {params}\n{split_line}')
|
| 118 |
+
print('!!!Please be cautious if you use the results in papers. '
|
| 119 |
+
'You may need to check if all ops are supported and verify that the '
|
| 120 |
+
'flops computation is correct.')
|
model/image_demo.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import asyncio
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
|
| 5 |
+
from mmdet.apis import (async_inference_detector, inference_detector,
|
| 6 |
+
init_detector, show_result_pyplot)
|
| 7 |
+
import mmcv
|
| 8 |
+
import mmcv_custom # noqa: F401,F403
|
| 9 |
+
import mmdet_custom # noqa: F401,F403
|
| 10 |
+
import os.path as osp
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = ArgumentParser()
|
| 15 |
+
parser.add_argument('img', help='Image file')
|
| 16 |
+
parser.add_argument('config', help='Config file')
|
| 17 |
+
parser.add_argument('checkpoint', help='Checkpoint file')
|
| 18 |
+
parser.add_argument('--out', type=str, default="demo", help='out dir')
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
'--device', default='cuda:0', help='Device used for inference')
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
'--palette',
|
| 23 |
+
default='coco',
|
| 24 |
+
choices=['coco', 'voc', 'citys', 'random'],
|
| 25 |
+
help='Color palette used for visualization')
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
'--score-thr', type=float, default=0.3, help='bbox score threshold')
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
'--async-test',
|
| 30 |
+
action='store_true',
|
| 31 |
+
help='whether to set async options for async inference.')
|
| 32 |
+
args = parser.parse_args()
|
| 33 |
+
return args
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main(args):
|
| 37 |
+
# build the model from a config file and a checkpoint file
|
| 38 |
+
model = init_detector(args.config, args.checkpoint, device=args.device)
|
| 39 |
+
# test a single image
|
| 40 |
+
result = inference_detector(model, args.img)
|
| 41 |
+
|
| 42 |
+
mmcv.mkdir_or_exist(args.out)
|
| 43 |
+
out_file = osp.join(args.out, osp.basename(args.img))
|
| 44 |
+
# show the results
|
| 45 |
+
model.show_result(
|
| 46 |
+
args.img,
|
| 47 |
+
result,
|
| 48 |
+
score_thr=args.score_thr,
|
| 49 |
+
show=False,
|
| 50 |
+
bbox_color=args.palette,
|
| 51 |
+
text_color=(200, 200, 200),
|
| 52 |
+
mask_color=args.palette,
|
| 53 |
+
out_file=out_file
|
| 54 |
+
)
|
| 55 |
+
print(f"Result is save at {out_file}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if __name__ == '__main__':
|
| 60 |
+
args = parse_args()
|
| 61 |
+
main(args)
|
model/mmcv_custom/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
# -*- coding: utf-8 -*-
|
| 8 |
+
from .custom_layer_decay_optimizer_constructor import CustomLayerDecayOptimizerConstructor
|
| 9 |
+
from .checkpoint import load_checkpoint
|
| 10 |
+
|
| 11 |
+
__all__ = ['CustomLayerDecayOptimizerConstructor', 'load_checkpoint']
|
model/mmcv_custom/checkpoint.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import os.path as osp
|
| 5 |
+
import pkgutil
|
| 6 |
+
import time
|
| 7 |
+
import warnings
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
from importlib import import_module
|
| 10 |
+
from tempfile import TemporaryDirectory
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torchvision
|
| 14 |
+
from torch.optim import Optimizer
|
| 15 |
+
from torch.utils import model_zoo
|
| 16 |
+
from torch.nn import functional as F
|
| 17 |
+
|
| 18 |
+
import mmcv
|
| 19 |
+
from mmcv.fileio import FileClient
|
| 20 |
+
from mmcv.fileio import load as load_file
|
| 21 |
+
from mmcv.parallel import is_module_wrapper
|
| 22 |
+
from mmcv.utils import mkdir_or_exist
|
| 23 |
+
from mmcv.runner import get_dist_info
|
| 24 |
+
|
| 25 |
+
ENV_MMCV_HOME = 'MMCV_HOME'
|
| 26 |
+
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
| 27 |
+
DEFAULT_CACHE_DIR = '~/.cache'
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _get_mmcv_home():
|
| 31 |
+
mmcv_home = os.path.expanduser(
|
| 32 |
+
os.getenv(
|
| 33 |
+
ENV_MMCV_HOME,
|
| 34 |
+
os.path.join(
|
| 35 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
| 36 |
+
|
| 37 |
+
mkdir_or_exist(mmcv_home)
|
| 38 |
+
return mmcv_home
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
| 42 |
+
"""Load state_dict to a module.
|
| 43 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
| 44 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
| 45 |
+
param mismatch will be shown even if strict is False.
|
| 46 |
+
Args:
|
| 47 |
+
module (Module): Module that receives the state_dict.
|
| 48 |
+
state_dict (OrderedDict): Weights.
|
| 49 |
+
strict (bool): whether to strictly enforce that the keys
|
| 50 |
+
in :attr:`state_dict` match the keys returned by this module's
|
| 51 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
| 52 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
| 53 |
+
message. If not specified, print function will be used.
|
| 54 |
+
"""
|
| 55 |
+
unexpected_keys = []
|
| 56 |
+
all_missing_keys = []
|
| 57 |
+
err_msg = []
|
| 58 |
+
|
| 59 |
+
metadata = getattr(state_dict, '_metadata', None)
|
| 60 |
+
state_dict = state_dict.copy()
|
| 61 |
+
if metadata is not None:
|
| 62 |
+
state_dict._metadata = metadata
|
| 63 |
+
|
| 64 |
+
# use _load_from_state_dict to enable checkpoint version control
|
| 65 |
+
def load(module, prefix=''):
|
| 66 |
+
# recursively check parallel module in case that the model has a
|
| 67 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
| 68 |
+
if is_module_wrapper(module):
|
| 69 |
+
module = module.module
|
| 70 |
+
local_metadata = {} if metadata is None else metadata.get(
|
| 71 |
+
prefix[:-1], {})
|
| 72 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
| 73 |
+
all_missing_keys, unexpected_keys,
|
| 74 |
+
err_msg)
|
| 75 |
+
for name, child in module._modules.items():
|
| 76 |
+
if child is not None:
|
| 77 |
+
load(child, prefix + name + '.')
|
| 78 |
+
|
| 79 |
+
load(module)
|
| 80 |
+
load = None # break load->load reference cycle
|
| 81 |
+
|
| 82 |
+
# ignore "num_batches_tracked" of BN layers
|
| 83 |
+
missing_keys = [
|
| 84 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
if unexpected_keys:
|
| 88 |
+
err_msg.append('unexpected key in source '
|
| 89 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
| 90 |
+
if missing_keys:
|
| 91 |
+
err_msg.append(
|
| 92 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
| 93 |
+
|
| 94 |
+
rank, _ = get_dist_info()
|
| 95 |
+
if len(err_msg) > 0 and rank == 0:
|
| 96 |
+
err_msg.insert(
|
| 97 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
| 98 |
+
err_msg = '\n'.join(err_msg)
|
| 99 |
+
if strict:
|
| 100 |
+
raise RuntimeError(err_msg)
|
| 101 |
+
elif logger is not None:
|
| 102 |
+
logger.warning(err_msg)
|
| 103 |
+
else:
|
| 104 |
+
print(err_msg)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def load_url_dist(url, model_dir=None):
|
| 108 |
+
"""In distributed setting, this function only download checkpoint at local
|
| 109 |
+
rank 0."""
|
| 110 |
+
rank, world_size = get_dist_info()
|
| 111 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
| 112 |
+
if rank == 0:
|
| 113 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
| 114 |
+
if world_size > 1:
|
| 115 |
+
torch.distributed.barrier()
|
| 116 |
+
if rank > 0:
|
| 117 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
| 118 |
+
return checkpoint
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def load_pavimodel_dist(model_path, map_location=None):
|
| 122 |
+
"""In distributed setting, this function only download checkpoint at local
|
| 123 |
+
rank 0."""
|
| 124 |
+
try:
|
| 125 |
+
from pavi import modelcloud
|
| 126 |
+
except ImportError:
|
| 127 |
+
raise ImportError(
|
| 128 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
| 129 |
+
rank, world_size = get_dist_info()
|
| 130 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
| 131 |
+
if rank == 0:
|
| 132 |
+
model = modelcloud.get(model_path)
|
| 133 |
+
with TemporaryDirectory() as tmp_dir:
|
| 134 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
| 135 |
+
model.download(downloaded_file)
|
| 136 |
+
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
| 137 |
+
if world_size > 1:
|
| 138 |
+
torch.distributed.barrier()
|
| 139 |
+
if rank > 0:
|
| 140 |
+
model = modelcloud.get(model_path)
|
| 141 |
+
with TemporaryDirectory() as tmp_dir:
|
| 142 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
| 143 |
+
model.download(downloaded_file)
|
| 144 |
+
checkpoint = torch.load(
|
| 145 |
+
downloaded_file, map_location=map_location)
|
| 146 |
+
return checkpoint
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_fileclient_dist(filename, backend, map_location):
|
| 150 |
+
"""In distributed setting, this function only download checkpoint at local
|
| 151 |
+
rank 0."""
|
| 152 |
+
rank, world_size = get_dist_info()
|
| 153 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
| 154 |
+
allowed_backends = ['ceph', 'petrel']
|
| 155 |
+
if backend not in allowed_backends:
|
| 156 |
+
raise ValueError(f'Load from Backend {backend} is not supported.')
|
| 157 |
+
if rank == 0:
|
| 158 |
+
fileclient = FileClient(backend=backend)
|
| 159 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
| 160 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
| 161 |
+
if world_size > 1:
|
| 162 |
+
torch.distributed.barrier()
|
| 163 |
+
if rank > 0:
|
| 164 |
+
fileclient = FileClient(backend=backend)
|
| 165 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
| 166 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
| 167 |
+
return checkpoint
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_torchvision_models():
|
| 171 |
+
model_urls = dict()
|
| 172 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
| 173 |
+
if ispkg:
|
| 174 |
+
continue
|
| 175 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
| 176 |
+
if hasattr(_zoo, 'model_urls'):
|
| 177 |
+
_urls = getattr(_zoo, 'model_urls')
|
| 178 |
+
model_urls.update(_urls)
|
| 179 |
+
return model_urls
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_external_models():
|
| 183 |
+
mmcv_home = _get_mmcv_home()
|
| 184 |
+
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
| 185 |
+
default_urls = load_file(default_json_path)
|
| 186 |
+
assert isinstance(default_urls, dict)
|
| 187 |
+
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
| 188 |
+
if osp.exists(external_json_path):
|
| 189 |
+
external_urls = load_file(external_json_path)
|
| 190 |
+
assert isinstance(external_urls, dict)
|
| 191 |
+
default_urls.update(external_urls)
|
| 192 |
+
|
| 193 |
+
return default_urls
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_mmcls_models():
|
| 197 |
+
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
| 198 |
+
mmcls_urls = load_file(mmcls_json_path)
|
| 199 |
+
|
| 200 |
+
return mmcls_urls
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_deprecated_model_names():
|
| 204 |
+
deprecate_json_path = osp.join(mmcv.__path__[0],
|
| 205 |
+
'model_zoo/deprecated.json')
|
| 206 |
+
deprecate_urls = load_file(deprecate_json_path)
|
| 207 |
+
assert isinstance(deprecate_urls, dict)
|
| 208 |
+
|
| 209 |
+
return deprecate_urls
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _process_mmcls_checkpoint(checkpoint):
|
| 213 |
+
state_dict = checkpoint['state_dict']
|
| 214 |
+
new_state_dict = OrderedDict()
|
| 215 |
+
for k, v in state_dict.items():
|
| 216 |
+
if k.startswith('backbone.'):
|
| 217 |
+
new_state_dict[k[9:]] = v
|
| 218 |
+
new_checkpoint = dict(state_dict=new_state_dict)
|
| 219 |
+
|
| 220 |
+
return new_checkpoint
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _load_checkpoint(filename, map_location=None):
|
| 224 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
| 225 |
+
Args:
|
| 226 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
| 227 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
| 228 |
+
details.
|
| 229 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
| 230 |
+
Returns:
|
| 231 |
+
dict | OrderedDict: The loaded checkpoint. It can be either an
|
| 232 |
+
OrderedDict storing model weights or a dict containing other
|
| 233 |
+
information, which depends on the checkpoint.
|
| 234 |
+
"""
|
| 235 |
+
if filename.startswith('modelzoo://'):
|
| 236 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
| 237 |
+
'use "torchvision://" instead')
|
| 238 |
+
model_urls = get_torchvision_models()
|
| 239 |
+
model_name = filename[11:]
|
| 240 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
| 241 |
+
elif filename.startswith('torchvision://'):
|
| 242 |
+
model_urls = get_torchvision_models()
|
| 243 |
+
model_name = filename[14:]
|
| 244 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
| 245 |
+
elif filename.startswith('open-mmlab://'):
|
| 246 |
+
model_urls = get_external_models()
|
| 247 |
+
model_name = filename[13:]
|
| 248 |
+
deprecated_urls = get_deprecated_model_names()
|
| 249 |
+
if model_name in deprecated_urls:
|
| 250 |
+
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
| 251 |
+
f'of open-mmlab://{deprecated_urls[model_name]}')
|
| 252 |
+
model_name = deprecated_urls[model_name]
|
| 253 |
+
model_url = model_urls[model_name]
|
| 254 |
+
# check if is url
|
| 255 |
+
if model_url.startswith(('http://', 'https://')):
|
| 256 |
+
checkpoint = load_url_dist(model_url)
|
| 257 |
+
else:
|
| 258 |
+
filename = osp.join(_get_mmcv_home(), model_url)
|
| 259 |
+
if not osp.isfile(filename):
|
| 260 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
| 261 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
| 262 |
+
elif filename.startswith('mmcls://'):
|
| 263 |
+
model_urls = get_mmcls_models()
|
| 264 |
+
model_name = filename[8:]
|
| 265 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
| 266 |
+
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
| 267 |
+
elif filename.startswith(('http://', 'https://')):
|
| 268 |
+
checkpoint = load_url_dist(filename)
|
| 269 |
+
elif filename.startswith('pavi://'):
|
| 270 |
+
model_path = filename[7:]
|
| 271 |
+
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
| 272 |
+
elif filename.startswith('s3://'):
|
| 273 |
+
checkpoint = load_fileclient_dist(
|
| 274 |
+
filename, backend='petrel', map_location=map_location)
|
| 275 |
+
else:
|
| 276 |
+
if not osp.isfile(filename):
|
| 277 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
| 278 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
| 279 |
+
return checkpoint
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def load_checkpoint(model,
|
| 283 |
+
filename,
|
| 284 |
+
map_location='cpu',
|
| 285 |
+
strict=False,
|
| 286 |
+
logger=None):
|
| 287 |
+
"""Load checkpoint from a file or URI.
|
| 288 |
+
Args:
|
| 289 |
+
model (Module): Module to load checkpoint.
|
| 290 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
| 291 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
| 292 |
+
details.
|
| 293 |
+
map_location (str): Same as :func:`torch.load`.
|
| 294 |
+
strict (bool): Whether to allow different params for the model and
|
| 295 |
+
checkpoint.
|
| 296 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
| 297 |
+
Returns:
|
| 298 |
+
dict or OrderedDict: The loaded checkpoint.
|
| 299 |
+
"""
|
| 300 |
+
checkpoint = _load_checkpoint(filename, map_location)
|
| 301 |
+
# OrderedDict is a subclass of dict
|
| 302 |
+
if not isinstance(checkpoint, dict):
|
| 303 |
+
raise RuntimeError(
|
| 304 |
+
f'No state_dict found in checkpoint file {filename}')
|
| 305 |
+
# get state_dict from checkpoint
|
| 306 |
+
if 'state_dict' in checkpoint:
|
| 307 |
+
state_dict = checkpoint['state_dict']
|
| 308 |
+
elif 'model' in checkpoint:
|
| 309 |
+
state_dict = checkpoint['model']
|
| 310 |
+
else:
|
| 311 |
+
state_dict = checkpoint
|
| 312 |
+
# strip prefix of state_dict
|
| 313 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
| 314 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 315 |
+
|
| 316 |
+
# for MoBY, load model of online branch
|
| 317 |
+
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
| 318 |
+
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
| 319 |
+
|
| 320 |
+
# reshape absolute position embedding
|
| 321 |
+
if state_dict.get('absolute_pos_embed') is not None:
|
| 322 |
+
absolute_pos_embed = state_dict['absolute_pos_embed']
|
| 323 |
+
N1, L, C1 = absolute_pos_embed.size()
|
| 324 |
+
N2, C2, H, W = model.absolute_pos_embed.size()
|
| 325 |
+
if N1 != N2 or C1 != C2 or L != H*W:
|
| 326 |
+
logger.warning("Error in loading absolute_pos_embed, pass")
|
| 327 |
+
else:
|
| 328 |
+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
| 329 |
+
|
| 330 |
+
# interpolate position bias table if needed
|
| 331 |
+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
| 332 |
+
for table_key in relative_position_bias_table_keys:
|
| 333 |
+
table_pretrained = state_dict[table_key]
|
| 334 |
+
if not table_key in model.state_dict().keys():
|
| 335 |
+
print(table_key)
|
| 336 |
+
continue
|
| 337 |
+
table_current = model.state_dict()[table_key]
|
| 338 |
+
L1, nH1 = table_pretrained.size()
|
| 339 |
+
L2, nH2 = table_current.size()
|
| 340 |
+
if nH1 != nH2:
|
| 341 |
+
logger.warning(f"Error in loading {table_key}, pass")
|
| 342 |
+
else:
|
| 343 |
+
if L1 != L2:
|
| 344 |
+
S1 = int(L1 ** 0.5)
|
| 345 |
+
S2 = int(L2 ** 0.5)
|
| 346 |
+
table_pretrained_resized = F.interpolate(
|
| 347 |
+
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
| 348 |
+
size=(S2, S2), mode='bicubic')
|
| 349 |
+
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
| 350 |
+
|
| 351 |
+
# load state_dict
|
| 352 |
+
load_state_dict(model, state_dict, strict, logger)
|
| 353 |
+
return checkpoint
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def weights_to_cpu(state_dict):
|
| 357 |
+
"""Copy a model state_dict to cpu.
|
| 358 |
+
Args:
|
| 359 |
+
state_dict (OrderedDict): Model weights on GPU.
|
| 360 |
+
Returns:
|
| 361 |
+
OrderedDict: Model weights on GPU.
|
| 362 |
+
"""
|
| 363 |
+
state_dict_cpu = OrderedDict()
|
| 364 |
+
for key, val in state_dict.items():
|
| 365 |
+
state_dict_cpu[key] = val.cpu()
|
| 366 |
+
return state_dict_cpu
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
| 370 |
+
"""Saves module state to `destination` dictionary.
|
| 371 |
+
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
| 372 |
+
Args:
|
| 373 |
+
module (nn.Module): The module to generate state_dict.
|
| 374 |
+
destination (dict): A dict where state will be stored.
|
| 375 |
+
prefix (str): The prefix for parameters and buffers used in this
|
| 376 |
+
module.
|
| 377 |
+
"""
|
| 378 |
+
for name, param in module._parameters.items():
|
| 379 |
+
if param is not None:
|
| 380 |
+
destination[prefix + name] = param if keep_vars else param.detach()
|
| 381 |
+
for name, buf in module._buffers.items():
|
| 382 |
+
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
| 383 |
+
if buf is not None:
|
| 384 |
+
destination[prefix + name] = buf if keep_vars else buf.detach()
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
| 388 |
+
"""Returns a dictionary containing a whole state of the module.
|
| 389 |
+
Both parameters and persistent buffers (e.g. running averages) are
|
| 390 |
+
included. Keys are corresponding parameter and buffer names.
|
| 391 |
+
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
| 392 |
+
recursively check parallel module in case that the model has a complicated
|
| 393 |
+
structure, e.g., nn.Module(nn.Module(DDP)).
|
| 394 |
+
Args:
|
| 395 |
+
module (nn.Module): The module to generate state_dict.
|
| 396 |
+
destination (OrderedDict): Returned dict for the state of the
|
| 397 |
+
module.
|
| 398 |
+
prefix (str): Prefix of the key.
|
| 399 |
+
keep_vars (bool): Whether to keep the variable property of the
|
| 400 |
+
parameters. Default: False.
|
| 401 |
+
Returns:
|
| 402 |
+
dict: A dictionary containing a whole state of the module.
|
| 403 |
+
"""
|
| 404 |
+
# recursively check parallel module in case that the model has a
|
| 405 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
| 406 |
+
if is_module_wrapper(module):
|
| 407 |
+
module = module.module
|
| 408 |
+
|
| 409 |
+
# below is the same as torch.nn.Module.state_dict()
|
| 410 |
+
if destination is None:
|
| 411 |
+
destination = OrderedDict()
|
| 412 |
+
destination._metadata = OrderedDict()
|
| 413 |
+
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
| 414 |
+
version=module._version)
|
| 415 |
+
_save_to_state_dict(module, destination, prefix, keep_vars)
|
| 416 |
+
for name, child in module._modules.items():
|
| 417 |
+
if child is not None:
|
| 418 |
+
get_state_dict(
|
| 419 |
+
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
| 420 |
+
for hook in module._state_dict_hooks.values():
|
| 421 |
+
hook_result = hook(module, destination, prefix, local_metadata)
|
| 422 |
+
if hook_result is not None:
|
| 423 |
+
destination = hook_result
|
| 424 |
+
return destination
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
| 428 |
+
"""Save checkpoint to file.
|
| 429 |
+
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
| 430 |
+
``optimizer``. By default ``meta`` will contain version and time info.
|
| 431 |
+
Args:
|
| 432 |
+
model (Module): Module whose params are to be saved.
|
| 433 |
+
filename (str): Checkpoint filename.
|
| 434 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
| 435 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
| 436 |
+
"""
|
| 437 |
+
if meta is None:
|
| 438 |
+
meta = {}
|
| 439 |
+
elif not isinstance(meta, dict):
|
| 440 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
| 441 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
| 442 |
+
|
| 443 |
+
if is_module_wrapper(model):
|
| 444 |
+
model = model.module
|
| 445 |
+
|
| 446 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
| 447 |
+
# save class name to the meta
|
| 448 |
+
meta.update(CLASSES=model.CLASSES)
|
| 449 |
+
|
| 450 |
+
checkpoint = {
|
| 451 |
+
'meta': meta,
|
| 452 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
| 453 |
+
}
|
| 454 |
+
# save optimizer state dict in the checkpoint
|
| 455 |
+
if isinstance(optimizer, Optimizer):
|
| 456 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
| 457 |
+
elif isinstance(optimizer, dict):
|
| 458 |
+
checkpoint['optimizer'] = {}
|
| 459 |
+
for name, optim in optimizer.items():
|
| 460 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
| 461 |
+
|
| 462 |
+
if filename.startswith('pavi://'):
|
| 463 |
+
try:
|
| 464 |
+
from pavi import modelcloud
|
| 465 |
+
from pavi.exception import NodeNotFoundError
|
| 466 |
+
except ImportError:
|
| 467 |
+
raise ImportError(
|
| 468 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
| 469 |
+
model_path = filename[7:]
|
| 470 |
+
root = modelcloud.Folder()
|
| 471 |
+
model_dir, model_name = osp.split(model_path)
|
| 472 |
+
try:
|
| 473 |
+
model = modelcloud.get(model_dir)
|
| 474 |
+
except NodeNotFoundError:
|
| 475 |
+
model = root.create_training_model(model_dir)
|
| 476 |
+
with TemporaryDirectory() as tmp_dir:
|
| 477 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
| 478 |
+
with open(checkpoint_file, 'wb') as f:
|
| 479 |
+
torch.save(checkpoint, f)
|
| 480 |
+
f.flush()
|
| 481 |
+
model.create_file(checkpoint_file, name=model_name)
|
| 482 |
+
else:
|
| 483 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
| 484 |
+
# immediately flush buffer
|
| 485 |
+
with open(filename, 'wb') as f:
|
| 486 |
+
torch.save(checkpoint, f)
|
| 487 |
+
f.flush()
|
model/mmcv_custom/custom_layer_decay_optimizer_constructor.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
"""
|
| 7 |
+
Mostly copy-paste from BEiT library:
|
| 8 |
+
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
|
| 14 |
+
from mmcv.runner import get_dist_info
|
| 15 |
+
from mmdet.utils import get_root_logger
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_num_layer_for_swin(var_name, num_max_layer, depths):
|
| 19 |
+
if var_name.startswith("backbone.patch_embed"):
|
| 20 |
+
return 0
|
| 21 |
+
elif "level_embeds" in var_name:
|
| 22 |
+
return 0
|
| 23 |
+
elif var_name.startswith("backbone.layers") or var_name.startswith(
|
| 24 |
+
"backbone.levels"):
|
| 25 |
+
if var_name.split('.')[3] not in ['downsample', 'norm']:
|
| 26 |
+
stage_id = int(var_name.split('.')[2])
|
| 27 |
+
layer_id = int(var_name.split('.')[4])
|
| 28 |
+
# layers for Swin-Large: [2, 2, 18, 2]
|
| 29 |
+
if stage_id == 0:
|
| 30 |
+
return layer_id + 1
|
| 31 |
+
elif stage_id == 1:
|
| 32 |
+
return layer_id + 1 + depths[0]
|
| 33 |
+
elif stage_id == 2:
|
| 34 |
+
return layer_id + 1 + depths[0] + depths[1]
|
| 35 |
+
else:
|
| 36 |
+
return layer_id + 1 + depths[0] + depths[1] + depths[2]
|
| 37 |
+
else:
|
| 38 |
+
stage_id = int(var_name.split('.')[2])
|
| 39 |
+
if stage_id == 0:
|
| 40 |
+
return 1 + depths[0]
|
| 41 |
+
elif stage_id == 1:
|
| 42 |
+
return 1 + depths[0] + depths[1]
|
| 43 |
+
elif stage_id == 2:
|
| 44 |
+
return 1 + depths[0] + depths[1] + depths[2]
|
| 45 |
+
else:
|
| 46 |
+
return 1 + depths[0] + depths[1] + depths[2]
|
| 47 |
+
else:
|
| 48 |
+
return num_max_layer - 1
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@OPTIMIZER_BUILDERS.register_module()
|
| 52 |
+
class CustomLayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
| 53 |
+
|
| 54 |
+
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
| 55 |
+
"""Add all parameters of module to the params list.
|
| 56 |
+
The parameters of the given module will be added to the list of param
|
| 57 |
+
groups, with specific rules defined by paramwise_cfg.
|
| 58 |
+
Args:
|
| 59 |
+
params (list[dict]): A list of param groups, it will be modified
|
| 60 |
+
in place.
|
| 61 |
+
module (nn.Module): The module to be added.
|
| 62 |
+
prefix (str): The prefix of the module
|
| 63 |
+
is_dcn_module (int|float|None): If the current module is a
|
| 64 |
+
submodule of DCN, `is_dcn_module` will be passed to
|
| 65 |
+
control conv_offset layer's learning rate. Defaults to None.
|
| 66 |
+
"""
|
| 67 |
+
parameter_groups = {}
|
| 68 |
+
logger = get_root_logger()
|
| 69 |
+
logger.info(self.paramwise_cfg)
|
| 70 |
+
backbone_small_lr = self.paramwise_cfg.get('backbone_small_lr', False)
|
| 71 |
+
dino_head = self.paramwise_cfg.get('dino_head', False)
|
| 72 |
+
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
| 73 |
+
layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
|
| 74 |
+
depths = self.paramwise_cfg.get('depths')
|
| 75 |
+
offset_lr_scale = self.paramwise_cfg.get('offset_lr_scale', 1.0)
|
| 76 |
+
|
| 77 |
+
logger.info("Build CustomLayerDecayOptimizerConstructor %f - %d" %
|
| 78 |
+
(layer_decay_rate, num_layers))
|
| 79 |
+
weight_decay = self.base_wd
|
| 80 |
+
|
| 81 |
+
for name, param in module.named_parameters():
|
| 82 |
+
if not param.requires_grad:
|
| 83 |
+
continue # frozen weights
|
| 84 |
+
if len(param.shape) == 1 or name.endswith(".bias") or \
|
| 85 |
+
"relative_position" in name or \
|
| 86 |
+
"norm" in name or\
|
| 87 |
+
"sampling_offsets" in name:
|
| 88 |
+
group_name = "no_decay"
|
| 89 |
+
this_weight_decay = 0.
|
| 90 |
+
else:
|
| 91 |
+
group_name = "decay"
|
| 92 |
+
this_weight_decay = weight_decay
|
| 93 |
+
|
| 94 |
+
layer_id = get_num_layer_for_swin(name, num_layers, depths)
|
| 95 |
+
if layer_id == num_layers - 1 and dino_head and \
|
| 96 |
+
("sampling_offsets" in name or "reference_points" in name):
|
| 97 |
+
group_name = "layer_%d_%s_0.1x" % (layer_id, group_name)
|
| 98 |
+
elif "sampling_offsets" in name or "reference_points" in name:
|
| 99 |
+
group_name = "layer_%d_%s_offset_lr_scale" % (layer_id,
|
| 100 |
+
group_name)
|
| 101 |
+
else:
|
| 102 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 103 |
+
|
| 104 |
+
if group_name not in parameter_groups:
|
| 105 |
+
scale = layer_decay_rate ** (num_layers - layer_id - 1)
|
| 106 |
+
if scale < 1 and backbone_small_lr == True:
|
| 107 |
+
scale = scale * 0.1
|
| 108 |
+
if "0.1x" in group_name:
|
| 109 |
+
scale = scale * 0.1
|
| 110 |
+
if "offset_lr_scale" in group_name:
|
| 111 |
+
scale = scale * offset_lr_scale
|
| 112 |
+
|
| 113 |
+
parameter_groups[group_name] = {
|
| 114 |
+
"weight_decay": this_weight_decay,
|
| 115 |
+
"params": [],
|
| 116 |
+
"param_names": [],
|
| 117 |
+
"lr_scale": scale,
|
| 118 |
+
"group_name": group_name,
|
| 119 |
+
"lr": scale * self.base_lr,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
parameter_groups[group_name]["params"].append(param)
|
| 123 |
+
parameter_groups[group_name]["param_names"].append(name)
|
| 124 |
+
rank, _ = get_dist_info()
|
| 125 |
+
if rank == 0:
|
| 126 |
+
to_display = {}
|
| 127 |
+
for key in parameter_groups:
|
| 128 |
+
to_display[key] = {
|
| 129 |
+
"param_names": parameter_groups[key]["param_names"],
|
| 130 |
+
"lr_scale": parameter_groups[key]["lr_scale"],
|
| 131 |
+
"lr": parameter_groups[key]["lr"],
|
| 132 |
+
"weight_decay": parameter_groups[key]["weight_decay"],
|
| 133 |
+
}
|
| 134 |
+
logger.info("Param groups = %s" % json.dumps(to_display, indent=2))
|
| 135 |
+
|
| 136 |
+
# state_dict = module.state_dict()
|
| 137 |
+
# for group_name in parameter_groups:
|
| 138 |
+
# group = parameter_groups[group_name]
|
| 139 |
+
# for name in group["param_names"]:
|
| 140 |
+
# group["params"].append(state_dict[name])
|
| 141 |
+
|
| 142 |
+
params.extend(parameter_groups.values())
|
model/mmdet_custom/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from .models import * # noqa: F401,F403
|
| 8 |
+
from .datasets import *
|
model/mmdet_custom/datasets/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from .publaynet import PubLayNetDataset
|
| 8 |
+
from .doclaynet import DocLayNetDataset
|
| 9 |
+
from .m6doc import M6DocDataset
|
model/mmdet_custom/datasets/doclaynet.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import logging
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import tempfile
|
| 5 |
+
import warnings
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import mmcv
|
| 9 |
+
import numpy as np
|
| 10 |
+
from mmcv.utils import print_log
|
| 11 |
+
from terminaltables import AsciiTable
|
| 12 |
+
|
| 13 |
+
from mmdet.core import eval_recalls
|
| 14 |
+
from mmdet.datasets.api_wrappers import COCO, COCOeval
|
| 15 |
+
|
| 16 |
+
from mmdet.datasets.custom import CustomDataset
|
| 17 |
+
from mmdet.datasets.builder import DATASETS
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@DATASETS.register_module()
|
| 21 |
+
class DocLayNetDataset(CustomDataset):
|
| 22 |
+
CLASSES = ("Caption", "Footnote", "Formula", "List-item", "Page-footer", "Page-header", "Picture", "Section-header", "Table", "Text", "Title",)
|
| 23 |
+
def load_annotations(self, ann_file):
|
| 24 |
+
"""Load annotation from COCO style annotation file.
|
| 25 |
+
Args:
|
| 26 |
+
ann_file (str): Path of annotation file.
|
| 27 |
+
Returns:
|
| 28 |
+
list[dict]: Annotation info from COCO api.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
self.coco = COCO(ann_file)
|
| 32 |
+
# The order of returned `cat_ids` will not
|
| 33 |
+
# change with the order of the CLASSES
|
| 34 |
+
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
|
| 35 |
+
|
| 36 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
| 37 |
+
self.img_ids = self.coco.get_img_ids()
|
| 38 |
+
data_infos = []
|
| 39 |
+
total_ann_ids = []
|
| 40 |
+
for i in self.img_ids:
|
| 41 |
+
info = self.coco.load_imgs([i])[0]
|
| 42 |
+
info['filename'] = info['file_name']
|
| 43 |
+
data_infos.append(info)
|
| 44 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[i])
|
| 45 |
+
total_ann_ids.extend(ann_ids)
|
| 46 |
+
assert len(set(total_ann_ids)) == len(
|
| 47 |
+
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
|
| 48 |
+
return data_infos
|
| 49 |
+
|
| 50 |
+
def get_ann_info(self, idx):
|
| 51 |
+
"""Get COCO annotation by index.
|
| 52 |
+
Args:
|
| 53 |
+
idx (int): Index of data.
|
| 54 |
+
Returns:
|
| 55 |
+
dict: Annotation info of specified index.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
img_id = self.data_infos[idx]['id']
|
| 59 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
| 60 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 61 |
+
return self._parse_ann_info(self.data_infos[idx], ann_info)
|
| 62 |
+
|
| 63 |
+
def get_cat_ids(self, idx):
|
| 64 |
+
"""Get COCO category ids by index.
|
| 65 |
+
Args:
|
| 66 |
+
idx (int): Index of data.
|
| 67 |
+
Returns:
|
| 68 |
+
list[int]: All categories in the image of specified index.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
img_id = self.data_infos[idx]['id']
|
| 72 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
| 73 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 74 |
+
return [ann['category_id'] for ann in ann_info]
|
| 75 |
+
|
| 76 |
+
def _filter_imgs(self, min_size=32):
|
| 77 |
+
"""Filter images too small or without ground truths."""
|
| 78 |
+
valid_inds = []
|
| 79 |
+
# obtain images that contain annotation
|
| 80 |
+
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
|
| 81 |
+
# obtain images that contain annotations of the required categories
|
| 82 |
+
ids_in_cat = set()
|
| 83 |
+
for i, class_id in enumerate(self.cat_ids):
|
| 84 |
+
ids_in_cat |= set(self.coco.cat_img_map[class_id])
|
| 85 |
+
# merge the image id sets of the two conditions and use the merged set
|
| 86 |
+
# to filter out images if self.filter_empty_gt=True
|
| 87 |
+
ids_in_cat &= ids_with_ann
|
| 88 |
+
|
| 89 |
+
valid_img_ids = []
|
| 90 |
+
for i, img_info in enumerate(self.data_infos):
|
| 91 |
+
img_id = self.img_ids[i]
|
| 92 |
+
if self.filter_empty_gt and img_id not in ids_in_cat:
|
| 93 |
+
continue
|
| 94 |
+
if min(img_info['width'], img_info['height']) >= min_size:
|
| 95 |
+
valid_inds.append(i)
|
| 96 |
+
valid_img_ids.append(img_id)
|
| 97 |
+
self.img_ids = valid_img_ids
|
| 98 |
+
return valid_inds
|
| 99 |
+
|
| 100 |
+
def _parse_ann_info(self, img_info, ann_info):
|
| 101 |
+
"""Parse bbox and mask annotation.
|
| 102 |
+
Args:
|
| 103 |
+
ann_info (list[dict]): Annotation info of an image.
|
| 104 |
+
with_mask (bool): Whether to parse mask annotations.
|
| 105 |
+
Returns:
|
| 106 |
+
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
|
| 107 |
+
labels, masks, seg_map. "masks" are raw annotations and not \
|
| 108 |
+
decoded into binary masks.
|
| 109 |
+
"""
|
| 110 |
+
gt_bboxes = []
|
| 111 |
+
gt_labels = []
|
| 112 |
+
gt_bboxes_ignore = []
|
| 113 |
+
gt_masks_ann = []
|
| 114 |
+
for i, ann in enumerate(ann_info):
|
| 115 |
+
if ann.get('ignore', False):
|
| 116 |
+
continue
|
| 117 |
+
x1, y1, w, h = ann['bbox']
|
| 118 |
+
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
|
| 119 |
+
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
|
| 120 |
+
if inter_w * inter_h == 0:
|
| 121 |
+
continue
|
| 122 |
+
if ann['area'] <= 0 or w < 1 or h < 1:
|
| 123 |
+
continue
|
| 124 |
+
if ann['category_id'] not in self.cat_ids:
|
| 125 |
+
continue
|
| 126 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
| 127 |
+
if ann.get('iscrowd', False):
|
| 128 |
+
gt_bboxes_ignore.append(bbox)
|
| 129 |
+
else:
|
| 130 |
+
gt_bboxes.append(bbox)
|
| 131 |
+
gt_labels.append(self.cat2label[ann['category_id']])
|
| 132 |
+
gt_masks_ann.append(ann.get('segmentation', None))
|
| 133 |
+
|
| 134 |
+
if gt_bboxes:
|
| 135 |
+
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
|
| 136 |
+
gt_labels = np.array(gt_labels, dtype=np.int64)
|
| 137 |
+
else:
|
| 138 |
+
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
|
| 139 |
+
gt_labels = np.array([], dtype=np.int64)
|
| 140 |
+
|
| 141 |
+
if gt_bboxes_ignore:
|
| 142 |
+
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
|
| 143 |
+
else:
|
| 144 |
+
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
| 145 |
+
|
| 146 |
+
seg_map = img_info['filename'].replace('jpg', 'png')
|
| 147 |
+
|
| 148 |
+
ann = dict(
|
| 149 |
+
bboxes=gt_bboxes,
|
| 150 |
+
labels=gt_labels,
|
| 151 |
+
bboxes_ignore=gt_bboxes_ignore,
|
| 152 |
+
masks=gt_masks_ann,
|
| 153 |
+
seg_map=seg_map)
|
| 154 |
+
|
| 155 |
+
return ann
|
| 156 |
+
|
| 157 |
+
def xyxy2xywh(self, bbox):
|
| 158 |
+
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
|
| 159 |
+
evaluation.
|
| 160 |
+
Args:
|
| 161 |
+
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
|
| 162 |
+
``xyxy`` order.
|
| 163 |
+
Returns:
|
| 164 |
+
list[float]: The converted bounding boxes, in ``xywh`` order.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
_bbox = bbox.tolist()
|
| 168 |
+
return [
|
| 169 |
+
_bbox[0],
|
| 170 |
+
_bbox[1],
|
| 171 |
+
_bbox[2] - _bbox[0],
|
| 172 |
+
_bbox[3] - _bbox[1],
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
def _proposal2json(self, results):
|
| 176 |
+
"""Convert proposal results to COCO json style."""
|
| 177 |
+
json_results = []
|
| 178 |
+
for idx in range(len(self)):
|
| 179 |
+
img_id = self.img_ids[idx]
|
| 180 |
+
bboxes = results[idx]
|
| 181 |
+
for i in range(bboxes.shape[0]):
|
| 182 |
+
data = dict()
|
| 183 |
+
data['image_id'] = img_id
|
| 184 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 185 |
+
data['score'] = float(bboxes[i][4])
|
| 186 |
+
data['category_id'] = 1
|
| 187 |
+
json_results.append(data)
|
| 188 |
+
return json_results
|
| 189 |
+
|
| 190 |
+
def _det2json(self, results):
|
| 191 |
+
"""Convert detection results to COCO json style."""
|
| 192 |
+
json_results = []
|
| 193 |
+
for idx in range(len(self)):
|
| 194 |
+
img_id = self.img_ids[idx]
|
| 195 |
+
result = results[idx]
|
| 196 |
+
for label in range(len(result)):
|
| 197 |
+
bboxes = result[label]
|
| 198 |
+
for i in range(bboxes.shape[0]):
|
| 199 |
+
data = dict()
|
| 200 |
+
data['image_id'] = img_id
|
| 201 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 202 |
+
data['score'] = float(bboxes[i][4])
|
| 203 |
+
data['category_id'] = self.cat_ids[label]
|
| 204 |
+
json_results.append(data)
|
| 205 |
+
return json_results
|
| 206 |
+
|
| 207 |
+
def _segm2json(self, results):
|
| 208 |
+
"""Convert instance segmentation results to COCO json style."""
|
| 209 |
+
bbox_json_results = []
|
| 210 |
+
segm_json_results = []
|
| 211 |
+
for idx in range(len(self)):
|
| 212 |
+
img_id = self.img_ids[idx]
|
| 213 |
+
det, seg = results[idx]
|
| 214 |
+
for label in range(len(det)):
|
| 215 |
+
# bbox results
|
| 216 |
+
bboxes = det[label]
|
| 217 |
+
for i in range(bboxes.shape[0]):
|
| 218 |
+
data = dict()
|
| 219 |
+
data['image_id'] = img_id
|
| 220 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 221 |
+
data['score'] = float(bboxes[i][4])
|
| 222 |
+
data['category_id'] = self.cat_ids[label]
|
| 223 |
+
bbox_json_results.append(data)
|
| 224 |
+
|
| 225 |
+
# segm results
|
| 226 |
+
# some detectors use different scores for bbox and mask
|
| 227 |
+
if isinstance(seg, tuple):
|
| 228 |
+
segms = seg[0][label]
|
| 229 |
+
mask_score = seg[1][label]
|
| 230 |
+
else:
|
| 231 |
+
segms = seg[label]
|
| 232 |
+
mask_score = [bbox[4] for bbox in bboxes]
|
| 233 |
+
for i in range(bboxes.shape[0]):
|
| 234 |
+
data = dict()
|
| 235 |
+
data['image_id'] = img_id
|
| 236 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 237 |
+
data['score'] = float(mask_score[i])
|
| 238 |
+
data['category_id'] = self.cat_ids[label]
|
| 239 |
+
if isinstance(segms[i]['counts'], bytes):
|
| 240 |
+
segms[i]['counts'] = segms[i]['counts'].decode()
|
| 241 |
+
data['segmentation'] = segms[i]
|
| 242 |
+
segm_json_results.append(data)
|
| 243 |
+
return bbox_json_results, segm_json_results
|
| 244 |
+
|
| 245 |
+
def results2json(self, results, outfile_prefix):
|
| 246 |
+
"""Dump the detection results to a COCO style json file.
|
| 247 |
+
There are 3 types of results: proposals, bbox predictions, mask
|
| 248 |
+
predictions, and they have different data types. This method will
|
| 249 |
+
automatically recognize the type, and dump them to json files.
|
| 250 |
+
Args:
|
| 251 |
+
results (list[list | tuple | ndarray]): Testing results of the
|
| 252 |
+
dataset.
|
| 253 |
+
outfile_prefix (str): The filename prefix of the json files. If the
|
| 254 |
+
prefix is "somepath/xxx", the json files will be named
|
| 255 |
+
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
|
| 256 |
+
"somepath/xxx.proposal.json".
|
| 257 |
+
Returns:
|
| 258 |
+
dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
|
| 259 |
+
values are corresponding filenames.
|
| 260 |
+
"""
|
| 261 |
+
result_files = dict()
|
| 262 |
+
if isinstance(results[0], list):
|
| 263 |
+
json_results = self._det2json(results)
|
| 264 |
+
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
| 265 |
+
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
| 266 |
+
mmcv.dump(json_results, result_files['bbox'])
|
| 267 |
+
elif isinstance(results[0], tuple):
|
| 268 |
+
json_results = self._segm2json(results)
|
| 269 |
+
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
| 270 |
+
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
| 271 |
+
result_files['segm'] = f'{outfile_prefix}.segm.json'
|
| 272 |
+
mmcv.dump(json_results[0], result_files['bbox'])
|
| 273 |
+
mmcv.dump(json_results[1], result_files['segm'])
|
| 274 |
+
elif isinstance(results[0], np.ndarray):
|
| 275 |
+
json_results = self._proposal2json(results)
|
| 276 |
+
result_files['proposal'] = f'{outfile_prefix}.proposal.json'
|
| 277 |
+
mmcv.dump(json_results, result_files['proposal'])
|
| 278 |
+
else:
|
| 279 |
+
raise TypeError('invalid type of results')
|
| 280 |
+
return result_files
|
| 281 |
+
|
| 282 |
+
def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
|
| 283 |
+
gt_bboxes = []
|
| 284 |
+
for i in range(len(self.img_ids)):
|
| 285 |
+
ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
|
| 286 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 287 |
+
if len(ann_info) == 0:
|
| 288 |
+
gt_bboxes.append(np.zeros((0, 4)))
|
| 289 |
+
continue
|
| 290 |
+
bboxes = []
|
| 291 |
+
for ann in ann_info:
|
| 292 |
+
if ann.get('ignore', False) or ann['iscrowd']:
|
| 293 |
+
continue
|
| 294 |
+
x1, y1, w, h = ann['bbox']
|
| 295 |
+
bboxes.append([x1, y1, x1 + w, y1 + h])
|
| 296 |
+
bboxes = np.array(bboxes, dtype=np.float32)
|
| 297 |
+
if bboxes.shape[0] == 0:
|
| 298 |
+
bboxes = np.zeros((0, 4))
|
| 299 |
+
gt_bboxes.append(bboxes)
|
| 300 |
+
|
| 301 |
+
recalls = eval_recalls(
|
| 302 |
+
gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
|
| 303 |
+
ar = recalls.mean(axis=1)
|
| 304 |
+
return ar
|
| 305 |
+
|
| 306 |
+
def format_results(self, results, jsonfile_prefix=None, **kwargs):
|
| 307 |
+
"""Format the results to json (standard format for COCO evaluation).
|
| 308 |
+
Args:
|
| 309 |
+
results (list[tuple | numpy.ndarray]): Testing results of the
|
| 310 |
+
dataset.
|
| 311 |
+
jsonfile_prefix (str | None): The prefix of json files. It includes
|
| 312 |
+
the file path and the prefix of filename, e.g., "a/b/prefix".
|
| 313 |
+
If not specified, a temp file will be created. Default: None.
|
| 314 |
+
Returns:
|
| 315 |
+
tuple: (result_files, tmp_dir), result_files is a dict containing \
|
| 316 |
+
the json filepaths, tmp_dir is the temporal directory created \
|
| 317 |
+
for saving json files when jsonfile_prefix is not specified.
|
| 318 |
+
"""
|
| 319 |
+
assert isinstance(results, list), 'results must be a list'
|
| 320 |
+
assert len(results) == len(self), (
|
| 321 |
+
'The length of results is not equal to the dataset len: {} != {}'.
|
| 322 |
+
format(len(results), len(self)))
|
| 323 |
+
|
| 324 |
+
if jsonfile_prefix is None:
|
| 325 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
| 326 |
+
jsonfile_prefix = osp.join(tmp_dir.name, 'results')
|
| 327 |
+
else:
|
| 328 |
+
tmp_dir = None
|
| 329 |
+
result_files = self.results2json(results, jsonfile_prefix)
|
| 330 |
+
return result_files, tmp_dir
|
| 331 |
+
|
| 332 |
+
def evaluate(self,
|
| 333 |
+
results,
|
| 334 |
+
metric='bbox',
|
| 335 |
+
logger=None,
|
| 336 |
+
jsonfile_prefix=None,
|
| 337 |
+
classwise=False,
|
| 338 |
+
proposal_nums=(100, 300, 1000),
|
| 339 |
+
iou_thrs=None,
|
| 340 |
+
metric_items=None):
|
| 341 |
+
"""Evaluation in COCO protocol.
|
| 342 |
+
Args:
|
| 343 |
+
results (list[list | tuple]): Testing results of the dataset.
|
| 344 |
+
metric (str | list[str]): Metrics to be evaluated. Options are
|
| 345 |
+
'bbox', 'segm', 'proposal', 'proposal_fast'.
|
| 346 |
+
logger (logging.Logger | str | None): Logger used for printing
|
| 347 |
+
related information during evaluation. Default: None.
|
| 348 |
+
jsonfile_prefix (str | None): The prefix of json files. It includes
|
| 349 |
+
the file path and the prefix of filename, e.g., "a/b/prefix".
|
| 350 |
+
If not specified, a temp file will be created. Default: None.
|
| 351 |
+
classwise (bool): Whether to evaluating the AP for each class.
|
| 352 |
+
proposal_nums (Sequence[int]): Proposal number used for evaluating
|
| 353 |
+
recalls, such as recall@100, recall@1000.
|
| 354 |
+
Default: (100, 300, 1000).
|
| 355 |
+
iou_thrs (Sequence[float], optional): IoU threshold used for
|
| 356 |
+
evaluating recalls/mAPs. If set to a list, the average of all
|
| 357 |
+
IoUs will also be computed. If not specified, [0.50, 0.55,
|
| 358 |
+
0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
|
| 359 |
+
Default: None.
|
| 360 |
+
metric_items (list[str] | str, optional): Metric items that will
|
| 361 |
+
be returned. If not specified, ``['AR@100', 'AR@300',
|
| 362 |
+
'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
|
| 363 |
+
used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
|
| 364 |
+
'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
|
| 365 |
+
``metric=='bbox' or metric=='segm'``.
|
| 366 |
+
Returns:
|
| 367 |
+
dict[str, float]: COCO style evaluation metric.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
metrics = metric if isinstance(metric, list) else [metric]
|
| 371 |
+
allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
|
| 372 |
+
for metric in metrics:
|
| 373 |
+
if metric not in allowed_metrics:
|
| 374 |
+
raise KeyError(f'metric {metric} is not supported')
|
| 375 |
+
if iou_thrs is None:
|
| 376 |
+
iou_thrs = np.linspace(
|
| 377 |
+
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
| 378 |
+
if metric_items is not None:
|
| 379 |
+
if not isinstance(metric_items, list):
|
| 380 |
+
metric_items = [metric_items]
|
| 381 |
+
|
| 382 |
+
result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
|
| 383 |
+
|
| 384 |
+
eval_results = OrderedDict()
|
| 385 |
+
cocoGt = self.coco
|
| 386 |
+
for metric in metrics:
|
| 387 |
+
msg = f'Evaluating {metric}...'
|
| 388 |
+
if logger is None:
|
| 389 |
+
msg = '\n' + msg
|
| 390 |
+
print_log(msg, logger=logger)
|
| 391 |
+
|
| 392 |
+
if metric == 'proposal_fast':
|
| 393 |
+
ar = self.fast_eval_recall(
|
| 394 |
+
results, proposal_nums, iou_thrs, logger='silent')
|
| 395 |
+
log_msg = []
|
| 396 |
+
for i, num in enumerate(proposal_nums):
|
| 397 |
+
eval_results[f'AR@{num}'] = ar[i]
|
| 398 |
+
log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
|
| 399 |
+
log_msg = ''.join(log_msg)
|
| 400 |
+
print_log(log_msg, logger=logger)
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
iou_type = 'bbox' if metric == 'proposal' else metric
|
| 404 |
+
if metric not in result_files:
|
| 405 |
+
raise KeyError(f'{metric} is not in results')
|
| 406 |
+
try:
|
| 407 |
+
predictions = mmcv.load(result_files[metric])
|
| 408 |
+
if iou_type == 'segm':
|
| 409 |
+
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
|
| 410 |
+
# When evaluating mask AP, if the results contain bbox,
|
| 411 |
+
# cocoapi will use the box area instead of the mask area
|
| 412 |
+
# for calculating the instance area. Though the overall AP
|
| 413 |
+
# is not affected, this leads to different
|
| 414 |
+
# small/medium/large mask AP results.
|
| 415 |
+
for x in predictions:
|
| 416 |
+
x.pop('bbox')
|
| 417 |
+
warnings.simplefilter('once')
|
| 418 |
+
warnings.warn(
|
| 419 |
+
'The key "bbox" is deleted for more accurate mask AP '
|
| 420 |
+
'of small/medium/large instances since v2.12.0. This '
|
| 421 |
+
'does not change the overall mAP calculation.',
|
| 422 |
+
UserWarning)
|
| 423 |
+
cocoDt = cocoGt.loadRes(predictions)
|
| 424 |
+
except IndexError:
|
| 425 |
+
print_log(
|
| 426 |
+
'The testing results of the whole dataset is empty.',
|
| 427 |
+
logger=logger,
|
| 428 |
+
level=logging.ERROR)
|
| 429 |
+
break
|
| 430 |
+
|
| 431 |
+
cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
|
| 432 |
+
cocoEval.params.catIds = self.cat_ids
|
| 433 |
+
cocoEval.params.imgIds = self.img_ids
|
| 434 |
+
cocoEval.params.maxDets = list(proposal_nums)
|
| 435 |
+
cocoEval.params.iouThrs = iou_thrs
|
| 436 |
+
# mapping of cocoEval.stats
|
| 437 |
+
coco_metric_names = {
|
| 438 |
+
'mAP': 0,
|
| 439 |
+
'mAP_50': 1,
|
| 440 |
+
'mAP_75': 2,
|
| 441 |
+
'mAP_s': 3,
|
| 442 |
+
'mAP_m': 4,
|
| 443 |
+
'mAP_l': 5,
|
| 444 |
+
'AR@100': 6,
|
| 445 |
+
'AR@300': 7,
|
| 446 |
+
'AR@1000': 8,
|
| 447 |
+
'AR_s@1000': 9,
|
| 448 |
+
'AR_m@1000': 10,
|
| 449 |
+
'AR_l@1000': 11
|
| 450 |
+
}
|
| 451 |
+
if metric_items is not None:
|
| 452 |
+
for metric_item in metric_items:
|
| 453 |
+
if metric_item not in coco_metric_names:
|
| 454 |
+
raise KeyError(
|
| 455 |
+
f'metric item {metric_item} is not supported')
|
| 456 |
+
|
| 457 |
+
if metric == 'proposal':
|
| 458 |
+
cocoEval.params.useCats = 0
|
| 459 |
+
cocoEval.evaluate()
|
| 460 |
+
cocoEval.accumulate()
|
| 461 |
+
cocoEval.summarize()
|
| 462 |
+
if metric_items is None:
|
| 463 |
+
metric_items = [
|
| 464 |
+
'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
|
| 465 |
+
'AR_m@1000', 'AR_l@1000'
|
| 466 |
+
]
|
| 467 |
+
|
| 468 |
+
for item in metric_items:
|
| 469 |
+
val = float(
|
| 470 |
+
f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
|
| 471 |
+
eval_results[item] = val
|
| 472 |
+
else:
|
| 473 |
+
cocoEval.evaluate()
|
| 474 |
+
cocoEval.accumulate()
|
| 475 |
+
cocoEval.summarize()
|
| 476 |
+
if classwise: # Compute per-category AP
|
| 477 |
+
# Compute per-category AP
|
| 478 |
+
# from https://github.com/facebookresearch/detectron2/
|
| 479 |
+
precisions = cocoEval.eval['precision']
|
| 480 |
+
# precision: (iou, recall, cls, area range, max dets)
|
| 481 |
+
assert len(self.cat_ids) == precisions.shape[2]
|
| 482 |
+
|
| 483 |
+
results_per_category = []
|
| 484 |
+
for idx, catId in enumerate(self.cat_ids):
|
| 485 |
+
# area range index 0: all area ranges
|
| 486 |
+
# max dets index -1: typically 100 per image
|
| 487 |
+
nm = self.coco.loadCats(catId)[0]
|
| 488 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 489 |
+
precision = precision[precision > -1]
|
| 490 |
+
if precision.size:
|
| 491 |
+
ap = np.mean(precision)
|
| 492 |
+
else:
|
| 493 |
+
ap = float('nan')
|
| 494 |
+
results_per_category.append(
|
| 495 |
+
(f'{nm["name"]}', f'{float(ap):0.3f}'))
|
| 496 |
+
|
| 497 |
+
num_columns = min(6, len(results_per_category) * 2)
|
| 498 |
+
results_flatten = list(
|
| 499 |
+
itertools.chain(*results_per_category))
|
| 500 |
+
headers = ['category', 'AP'] * (num_columns // 2)
|
| 501 |
+
results_2d = itertools.zip_longest(*[
|
| 502 |
+
results_flatten[i::num_columns]
|
| 503 |
+
for i in range(num_columns)
|
| 504 |
+
])
|
| 505 |
+
table_data = [headers]
|
| 506 |
+
table_data += [result for result in results_2d]
|
| 507 |
+
table = AsciiTable(table_data)
|
| 508 |
+
print_log('\n' + table.table, logger=logger)
|
| 509 |
+
|
| 510 |
+
if metric_items is None:
|
| 511 |
+
metric_items = [
|
| 512 |
+
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
|
| 513 |
+
]
|
| 514 |
+
|
| 515 |
+
for metric_item in metric_items:
|
| 516 |
+
key = f'{metric}_{metric_item}'
|
| 517 |
+
val = float(
|
| 518 |
+
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
|
| 519 |
+
)
|
| 520 |
+
eval_results[key] = val
|
| 521 |
+
ap = cocoEval.stats[:6]
|
| 522 |
+
eval_results[f'{metric}_mAP_copypaste'] = (
|
| 523 |
+
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
|
| 524 |
+
f'{ap[4]:.3f} {ap[5]:.3f}')
|
| 525 |
+
if tmp_dir is not None:
|
| 526 |
+
tmp_dir.cleanup()
|
| 527 |
+
return eval_results
|
model/mmdet_custom/datasets/m6doc.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import logging
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import tempfile
|
| 5 |
+
import warnings
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import mmcv
|
| 9 |
+
import numpy as np
|
| 10 |
+
from mmcv.utils import print_log
|
| 11 |
+
from terminaltables import AsciiTable
|
| 12 |
+
|
| 13 |
+
from mmdet.core import eval_recalls
|
| 14 |
+
from mmdet.datasets.api_wrappers import COCO, COCOeval
|
| 15 |
+
|
| 16 |
+
from mmdet.datasets.custom import CustomDataset
|
| 17 |
+
from mmdet.datasets.builder import DATASETS
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@DATASETS.register_module()
|
| 21 |
+
class M6DocDataset(CustomDataset):
|
| 22 |
+
|
| 23 |
+
CLASSES = ("_background_", "QR code", "advertisement", "algorithm", "answer", "author", "barcode", "bill", "blank", "bracket", "breakout", "byline", "caption", "catalogue", "chapter title", "code", "correction", "credit", "dateline", "drop cap", "editor's note", "endnote", "examinee information", "fifth-level title", "figure", "first-level question number", "first-level title", "flag", "folio", "footer", "footnote", "formula", "fourth-level section title", "fourth-level title", "header", "headline", "index", "inside", "institute", "jump line", "kicker", "lead", "marginal note", "matching", "mugshot", "option", "ordered list", "other question number", "page number", "paragraph", "part", "play", "poem", "reference", "sealing line", "second-level question number", "second-level title", "section", "section title", "sidebar", "sub section title", "subhead", "subsub section title", "supplementary note", "table", "table caption", "table note", "teasers", "third-level question number", "third-level title", "title", "translator", "underscore", "unordered list", "weather forecast")
|
| 24 |
+
|
| 25 |
+
def load_annotations(self, ann_file):
|
| 26 |
+
"""Load annotation from COCO style annotation file.
|
| 27 |
+
Args:
|
| 28 |
+
ann_file (str): Path of annotation file.
|
| 29 |
+
Returns:
|
| 30 |
+
list[dict]: Annotation info from COCO api.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
self.coco = COCO(ann_file)
|
| 34 |
+
# The order of returned `cat_ids` will not
|
| 35 |
+
# change with the order of the CLASSES
|
| 36 |
+
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
|
| 37 |
+
|
| 38 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
| 39 |
+
self.img_ids = self.coco.get_img_ids()
|
| 40 |
+
data_infos = []
|
| 41 |
+
total_ann_ids = []
|
| 42 |
+
for i in self.img_ids:
|
| 43 |
+
info = self.coco.load_imgs([i])[0]
|
| 44 |
+
info['filename'] = info['file_name']
|
| 45 |
+
data_infos.append(info)
|
| 46 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[i])
|
| 47 |
+
total_ann_ids.extend(ann_ids)
|
| 48 |
+
assert len(set(total_ann_ids)) == len(
|
| 49 |
+
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
|
| 50 |
+
return data_infos
|
| 51 |
+
|
| 52 |
+
def get_ann_info(self, idx):
|
| 53 |
+
"""Get COCO annotation by index.
|
| 54 |
+
Args:
|
| 55 |
+
idx (int): Index of data.
|
| 56 |
+
Returns:
|
| 57 |
+
dict: Annotation info of specified index.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
img_id = self.data_infos[idx]['id']
|
| 61 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
| 62 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 63 |
+
return self._parse_ann_info(self.data_infos[idx], ann_info)
|
| 64 |
+
|
| 65 |
+
def get_cat_ids(self, idx):
|
| 66 |
+
"""Get COCO category ids by index.
|
| 67 |
+
Args:
|
| 68 |
+
idx (int): Index of data.
|
| 69 |
+
Returns:
|
| 70 |
+
list[int]: All categories in the image of specified index.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
img_id = self.data_infos[idx]['id']
|
| 74 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
| 75 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 76 |
+
return [ann['category_id'] for ann in ann_info]
|
| 77 |
+
|
| 78 |
+
def _filter_imgs(self, min_size=32):
|
| 79 |
+
"""Filter images too small or without ground truths."""
|
| 80 |
+
valid_inds = []
|
| 81 |
+
# obtain images that contain annotation
|
| 82 |
+
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
|
| 83 |
+
# obtain images that contain annotations of the required categories
|
| 84 |
+
ids_in_cat = set()
|
| 85 |
+
for i, class_id in enumerate(self.cat_ids):
|
| 86 |
+
ids_in_cat |= set(self.coco.cat_img_map[class_id])
|
| 87 |
+
# merge the image id sets of the two conditions and use the merged set
|
| 88 |
+
# to filter out images if self.filter_empty_gt=True
|
| 89 |
+
ids_in_cat &= ids_with_ann
|
| 90 |
+
|
| 91 |
+
valid_img_ids = []
|
| 92 |
+
for i, img_info in enumerate(self.data_infos):
|
| 93 |
+
img_id = self.img_ids[i]
|
| 94 |
+
if self.filter_empty_gt and img_id not in ids_in_cat:
|
| 95 |
+
continue
|
| 96 |
+
if min(img_info['width'], img_info['height']) >= min_size:
|
| 97 |
+
valid_inds.append(i)
|
| 98 |
+
valid_img_ids.append(img_id)
|
| 99 |
+
self.img_ids = valid_img_ids
|
| 100 |
+
return valid_inds
|
| 101 |
+
|
| 102 |
+
def _parse_ann_info(self, img_info, ann_info):
|
| 103 |
+
"""Parse bbox and mask annotation.
|
| 104 |
+
Args:
|
| 105 |
+
ann_info (list[dict]): Annotation info of an image.
|
| 106 |
+
with_mask (bool): Whether to parse mask annotations.
|
| 107 |
+
Returns:
|
| 108 |
+
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
|
| 109 |
+
labels, masks, seg_map. "masks" are raw annotations and not \
|
| 110 |
+
decoded into binary masks.
|
| 111 |
+
"""
|
| 112 |
+
gt_bboxes = []
|
| 113 |
+
gt_labels = []
|
| 114 |
+
gt_bboxes_ignore = []
|
| 115 |
+
gt_masks_ann = []
|
| 116 |
+
for i, ann in enumerate(ann_info):
|
| 117 |
+
if ann.get('ignore', False):
|
| 118 |
+
continue
|
| 119 |
+
x1, y1, w, h = ann['bbox']
|
| 120 |
+
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
|
| 121 |
+
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
|
| 122 |
+
if inter_w * inter_h == 0:
|
| 123 |
+
continue
|
| 124 |
+
if ann['area'] <= 0 or w < 1 or h < 1:
|
| 125 |
+
continue
|
| 126 |
+
if ann['category_id'] not in self.cat_ids:
|
| 127 |
+
continue
|
| 128 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
| 129 |
+
if ann.get('iscrowd', False):
|
| 130 |
+
gt_bboxes_ignore.append(bbox)
|
| 131 |
+
else:
|
| 132 |
+
gt_bboxes.append(bbox)
|
| 133 |
+
gt_labels.append(self.cat2label[ann['category_id']])
|
| 134 |
+
gt_masks_ann.append(ann.get('segmentation', None))
|
| 135 |
+
|
| 136 |
+
if gt_bboxes:
|
| 137 |
+
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
|
| 138 |
+
gt_labels = np.array(gt_labels, dtype=np.int64)
|
| 139 |
+
else:
|
| 140 |
+
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
|
| 141 |
+
gt_labels = np.array([], dtype=np.int64)
|
| 142 |
+
|
| 143 |
+
if gt_bboxes_ignore:
|
| 144 |
+
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
|
| 145 |
+
else:
|
| 146 |
+
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
| 147 |
+
|
| 148 |
+
seg_map = img_info['filename'].replace('jpg', 'png')
|
| 149 |
+
|
| 150 |
+
ann = dict(
|
| 151 |
+
bboxes=gt_bboxes,
|
| 152 |
+
labels=gt_labels,
|
| 153 |
+
bboxes_ignore=gt_bboxes_ignore,
|
| 154 |
+
masks=gt_masks_ann,
|
| 155 |
+
seg_map=seg_map)
|
| 156 |
+
|
| 157 |
+
return ann
|
| 158 |
+
|
| 159 |
+
def xyxy2xywh(self, bbox):
|
| 160 |
+
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
|
| 161 |
+
evaluation.
|
| 162 |
+
Args:
|
| 163 |
+
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
|
| 164 |
+
``xyxy`` order.
|
| 165 |
+
Returns:
|
| 166 |
+
list[float]: The converted bounding boxes, in ``xywh`` order.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
_bbox = bbox.tolist()
|
| 170 |
+
return [
|
| 171 |
+
_bbox[0],
|
| 172 |
+
_bbox[1],
|
| 173 |
+
_bbox[2] - _bbox[0],
|
| 174 |
+
_bbox[3] - _bbox[1],
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
def _proposal2json(self, results):
|
| 178 |
+
"""Convert proposal results to COCO json style."""
|
| 179 |
+
json_results = []
|
| 180 |
+
for idx in range(len(self)):
|
| 181 |
+
img_id = self.img_ids[idx]
|
| 182 |
+
bboxes = results[idx]
|
| 183 |
+
for i in range(bboxes.shape[0]):
|
| 184 |
+
data = dict()
|
| 185 |
+
data['image_id'] = img_id
|
| 186 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 187 |
+
data['score'] = float(bboxes[i][4])
|
| 188 |
+
data['category_id'] = 1
|
| 189 |
+
json_results.append(data)
|
| 190 |
+
return json_results
|
| 191 |
+
|
| 192 |
+
def _det2json(self, results):
|
| 193 |
+
"""Convert detection results to COCO json style."""
|
| 194 |
+
json_results = []
|
| 195 |
+
for idx in range(len(self)):
|
| 196 |
+
img_id = self.img_ids[idx]
|
| 197 |
+
result = results[idx]
|
| 198 |
+
for label in range(len(result)):
|
| 199 |
+
bboxes = result[label]
|
| 200 |
+
for i in range(bboxes.shape[0]):
|
| 201 |
+
data = dict()
|
| 202 |
+
data['image_id'] = img_id
|
| 203 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 204 |
+
data['score'] = float(bboxes[i][4])
|
| 205 |
+
data['category_id'] = self.cat_ids[label]
|
| 206 |
+
json_results.append(data)
|
| 207 |
+
return json_results
|
| 208 |
+
|
| 209 |
+
def _segm2json(self, results):
|
| 210 |
+
"""Convert instance segmentation results to COCO json style."""
|
| 211 |
+
bbox_json_results = []
|
| 212 |
+
segm_json_results = []
|
| 213 |
+
for idx in range(len(self)):
|
| 214 |
+
img_id = self.img_ids[idx]
|
| 215 |
+
det, seg = results[idx]
|
| 216 |
+
for label in range(len(det)):
|
| 217 |
+
# bbox results
|
| 218 |
+
bboxes = det[label]
|
| 219 |
+
for i in range(bboxes.shape[0]):
|
| 220 |
+
data = dict()
|
| 221 |
+
data['image_id'] = img_id
|
| 222 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 223 |
+
data['score'] = float(bboxes[i][4])
|
| 224 |
+
data['category_id'] = self.cat_ids[label]
|
| 225 |
+
bbox_json_results.append(data)
|
| 226 |
+
|
| 227 |
+
# segm results
|
| 228 |
+
# some detectors use different scores for bbox and mask
|
| 229 |
+
if isinstance(seg, tuple):
|
| 230 |
+
segms = seg[0][label]
|
| 231 |
+
mask_score = seg[1][label]
|
| 232 |
+
else:
|
| 233 |
+
segms = seg[label]
|
| 234 |
+
mask_score = [bbox[4] for bbox in bboxes]
|
| 235 |
+
for i in range(bboxes.shape[0]):
|
| 236 |
+
data = dict()
|
| 237 |
+
data['image_id'] = img_id
|
| 238 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 239 |
+
data['score'] = float(mask_score[i])
|
| 240 |
+
data['category_id'] = self.cat_ids[label]
|
| 241 |
+
if isinstance(segms[i]['counts'], bytes):
|
| 242 |
+
segms[i]['counts'] = segms[i]['counts'].decode()
|
| 243 |
+
data['segmentation'] = segms[i]
|
| 244 |
+
segm_json_results.append(data)
|
| 245 |
+
return bbox_json_results, segm_json_results
|
| 246 |
+
|
| 247 |
+
def results2json(self, results, outfile_prefix):
|
| 248 |
+
"""Dump the detection results to a COCO style json file.
|
| 249 |
+
There are 3 types of results: proposals, bbox predictions, mask
|
| 250 |
+
predictions, and they have different data types. This method will
|
| 251 |
+
automatically recognize the type, and dump them to json files.
|
| 252 |
+
Args:
|
| 253 |
+
results (list[list | tuple | ndarray]): Testing results of the
|
| 254 |
+
dataset.
|
| 255 |
+
outfile_prefix (str): The filename prefix of the json files. If the
|
| 256 |
+
prefix is "somepath/xxx", the json files will be named
|
| 257 |
+
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
|
| 258 |
+
"somepath/xxx.proposal.json".
|
| 259 |
+
Returns:
|
| 260 |
+
dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
|
| 261 |
+
values are corresponding filenames.
|
| 262 |
+
"""
|
| 263 |
+
result_files = dict()
|
| 264 |
+
if isinstance(results[0], list):
|
| 265 |
+
json_results = self._det2json(results)
|
| 266 |
+
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
| 267 |
+
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
| 268 |
+
mmcv.dump(json_results, result_files['bbox'])
|
| 269 |
+
elif isinstance(results[0], tuple):
|
| 270 |
+
json_results = self._segm2json(results)
|
| 271 |
+
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
| 272 |
+
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
| 273 |
+
result_files['segm'] = f'{outfile_prefix}.segm.json'
|
| 274 |
+
mmcv.dump(json_results[0], result_files['bbox'])
|
| 275 |
+
mmcv.dump(json_results[1], result_files['segm'])
|
| 276 |
+
elif isinstance(results[0], np.ndarray):
|
| 277 |
+
json_results = self._proposal2json(results)
|
| 278 |
+
result_files['proposal'] = f'{outfile_prefix}.proposal.json'
|
| 279 |
+
mmcv.dump(json_results, result_files['proposal'])
|
| 280 |
+
else:
|
| 281 |
+
raise TypeError('invalid type of results')
|
| 282 |
+
return result_files
|
| 283 |
+
|
| 284 |
+
def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
|
| 285 |
+
gt_bboxes = []
|
| 286 |
+
for i in range(len(self.img_ids)):
|
| 287 |
+
ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
|
| 288 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 289 |
+
if len(ann_info) == 0:
|
| 290 |
+
gt_bboxes.append(np.zeros((0, 4)))
|
| 291 |
+
continue
|
| 292 |
+
bboxes = []
|
| 293 |
+
for ann in ann_info:
|
| 294 |
+
if ann.get('ignore', False) or ann['iscrowd']:
|
| 295 |
+
continue
|
| 296 |
+
x1, y1, w, h = ann['bbox']
|
| 297 |
+
bboxes.append([x1, y1, x1 + w, y1 + h])
|
| 298 |
+
bboxes = np.array(bboxes, dtype=np.float32)
|
| 299 |
+
if bboxes.shape[0] == 0:
|
| 300 |
+
bboxes = np.zeros((0, 4))
|
| 301 |
+
gt_bboxes.append(bboxes)
|
| 302 |
+
|
| 303 |
+
recalls = eval_recalls(
|
| 304 |
+
gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
|
| 305 |
+
ar = recalls.mean(axis=1)
|
| 306 |
+
return ar
|
| 307 |
+
|
| 308 |
+
def format_results(self, results, jsonfile_prefix=None, **kwargs):
|
| 309 |
+
"""Format the results to json (standard format for COCO evaluation).
|
| 310 |
+
Args:
|
| 311 |
+
results (list[tuple | numpy.ndarray]): Testing results of the
|
| 312 |
+
dataset.
|
| 313 |
+
jsonfile_prefix (str | None): The prefix of json files. It includes
|
| 314 |
+
the file path and the prefix of filename, e.g., "a/b/prefix".
|
| 315 |
+
If not specified, a temp file will be created. Default: None.
|
| 316 |
+
Returns:
|
| 317 |
+
tuple: (result_files, tmp_dir), result_files is a dict containing \
|
| 318 |
+
the json filepaths, tmp_dir is the temporal directory created \
|
| 319 |
+
for saving json files when jsonfile_prefix is not specified.
|
| 320 |
+
"""
|
| 321 |
+
assert isinstance(results, list), 'results must be a list'
|
| 322 |
+
assert len(results) == len(self), (
|
| 323 |
+
'The length of results is not equal to the dataset len: {} != {}'.
|
| 324 |
+
format(len(results), len(self)))
|
| 325 |
+
|
| 326 |
+
if jsonfile_prefix is None:
|
| 327 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
| 328 |
+
jsonfile_prefix = osp.join(tmp_dir.name, 'results')
|
| 329 |
+
else:
|
| 330 |
+
tmp_dir = None
|
| 331 |
+
result_files = self.results2json(results, jsonfile_prefix)
|
| 332 |
+
return result_files, tmp_dir
|
| 333 |
+
|
| 334 |
+
def evaluate(self,
|
| 335 |
+
results,
|
| 336 |
+
metric='bbox',
|
| 337 |
+
logger=None,
|
| 338 |
+
jsonfile_prefix=None,
|
| 339 |
+
classwise=False,
|
| 340 |
+
proposal_nums=(100, 300, 1000),
|
| 341 |
+
iou_thrs=None,
|
| 342 |
+
metric_items=None):
|
| 343 |
+
"""Evaluation in COCO protocol.
|
| 344 |
+
Args:
|
| 345 |
+
results (list[list | tuple]): Testing results of the dataset.
|
| 346 |
+
metric (str | list[str]): Metrics to be evaluated. Options are
|
| 347 |
+
'bbox', 'segm', 'proposal', 'proposal_fast'.
|
| 348 |
+
logger (logging.Logger | str | None): Logger used for printing
|
| 349 |
+
related information during evaluation. Default: None.
|
| 350 |
+
jsonfile_prefix (str | None): The prefix of json files. It includes
|
| 351 |
+
the file path and the prefix of filename, e.g., "a/b/prefix".
|
| 352 |
+
If not specified, a temp file will be created. Default: None.
|
| 353 |
+
classwise (bool): Whether to evaluating the AP for each class.
|
| 354 |
+
proposal_nums (Sequence[int]): Proposal number used for evaluating
|
| 355 |
+
recalls, such as recall@100, recall@1000.
|
| 356 |
+
Default: (100, 300, 1000).
|
| 357 |
+
iou_thrs (Sequence[float], optional): IoU threshold used for
|
| 358 |
+
evaluating recalls/mAPs. If set to a list, the average of all
|
| 359 |
+
IoUs will also be computed. If not specified, [0.50, 0.55,
|
| 360 |
+
0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
|
| 361 |
+
Default: None.
|
| 362 |
+
metric_items (list[str] | str, optional): Metric items that will
|
| 363 |
+
be returned. If not specified, ``['AR@100', 'AR@300',
|
| 364 |
+
'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
|
| 365 |
+
used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
|
| 366 |
+
'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
|
| 367 |
+
``metric=='bbox' or metric=='segm'``.
|
| 368 |
+
Returns:
|
| 369 |
+
dict[str, float]: COCO style evaluation metric.
|
| 370 |
+
"""
|
| 371 |
+
|
| 372 |
+
metrics = metric if isinstance(metric, list) else [metric]
|
| 373 |
+
allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
|
| 374 |
+
for metric in metrics:
|
| 375 |
+
if metric not in allowed_metrics:
|
| 376 |
+
raise KeyError(f'metric {metric} is not supported')
|
| 377 |
+
if iou_thrs is None:
|
| 378 |
+
iou_thrs = np.linspace(
|
| 379 |
+
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
| 380 |
+
if metric_items is not None:
|
| 381 |
+
if not isinstance(metric_items, list):
|
| 382 |
+
metric_items = [metric_items]
|
| 383 |
+
|
| 384 |
+
result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
|
| 385 |
+
|
| 386 |
+
eval_results = OrderedDict()
|
| 387 |
+
cocoGt = self.coco
|
| 388 |
+
for metric in metrics:
|
| 389 |
+
msg = f'Evaluating {metric}...'
|
| 390 |
+
if logger is None:
|
| 391 |
+
msg = '\n' + msg
|
| 392 |
+
print_log(msg, logger=logger)
|
| 393 |
+
|
| 394 |
+
if metric == 'proposal_fast':
|
| 395 |
+
ar = self.fast_eval_recall(
|
| 396 |
+
results, proposal_nums, iou_thrs, logger='silent')
|
| 397 |
+
log_msg = []
|
| 398 |
+
for i, num in enumerate(proposal_nums):
|
| 399 |
+
eval_results[f'AR@{num}'] = ar[i]
|
| 400 |
+
log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
|
| 401 |
+
log_msg = ''.join(log_msg)
|
| 402 |
+
print_log(log_msg, logger=logger)
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
iou_type = 'bbox' if metric == 'proposal' else metric
|
| 406 |
+
if metric not in result_files:
|
| 407 |
+
raise KeyError(f'{metric} is not in results')
|
| 408 |
+
try:
|
| 409 |
+
predictions = mmcv.load(result_files[metric])
|
| 410 |
+
if iou_type == 'segm':
|
| 411 |
+
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
|
| 412 |
+
# When evaluating mask AP, if the results contain bbox,
|
| 413 |
+
# cocoapi will use the box area instead of the mask area
|
| 414 |
+
# for calculating the instance area. Though the overall AP
|
| 415 |
+
# is not affected, this leads to different
|
| 416 |
+
# small/medium/large mask AP results.
|
| 417 |
+
for x in predictions:
|
| 418 |
+
x.pop('bbox')
|
| 419 |
+
warnings.simplefilter('once')
|
| 420 |
+
warnings.warn(
|
| 421 |
+
'The key "bbox" is deleted for more accurate mask AP '
|
| 422 |
+
'of small/medium/large instances since v2.12.0. This '
|
| 423 |
+
'does not change the overall mAP calculation.',
|
| 424 |
+
UserWarning)
|
| 425 |
+
cocoDt = cocoGt.loadRes(predictions)
|
| 426 |
+
except IndexError:
|
| 427 |
+
print_log(
|
| 428 |
+
'The testing results of the whole dataset is empty.',
|
| 429 |
+
logger=logger,
|
| 430 |
+
level=logging.ERROR)
|
| 431 |
+
break
|
| 432 |
+
|
| 433 |
+
cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
|
| 434 |
+
cocoEval.params.catIds = self.cat_ids
|
| 435 |
+
cocoEval.params.imgIds = self.img_ids
|
| 436 |
+
cocoEval.params.maxDets = list(proposal_nums)
|
| 437 |
+
cocoEval.params.iouThrs = iou_thrs
|
| 438 |
+
# mapping of cocoEval.stats
|
| 439 |
+
coco_metric_names = {
|
| 440 |
+
'mAP': 0,
|
| 441 |
+
'mAP_50': 1,
|
| 442 |
+
'mAP_75': 2,
|
| 443 |
+
'mAP_s': 3,
|
| 444 |
+
'mAP_m': 4,
|
| 445 |
+
'mAP_l': 5,
|
| 446 |
+
'AR@100': 6,
|
| 447 |
+
'AR@300': 7,
|
| 448 |
+
'AR@1000': 8,
|
| 449 |
+
'AR_s@1000': 9,
|
| 450 |
+
'AR_m@1000': 10,
|
| 451 |
+
'AR_l@1000': 11
|
| 452 |
+
}
|
| 453 |
+
if metric_items is not None:
|
| 454 |
+
for metric_item in metric_items:
|
| 455 |
+
if metric_item not in coco_metric_names:
|
| 456 |
+
raise KeyError(
|
| 457 |
+
f'metric item {metric_item} is not supported')
|
| 458 |
+
|
| 459 |
+
if metric == 'proposal':
|
| 460 |
+
cocoEval.params.useCats = 0
|
| 461 |
+
cocoEval.evaluate()
|
| 462 |
+
cocoEval.accumulate()
|
| 463 |
+
cocoEval.summarize()
|
| 464 |
+
if metric_items is None:
|
| 465 |
+
metric_items = [
|
| 466 |
+
'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
|
| 467 |
+
'AR_m@1000', 'AR_l@1000'
|
| 468 |
+
]
|
| 469 |
+
|
| 470 |
+
for item in metric_items:
|
| 471 |
+
val = float(
|
| 472 |
+
f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
|
| 473 |
+
eval_results[item] = val
|
| 474 |
+
else:
|
| 475 |
+
cocoEval.evaluate()
|
| 476 |
+
cocoEval.accumulate()
|
| 477 |
+
cocoEval.summarize()
|
| 478 |
+
if classwise: # Compute per-category AP
|
| 479 |
+
# Compute per-category AP
|
| 480 |
+
# from https://github.com/facebookresearch/detectron2/
|
| 481 |
+
precisions = cocoEval.eval['precision']
|
| 482 |
+
# precision: (iou, recall, cls, area range, max dets)
|
| 483 |
+
assert len(self.cat_ids) == precisions.shape[2]
|
| 484 |
+
|
| 485 |
+
results_per_category = []
|
| 486 |
+
for idx, catId in enumerate(self.cat_ids):
|
| 487 |
+
# area range index 0: all area ranges
|
| 488 |
+
# max dets index -1: typically 100 per image
|
| 489 |
+
nm = self.coco.loadCats(catId)[0]
|
| 490 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 491 |
+
precision = precision[precision > -1]
|
| 492 |
+
if precision.size:
|
| 493 |
+
ap = np.mean(precision)
|
| 494 |
+
else:
|
| 495 |
+
ap = float('nan')
|
| 496 |
+
results_per_category.append(
|
| 497 |
+
(f'{nm["name"]}', f'{float(ap):0.3f}'))
|
| 498 |
+
|
| 499 |
+
num_columns = min(6, len(results_per_category) * 2)
|
| 500 |
+
results_flatten = list(
|
| 501 |
+
itertools.chain(*results_per_category))
|
| 502 |
+
headers = ['category', 'AP'] * (num_columns // 2)
|
| 503 |
+
results_2d = itertools.zip_longest(*[
|
| 504 |
+
results_flatten[i::num_columns]
|
| 505 |
+
for i in range(num_columns)
|
| 506 |
+
])
|
| 507 |
+
table_data = [headers]
|
| 508 |
+
table_data += [result for result in results_2d]
|
| 509 |
+
table = AsciiTable(table_data)
|
| 510 |
+
print_log('\n' + table.table, logger=logger)
|
| 511 |
+
|
| 512 |
+
if metric_items is None:
|
| 513 |
+
metric_items = [
|
| 514 |
+
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
|
| 515 |
+
]
|
| 516 |
+
|
| 517 |
+
for metric_item in metric_items:
|
| 518 |
+
key = f'{metric}_{metric_item}'
|
| 519 |
+
val = float(
|
| 520 |
+
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
|
| 521 |
+
)
|
| 522 |
+
eval_results[key] = val
|
| 523 |
+
ap = cocoEval.stats[:6]
|
| 524 |
+
eval_results[f'{metric}_mAP_copypaste'] = (
|
| 525 |
+
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
|
| 526 |
+
f'{ap[4]:.3f} {ap[5]:.3f}')
|
| 527 |
+
if tmp_dir is not None:
|
| 528 |
+
tmp_dir.cleanup()
|
| 529 |
+
return eval_results
|
model/mmdet_custom/datasets/publaynet.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import logging
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import tempfile
|
| 5 |
+
import warnings
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import mmcv
|
| 9 |
+
import numpy as np
|
| 10 |
+
from mmcv.utils import print_log
|
| 11 |
+
from terminaltables import AsciiTable
|
| 12 |
+
|
| 13 |
+
from mmdet.core import eval_recalls
|
| 14 |
+
from mmdet.datasets.api_wrappers import COCO, COCOeval
|
| 15 |
+
|
| 16 |
+
from mmdet.datasets.custom import CustomDataset
|
| 17 |
+
from mmdet.datasets.builder import DATASETS
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@DATASETS.register_module()
|
| 21 |
+
class PubLayNetDataset(CustomDataset):
|
| 22 |
+
|
| 23 |
+
CLASSES = ('text', 'title', 'list', 'table', 'figure',)
|
| 24 |
+
def load_annotations(self, ann_file):
|
| 25 |
+
"""Load annotation from COCO style annotation file.
|
| 26 |
+
Args:
|
| 27 |
+
ann_file (str): Path of annotation file.
|
| 28 |
+
Returns:
|
| 29 |
+
list[dict]: Annotation info from COCO api.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
self.coco = COCO(ann_file)
|
| 33 |
+
# The order of returned `cat_ids` will not
|
| 34 |
+
# change with the order of the CLASSES
|
| 35 |
+
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
|
| 36 |
+
|
| 37 |
+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
| 38 |
+
self.img_ids = self.coco.get_img_ids()
|
| 39 |
+
data_infos = []
|
| 40 |
+
total_ann_ids = []
|
| 41 |
+
for i in self.img_ids:
|
| 42 |
+
info = self.coco.load_imgs([i])[0]
|
| 43 |
+
info['filename'] = info['file_name']
|
| 44 |
+
data_infos.append(info)
|
| 45 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[i])
|
| 46 |
+
total_ann_ids.extend(ann_ids)
|
| 47 |
+
assert len(set(total_ann_ids)) == len(
|
| 48 |
+
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
|
| 49 |
+
return data_infos
|
| 50 |
+
|
| 51 |
+
def get_ann_info(self, idx):
|
| 52 |
+
"""Get COCO annotation by index.
|
| 53 |
+
Args:
|
| 54 |
+
idx (int): Index of data.
|
| 55 |
+
Returns:
|
| 56 |
+
dict: Annotation info of specified index.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
img_id = self.data_infos[idx]['id']
|
| 60 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
| 61 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 62 |
+
return self._parse_ann_info(self.data_infos[idx], ann_info)
|
| 63 |
+
|
| 64 |
+
def get_cat_ids(self, idx):
|
| 65 |
+
"""Get COCO category ids by index.
|
| 66 |
+
Args:
|
| 67 |
+
idx (int): Index of data.
|
| 68 |
+
Returns:
|
| 69 |
+
list[int]: All categories in the image of specified index.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
img_id = self.data_infos[idx]['id']
|
| 73 |
+
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
|
| 74 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 75 |
+
return [ann['category_id'] for ann in ann_info]
|
| 76 |
+
|
| 77 |
+
def _filter_imgs(self, min_size=32):
|
| 78 |
+
"""Filter images too small or without ground truths."""
|
| 79 |
+
valid_inds = []
|
| 80 |
+
# obtain images that contain annotation
|
| 81 |
+
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
|
| 82 |
+
# obtain images that contain annotations of the required categories
|
| 83 |
+
ids_in_cat = set()
|
| 84 |
+
for i, class_id in enumerate(self.cat_ids):
|
| 85 |
+
ids_in_cat |= set(self.coco.cat_img_map[class_id])
|
| 86 |
+
# merge the image id sets of the two conditions and use the merged set
|
| 87 |
+
# to filter out images if self.filter_empty_gt=True
|
| 88 |
+
ids_in_cat &= ids_with_ann
|
| 89 |
+
|
| 90 |
+
valid_img_ids = []
|
| 91 |
+
for i, img_info in enumerate(self.data_infos):
|
| 92 |
+
img_id = self.img_ids[i]
|
| 93 |
+
if self.filter_empty_gt and img_id not in ids_in_cat:
|
| 94 |
+
continue
|
| 95 |
+
if min(img_info['width'], img_info['height']) >= min_size:
|
| 96 |
+
valid_inds.append(i)
|
| 97 |
+
valid_img_ids.append(img_id)
|
| 98 |
+
self.img_ids = valid_img_ids
|
| 99 |
+
return valid_inds
|
| 100 |
+
|
| 101 |
+
def _parse_ann_info(self, img_info, ann_info):
|
| 102 |
+
"""Parse bbox and mask annotation.
|
| 103 |
+
Args:
|
| 104 |
+
ann_info (list[dict]): Annotation info of an image.
|
| 105 |
+
with_mask (bool): Whether to parse mask annotations.
|
| 106 |
+
Returns:
|
| 107 |
+
dict: A dict containing the following keys: bboxes, bboxes_ignore,\
|
| 108 |
+
labels, masks, seg_map. "masks" are raw annotations and not \
|
| 109 |
+
decoded into binary masks.
|
| 110 |
+
"""
|
| 111 |
+
gt_bboxes = []
|
| 112 |
+
gt_labels = []
|
| 113 |
+
gt_bboxes_ignore = []
|
| 114 |
+
gt_masks_ann = []
|
| 115 |
+
for i, ann in enumerate(ann_info):
|
| 116 |
+
if ann.get('ignore', False):
|
| 117 |
+
continue
|
| 118 |
+
x1, y1, w, h = ann['bbox']
|
| 119 |
+
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
|
| 120 |
+
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
|
| 121 |
+
if inter_w * inter_h == 0:
|
| 122 |
+
continue
|
| 123 |
+
if ann['area'] <= 0 or w < 1 or h < 1:
|
| 124 |
+
continue
|
| 125 |
+
if ann['category_id'] not in self.cat_ids:
|
| 126 |
+
continue
|
| 127 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
| 128 |
+
if ann.get('iscrowd', False):
|
| 129 |
+
gt_bboxes_ignore.append(bbox)
|
| 130 |
+
else:
|
| 131 |
+
gt_bboxes.append(bbox)
|
| 132 |
+
gt_labels.append(self.cat2label[ann['category_id']])
|
| 133 |
+
gt_masks_ann.append(ann.get('segmentation', None))
|
| 134 |
+
|
| 135 |
+
if gt_bboxes:
|
| 136 |
+
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
|
| 137 |
+
gt_labels = np.array(gt_labels, dtype=np.int64)
|
| 138 |
+
else:
|
| 139 |
+
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
|
| 140 |
+
gt_labels = np.array([], dtype=np.int64)
|
| 141 |
+
|
| 142 |
+
if gt_bboxes_ignore:
|
| 143 |
+
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
|
| 144 |
+
else:
|
| 145 |
+
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
| 146 |
+
|
| 147 |
+
seg_map = img_info['filename'].replace('jpg', 'png')
|
| 148 |
+
|
| 149 |
+
ann = dict(
|
| 150 |
+
bboxes=gt_bboxes,
|
| 151 |
+
labels=gt_labels,
|
| 152 |
+
bboxes_ignore=gt_bboxes_ignore,
|
| 153 |
+
masks=gt_masks_ann,
|
| 154 |
+
seg_map=seg_map)
|
| 155 |
+
|
| 156 |
+
return ann
|
| 157 |
+
|
| 158 |
+
def xyxy2xywh(self, bbox):
|
| 159 |
+
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
|
| 160 |
+
evaluation.
|
| 161 |
+
Args:
|
| 162 |
+
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
|
| 163 |
+
``xyxy`` order.
|
| 164 |
+
Returns:
|
| 165 |
+
list[float]: The converted bounding boxes, in ``xywh`` order.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
_bbox = bbox.tolist()
|
| 169 |
+
return [
|
| 170 |
+
_bbox[0],
|
| 171 |
+
_bbox[1],
|
| 172 |
+
_bbox[2] - _bbox[0],
|
| 173 |
+
_bbox[3] - _bbox[1],
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
def _proposal2json(self, results):
|
| 177 |
+
"""Convert proposal results to COCO json style."""
|
| 178 |
+
json_results = []
|
| 179 |
+
for idx in range(len(self)):
|
| 180 |
+
img_id = self.img_ids[idx]
|
| 181 |
+
bboxes = results[idx]
|
| 182 |
+
for i in range(bboxes.shape[0]):
|
| 183 |
+
data = dict()
|
| 184 |
+
data['image_id'] = img_id
|
| 185 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 186 |
+
data['score'] = float(bboxes[i][4])
|
| 187 |
+
data['category_id'] = 1
|
| 188 |
+
json_results.append(data)
|
| 189 |
+
return json_results
|
| 190 |
+
|
| 191 |
+
def _det2json(self, results):
|
| 192 |
+
"""Convert detection results to COCO json style."""
|
| 193 |
+
json_results = []
|
| 194 |
+
for idx in range(len(self)):
|
| 195 |
+
img_id = self.img_ids[idx]
|
| 196 |
+
result = results[idx]
|
| 197 |
+
for label in range(len(result)):
|
| 198 |
+
bboxes = result[label]
|
| 199 |
+
for i in range(bboxes.shape[0]):
|
| 200 |
+
data = dict()
|
| 201 |
+
data['image_id'] = img_id
|
| 202 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 203 |
+
data['score'] = float(bboxes[i][4])
|
| 204 |
+
data['category_id'] = self.cat_ids[label]
|
| 205 |
+
json_results.append(data)
|
| 206 |
+
return json_results
|
| 207 |
+
|
| 208 |
+
def _segm2json(self, results):
|
| 209 |
+
"""Convert instance segmentation results to COCO json style."""
|
| 210 |
+
bbox_json_results = []
|
| 211 |
+
segm_json_results = []
|
| 212 |
+
for idx in range(len(self)):
|
| 213 |
+
img_id = self.img_ids[idx]
|
| 214 |
+
det, seg = results[idx]
|
| 215 |
+
for label in range(len(det)):
|
| 216 |
+
# bbox results
|
| 217 |
+
bboxes = det[label]
|
| 218 |
+
for i in range(bboxes.shape[0]):
|
| 219 |
+
data = dict()
|
| 220 |
+
data['image_id'] = img_id
|
| 221 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 222 |
+
data['score'] = float(bboxes[i][4])
|
| 223 |
+
data['category_id'] = self.cat_ids[label]
|
| 224 |
+
bbox_json_results.append(data)
|
| 225 |
+
|
| 226 |
+
# segm results
|
| 227 |
+
# some detectors use different scores for bbox and mask
|
| 228 |
+
if isinstance(seg, tuple):
|
| 229 |
+
segms = seg[0][label]
|
| 230 |
+
mask_score = seg[1][label]
|
| 231 |
+
else:
|
| 232 |
+
segms = seg[label]
|
| 233 |
+
mask_score = [bbox[4] for bbox in bboxes]
|
| 234 |
+
for i in range(bboxes.shape[0]):
|
| 235 |
+
data = dict()
|
| 236 |
+
data['image_id'] = img_id
|
| 237 |
+
data['bbox'] = self.xyxy2xywh(bboxes[i])
|
| 238 |
+
data['score'] = float(mask_score[i])
|
| 239 |
+
data['category_id'] = self.cat_ids[label]
|
| 240 |
+
if isinstance(segms[i]['counts'], bytes):
|
| 241 |
+
segms[i]['counts'] = segms[i]['counts'].decode()
|
| 242 |
+
data['segmentation'] = segms[i]
|
| 243 |
+
segm_json_results.append(data)
|
| 244 |
+
return bbox_json_results, segm_json_results
|
| 245 |
+
|
| 246 |
+
def results2json(self, results, outfile_prefix):
|
| 247 |
+
"""Dump the detection results to a COCO style json file.
|
| 248 |
+
There are 3 types of results: proposals, bbox predictions, mask
|
| 249 |
+
predictions, and they have different data types. This method will
|
| 250 |
+
automatically recognize the type, and dump them to json files.
|
| 251 |
+
Args:
|
| 252 |
+
results (list[list | tuple | ndarray]): Testing results of the
|
| 253 |
+
dataset.
|
| 254 |
+
outfile_prefix (str): The filename prefix of the json files. If the
|
| 255 |
+
prefix is "somepath/xxx", the json files will be named
|
| 256 |
+
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
|
| 257 |
+
"somepath/xxx.proposal.json".
|
| 258 |
+
Returns:
|
| 259 |
+
dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
|
| 260 |
+
values are corresponding filenames.
|
| 261 |
+
"""
|
| 262 |
+
result_files = dict()
|
| 263 |
+
if isinstance(results[0], list):
|
| 264 |
+
json_results = self._det2json(results)
|
| 265 |
+
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
| 266 |
+
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
| 267 |
+
mmcv.dump(json_results, result_files['bbox'])
|
| 268 |
+
elif isinstance(results[0], tuple):
|
| 269 |
+
json_results = self._segm2json(results)
|
| 270 |
+
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
|
| 271 |
+
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
|
| 272 |
+
result_files['segm'] = f'{outfile_prefix}.segm.json'
|
| 273 |
+
mmcv.dump(json_results[0], result_files['bbox'])
|
| 274 |
+
mmcv.dump(json_results[1], result_files['segm'])
|
| 275 |
+
elif isinstance(results[0], np.ndarray):
|
| 276 |
+
json_results = self._proposal2json(results)
|
| 277 |
+
result_files['proposal'] = f'{outfile_prefix}.proposal.json'
|
| 278 |
+
mmcv.dump(json_results, result_files['proposal'])
|
| 279 |
+
else:
|
| 280 |
+
raise TypeError('invalid type of results')
|
| 281 |
+
return result_files
|
| 282 |
+
|
| 283 |
+
def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
|
| 284 |
+
gt_bboxes = []
|
| 285 |
+
for i in range(len(self.img_ids)):
|
| 286 |
+
ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
|
| 287 |
+
ann_info = self.coco.load_anns(ann_ids)
|
| 288 |
+
if len(ann_info) == 0:
|
| 289 |
+
gt_bboxes.append(np.zeros((0, 4)))
|
| 290 |
+
continue
|
| 291 |
+
bboxes = []
|
| 292 |
+
for ann in ann_info:
|
| 293 |
+
if ann.get('ignore', False) or ann['iscrowd']:
|
| 294 |
+
continue
|
| 295 |
+
x1, y1, w, h = ann['bbox']
|
| 296 |
+
bboxes.append([x1, y1, x1 + w, y1 + h])
|
| 297 |
+
bboxes = np.array(bboxes, dtype=np.float32)
|
| 298 |
+
if bboxes.shape[0] == 0:
|
| 299 |
+
bboxes = np.zeros((0, 4))
|
| 300 |
+
gt_bboxes.append(bboxes)
|
| 301 |
+
|
| 302 |
+
recalls = eval_recalls(
|
| 303 |
+
gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
|
| 304 |
+
ar = recalls.mean(axis=1)
|
| 305 |
+
return ar
|
| 306 |
+
|
| 307 |
+
def format_results(self, results, jsonfile_prefix=None, **kwargs):
|
| 308 |
+
"""Format the results to json (standard format for COCO evaluation).
|
| 309 |
+
Args:
|
| 310 |
+
results (list[tuple | numpy.ndarray]): Testing results of the
|
| 311 |
+
dataset.
|
| 312 |
+
jsonfile_prefix (str | None): The prefix of json files. It includes
|
| 313 |
+
the file path and the prefix of filename, e.g., "a/b/prefix".
|
| 314 |
+
If not specified, a temp file will be created. Default: None.
|
| 315 |
+
Returns:
|
| 316 |
+
tuple: (result_files, tmp_dir), result_files is a dict containing \
|
| 317 |
+
the json filepaths, tmp_dir is the temporal directory created \
|
| 318 |
+
for saving json files when jsonfile_prefix is not specified.
|
| 319 |
+
"""
|
| 320 |
+
assert isinstance(results, list), 'results must be a list'
|
| 321 |
+
assert len(results) == len(self), (
|
| 322 |
+
'The length of results is not equal to the dataset len: {} != {}'.
|
| 323 |
+
format(len(results), len(self)))
|
| 324 |
+
|
| 325 |
+
if jsonfile_prefix is None:
|
| 326 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
| 327 |
+
jsonfile_prefix = osp.join(tmp_dir.name, 'results')
|
| 328 |
+
else:
|
| 329 |
+
tmp_dir = None
|
| 330 |
+
result_files = self.results2json(results, jsonfile_prefix)
|
| 331 |
+
return result_files, tmp_dir
|
| 332 |
+
|
| 333 |
+
def evaluate(self,
|
| 334 |
+
results,
|
| 335 |
+
metric='bbox',
|
| 336 |
+
logger=None,
|
| 337 |
+
jsonfile_prefix=None,
|
| 338 |
+
classwise=False,
|
| 339 |
+
proposal_nums=(100, 300, 1000),
|
| 340 |
+
iou_thrs=None,
|
| 341 |
+
metric_items=None):
|
| 342 |
+
"""Evaluation in COCO protocol.
|
| 343 |
+
Args:
|
| 344 |
+
results (list[list | tuple]): Testing results of the dataset.
|
| 345 |
+
metric (str | list[str]): Metrics to be evaluated. Options are
|
| 346 |
+
'bbox', 'segm', 'proposal', 'proposal_fast'.
|
| 347 |
+
logger (logging.Logger | str | None): Logger used for printing
|
| 348 |
+
related information during evaluation. Default: None.
|
| 349 |
+
jsonfile_prefix (str | None): The prefix of json files. It includes
|
| 350 |
+
the file path and the prefix of filename, e.g., "a/b/prefix".
|
| 351 |
+
If not specified, a temp file will be created. Default: None.
|
| 352 |
+
classwise (bool): Whether to evaluating the AP for each class.
|
| 353 |
+
proposal_nums (Sequence[int]): Proposal number used for evaluating
|
| 354 |
+
recalls, such as recall@100, recall@1000.
|
| 355 |
+
Default: (100, 300, 1000).
|
| 356 |
+
iou_thrs (Sequence[float], optional): IoU threshold used for
|
| 357 |
+
evaluating recalls/mAPs. If set to a list, the average of all
|
| 358 |
+
IoUs will also be computed. If not specified, [0.50, 0.55,
|
| 359 |
+
0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
|
| 360 |
+
Default: None.
|
| 361 |
+
metric_items (list[str] | str, optional): Metric items that will
|
| 362 |
+
be returned. If not specified, ``['AR@100', 'AR@300',
|
| 363 |
+
'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
|
| 364 |
+
used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
|
| 365 |
+
'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
|
| 366 |
+
``metric=='bbox' or metric=='segm'``.
|
| 367 |
+
Returns:
|
| 368 |
+
dict[str, float]: COCO style evaluation metric.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
metrics = metric if isinstance(metric, list) else [metric]
|
| 372 |
+
allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
|
| 373 |
+
for metric in metrics:
|
| 374 |
+
if metric not in allowed_metrics:
|
| 375 |
+
raise KeyError(f'metric {metric} is not supported')
|
| 376 |
+
if iou_thrs is None:
|
| 377 |
+
iou_thrs = np.linspace(
|
| 378 |
+
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
|
| 379 |
+
if metric_items is not None:
|
| 380 |
+
if not isinstance(metric_items, list):
|
| 381 |
+
metric_items = [metric_items]
|
| 382 |
+
|
| 383 |
+
result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
|
| 384 |
+
|
| 385 |
+
eval_results = OrderedDict()
|
| 386 |
+
cocoGt = self.coco
|
| 387 |
+
for metric in metrics:
|
| 388 |
+
msg = f'Evaluating {metric}...'
|
| 389 |
+
if logger is None:
|
| 390 |
+
msg = '\n' + msg
|
| 391 |
+
print_log(msg, logger=logger)
|
| 392 |
+
|
| 393 |
+
if metric == 'proposal_fast':
|
| 394 |
+
ar = self.fast_eval_recall(
|
| 395 |
+
results, proposal_nums, iou_thrs, logger='silent')
|
| 396 |
+
log_msg = []
|
| 397 |
+
for i, num in enumerate(proposal_nums):
|
| 398 |
+
eval_results[f'AR@{num}'] = ar[i]
|
| 399 |
+
log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
|
| 400 |
+
log_msg = ''.join(log_msg)
|
| 401 |
+
print_log(log_msg, logger=logger)
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
iou_type = 'bbox' if metric == 'proposal' else metric
|
| 405 |
+
if metric not in result_files:
|
| 406 |
+
raise KeyError(f'{metric} is not in results')
|
| 407 |
+
try:
|
| 408 |
+
predictions = mmcv.load(result_files[metric])
|
| 409 |
+
if iou_type == 'segm':
|
| 410 |
+
# Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
|
| 411 |
+
# When evaluating mask AP, if the results contain bbox,
|
| 412 |
+
# cocoapi will use the box area instead of the mask area
|
| 413 |
+
# for calculating the instance area. Though the overall AP
|
| 414 |
+
# is not affected, this leads to different
|
| 415 |
+
# small/medium/large mask AP results.
|
| 416 |
+
for x in predictions:
|
| 417 |
+
x.pop('bbox')
|
| 418 |
+
warnings.simplefilter('once')
|
| 419 |
+
warnings.warn(
|
| 420 |
+
'The key "bbox" is deleted for more accurate mask AP '
|
| 421 |
+
'of small/medium/large instances since v2.12.0. This '
|
| 422 |
+
'does not change the overall mAP calculation.',
|
| 423 |
+
UserWarning)
|
| 424 |
+
cocoDt = cocoGt.loadRes(predictions)
|
| 425 |
+
except IndexError:
|
| 426 |
+
print_log(
|
| 427 |
+
'The testing results of the whole dataset is empty.',
|
| 428 |
+
logger=logger,
|
| 429 |
+
level=logging.ERROR)
|
| 430 |
+
break
|
| 431 |
+
|
| 432 |
+
cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
|
| 433 |
+
cocoEval.params.catIds = self.cat_ids
|
| 434 |
+
cocoEval.params.imgIds = self.img_ids
|
| 435 |
+
cocoEval.params.maxDets = list(proposal_nums)
|
| 436 |
+
cocoEval.params.iouThrs = iou_thrs
|
| 437 |
+
# mapping of cocoEval.stats
|
| 438 |
+
coco_metric_names = {
|
| 439 |
+
'mAP': 0,
|
| 440 |
+
'mAP_50': 1,
|
| 441 |
+
'mAP_75': 2,
|
| 442 |
+
'mAP_s': 3,
|
| 443 |
+
'mAP_m': 4,
|
| 444 |
+
'mAP_l': 5,
|
| 445 |
+
'AR@100': 6,
|
| 446 |
+
'AR@300': 7,
|
| 447 |
+
'AR@1000': 8,
|
| 448 |
+
'AR_s@1000': 9,
|
| 449 |
+
'AR_m@1000': 10,
|
| 450 |
+
'AR_l@1000': 11
|
| 451 |
+
}
|
| 452 |
+
if metric_items is not None:
|
| 453 |
+
for metric_item in metric_items:
|
| 454 |
+
if metric_item not in coco_metric_names:
|
| 455 |
+
raise KeyError(
|
| 456 |
+
f'metric item {metric_item} is not supported')
|
| 457 |
+
|
| 458 |
+
if metric == 'proposal':
|
| 459 |
+
cocoEval.params.useCats = 0
|
| 460 |
+
cocoEval.evaluate()
|
| 461 |
+
cocoEval.accumulate()
|
| 462 |
+
cocoEval.summarize()
|
| 463 |
+
if metric_items is None:
|
| 464 |
+
metric_items = [
|
| 465 |
+
'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
|
| 466 |
+
'AR_m@1000', 'AR_l@1000'
|
| 467 |
+
]
|
| 468 |
+
|
| 469 |
+
for item in metric_items:
|
| 470 |
+
val = float(
|
| 471 |
+
f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
|
| 472 |
+
eval_results[item] = val
|
| 473 |
+
else:
|
| 474 |
+
cocoEval.evaluate()
|
| 475 |
+
cocoEval.accumulate()
|
| 476 |
+
cocoEval.summarize()
|
| 477 |
+
if classwise: # Compute per-category AP
|
| 478 |
+
# Compute per-category AP
|
| 479 |
+
# from https://github.com/facebookresearch/detectron2/
|
| 480 |
+
precisions = cocoEval.eval['precision']
|
| 481 |
+
# precision: (iou, recall, cls, area range, max dets)
|
| 482 |
+
assert len(self.cat_ids) == precisions.shape[2]
|
| 483 |
+
|
| 484 |
+
results_per_category = []
|
| 485 |
+
for idx, catId in enumerate(self.cat_ids):
|
| 486 |
+
# area range index 0: all area ranges
|
| 487 |
+
# max dets index -1: typically 100 per image
|
| 488 |
+
nm = self.coco.loadCats(catId)[0]
|
| 489 |
+
precision = precisions[:, :, idx, 0, -1]
|
| 490 |
+
precision = precision[precision > -1]
|
| 491 |
+
if precision.size:
|
| 492 |
+
ap = np.mean(precision)
|
| 493 |
+
else:
|
| 494 |
+
ap = float('nan')
|
| 495 |
+
results_per_category.append(
|
| 496 |
+
(f'{nm["name"]}', f'{float(ap):0.3f}'))
|
| 497 |
+
|
| 498 |
+
num_columns = min(6, len(results_per_category) * 2)
|
| 499 |
+
results_flatten = list(
|
| 500 |
+
itertools.chain(*results_per_category))
|
| 501 |
+
headers = ['category', 'AP'] * (num_columns // 2)
|
| 502 |
+
results_2d = itertools.zip_longest(*[
|
| 503 |
+
results_flatten[i::num_columns]
|
| 504 |
+
for i in range(num_columns)
|
| 505 |
+
])
|
| 506 |
+
table_data = [headers]
|
| 507 |
+
table_data += [result for result in results_2d]
|
| 508 |
+
table = AsciiTable(table_data)
|
| 509 |
+
print_log('\n' + table.table, logger=logger)
|
| 510 |
+
|
| 511 |
+
if metric_items is None:
|
| 512 |
+
metric_items = [
|
| 513 |
+
'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
for metric_item in metric_items:
|
| 517 |
+
key = f'{metric}_{metric_item}'
|
| 518 |
+
val = float(
|
| 519 |
+
f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
|
| 520 |
+
)
|
| 521 |
+
eval_results[key] = val
|
| 522 |
+
ap = cocoEval.stats[:6]
|
| 523 |
+
eval_results[f'{metric}_mAP_copypaste'] = (
|
| 524 |
+
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
|
| 525 |
+
f'{ap[4]:.3f} {ap[5]:.3f}')
|
| 526 |
+
if tmp_dir is not None:
|
| 527 |
+
tmp_dir.cleanup()
|
| 528 |
+
return eval_results
|
model/mmdet_custom/models/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from .backbones import * # noqa: F401,F403
|
| 8 |
+
from .dense_heads import * # noqa: F401,F403
|
| 9 |
+
from .detectors import * # noqa: F401,F403
|
| 10 |
+
from .utils import * # noqa: F401,F403
|
model/mmdet_custom/models/backbones/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from .intern_image import InternImage
|
| 8 |
+
from .beit import BEiT
|
| 9 |
+
from .swin_transformer import SwinTransformerV1
|
| 10 |
+
|
| 11 |
+
__all__ = ['InternImage', 'BEiT', 'SwinTransformerV1']
|
model/mmdet_custom/models/backbones/beit.py
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Vision Transformer (ViT) in PyTorch
|
| 2 |
+
|
| 3 |
+
A PyTorch implement of Vision Transformers as described in
|
| 4 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
| 5 |
+
|
| 6 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
| 7 |
+
|
| 8 |
+
Status/TODO:
|
| 9 |
+
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
| 10 |
+
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
| 11 |
+
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
| 12 |
+
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
| 13 |
+
|
| 14 |
+
Acknowledgments:
|
| 15 |
+
* The paper authors for releasing code and weights, thanks!
|
| 16 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
| 17 |
+
for some einops/einsum fun
|
| 18 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
| 19 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
| 20 |
+
|
| 21 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 22 |
+
"""
|
| 23 |
+
import warnings
|
| 24 |
+
import math
|
| 25 |
+
import torch
|
| 26 |
+
from functools import partial
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torch.utils.checkpoint as checkpoint
|
| 30 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
| 31 |
+
from mmdet.models.builder import BACKBONES
|
| 32 |
+
from mmcv.runner import _load_checkpoint
|
| 33 |
+
from mmdet.utils import get_root_logger
|
| 34 |
+
from collections import OrderedDict
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _cfg(url='', **kwargs):
|
| 38 |
+
return {
|
| 39 |
+
'url': url,
|
| 40 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 41 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
| 42 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
| 43 |
+
**kwargs
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DropPath(nn.Module):
|
| 48 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, drop_prob=None):
|
| 52 |
+
super(DropPath, self).__init__()
|
| 53 |
+
self.drop_prob = drop_prob
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 57 |
+
|
| 58 |
+
def extra_repr(self) -> str:
|
| 59 |
+
return 'p={}'.format(self.drop_prob)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Mlp(nn.Module):
|
| 63 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 64 |
+
super().__init__()
|
| 65 |
+
out_features = out_features or in_features
|
| 66 |
+
hidden_features = hidden_features or in_features
|
| 67 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 68 |
+
self.act = act_layer()
|
| 69 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 70 |
+
self.drop = nn.Dropout(drop)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
x = self.fc1(x)
|
| 74 |
+
x = self.act(x)
|
| 75 |
+
# x = self.drop(x)
|
| 76 |
+
# commit this for the orignal BERT implement
|
| 77 |
+
x = self.fc2(x)
|
| 78 |
+
x = self.drop(x)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Attention(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 85 |
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.num_heads = num_heads
|
| 88 |
+
head_dim = dim // num_heads
|
| 89 |
+
if attn_head_dim is not None:
|
| 90 |
+
head_dim = attn_head_dim
|
| 91 |
+
all_head_dim = head_dim * self.num_heads
|
| 92 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 93 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 94 |
+
|
| 95 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
| 96 |
+
if qkv_bias:
|
| 97 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 98 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 99 |
+
else:
|
| 100 |
+
self.q_bias = None
|
| 101 |
+
self.v_bias = None
|
| 102 |
+
|
| 103 |
+
if window_size:
|
| 104 |
+
self.window_size = window_size
|
| 105 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 106 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 107 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 108 |
+
# cls to token & token 2 cls & cls to cls
|
| 109 |
+
|
| 110 |
+
# get pair-wise relative position index for each token inside the window
|
| 111 |
+
coords_h = torch.arange(window_size[0])
|
| 112 |
+
coords_w = torch.arange(window_size[1])
|
| 113 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 114 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 117 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 118 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 119 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 120 |
+
relative_position_index = \
|
| 121 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
| 122 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 123 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
| 124 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
| 125 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
| 126 |
+
|
| 127 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 128 |
+
|
| 129 |
+
# trunc_normal_(self.relative_position_bias_table, std=.0)
|
| 130 |
+
else:
|
| 131 |
+
self.window_size = None
|
| 132 |
+
self.relative_position_bias_table = None
|
| 133 |
+
self.relative_position_index = None
|
| 134 |
+
|
| 135 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 136 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 137 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 138 |
+
|
| 139 |
+
def forward(self, x, rel_pos_bias=None, training_window_size=None):
|
| 140 |
+
B, N, C = x.shape
|
| 141 |
+
qkv_bias = None
|
| 142 |
+
if self.q_bias is not None:
|
| 143 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
| 144 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 145 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 146 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 147 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 148 |
+
|
| 149 |
+
q = q * self.scale
|
| 150 |
+
attn = (q @ k.transpose(-2, -1))
|
| 151 |
+
|
| 152 |
+
if self.relative_position_bias_table is not None:
|
| 153 |
+
if training_window_size == self.window_size:
|
| 154 |
+
relative_position_bias = \
|
| 155 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 156 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 157 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 158 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 159 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 160 |
+
else:
|
| 161 |
+
training_window_size = tuple(training_window_size.tolist())
|
| 162 |
+
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
|
| 163 |
+
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
|
| 164 |
+
new_relative_position_bias_table = F.interpolate(
|
| 165 |
+
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
|
| 166 |
+
2 * self.window_size[0] - 1,
|
| 167 |
+
2 * self.window_size[1] - 1),
|
| 168 |
+
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
|
| 169 |
+
align_corners=False)
|
| 170 |
+
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
|
| 171 |
+
new_num_relative_distance - 3).permute(
|
| 172 |
+
1, 0)
|
| 173 |
+
new_relative_position_bias_table = torch.cat(
|
| 174 |
+
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
|
| 175 |
+
|
| 176 |
+
# get pair-wise relative position index for each token inside the window
|
| 177 |
+
coords_h = torch.arange(training_window_size[0])
|
| 178 |
+
coords_w = torch.arange(training_window_size[1])
|
| 179 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 180 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 181 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 182 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 183 |
+
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
|
| 184 |
+
relative_coords[:, :, 1] += training_window_size[1] - 1
|
| 185 |
+
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
|
| 186 |
+
relative_position_index = \
|
| 187 |
+
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
|
| 188 |
+
dtype=relative_coords.dtype)
|
| 189 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 190 |
+
relative_position_index[0, 0:] = new_num_relative_distance - 3
|
| 191 |
+
relative_position_index[0:, 0] = new_num_relative_distance - 2
|
| 192 |
+
relative_position_index[0, 0] = new_num_relative_distance - 1
|
| 193 |
+
|
| 194 |
+
relative_position_bias = \
|
| 195 |
+
new_relative_position_bias_table[relative_position_index.view(-1)].view(
|
| 196 |
+
training_window_size[0] * training_window_size[1] + 1,
|
| 197 |
+
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 198 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 199 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 200 |
+
|
| 201 |
+
if rel_pos_bias is not None:
|
| 202 |
+
attn = attn + rel_pos_bias
|
| 203 |
+
|
| 204 |
+
attn = attn.softmax(dim=-1)
|
| 205 |
+
attn = self.attn_drop(attn)
|
| 206 |
+
|
| 207 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 208 |
+
x = self.proj(x)
|
| 209 |
+
x = self.proj_drop(x)
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Block(nn.Module):
|
| 214 |
+
|
| 215 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 216 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
| 217 |
+
window_size=None, attn_head_dim=None):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.norm1 = norm_layer(dim)
|
| 220 |
+
self.attn = Attention(
|
| 221 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 222 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
| 223 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 224 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 225 |
+
self.norm2 = norm_layer(dim)
|
| 226 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 227 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 228 |
+
|
| 229 |
+
if init_values is not None:
|
| 230 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
| 231 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
| 232 |
+
else:
|
| 233 |
+
self.gamma_1, self.gamma_2 = None, None
|
| 234 |
+
|
| 235 |
+
def forward(self, x, rel_pos_bias=None, training_window_size=None):
|
| 236 |
+
if self.gamma_1 is None:
|
| 237 |
+
x = x + self.drop_path(
|
| 238 |
+
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
|
| 239 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 240 |
+
else:
|
| 241 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
|
| 242 |
+
training_window_size=training_window_size))
|
| 243 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class PatchEmbed(nn.Module):
|
| 248 |
+
""" Image to Patch Embedding
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
|
| 252 |
+
super().__init__()
|
| 253 |
+
img_size = to_2tuple(img_size)
|
| 254 |
+
patch_size = to_2tuple(patch_size)
|
| 255 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 256 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 257 |
+
self.num_patches_w = self.patch_shape[0]
|
| 258 |
+
self.num_patches_h = self.patch_shape[1]
|
| 259 |
+
# the so-called patch_shape is the patch shape during pre-training
|
| 260 |
+
self.img_size = img_size
|
| 261 |
+
self.patch_size = patch_size
|
| 262 |
+
self.num_patches = num_patches
|
| 263 |
+
|
| 264 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 265 |
+
|
| 266 |
+
def forward(self, x, position_embedding=None, **kwargs):
|
| 267 |
+
# FIXME look at relaxing size constraints
|
| 268 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
| 269 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 270 |
+
x = self.proj(x)
|
| 271 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
| 272 |
+
|
| 273 |
+
if position_embedding is not None:
|
| 274 |
+
# interpolate the position embedding to the corresponding size
|
| 275 |
+
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
|
| 276 |
+
1, 2)
|
| 277 |
+
position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
|
| 278 |
+
x = x + position_embedding
|
| 279 |
+
|
| 280 |
+
x = x.flatten(2).transpose(1, 2)
|
| 281 |
+
return x, (Hp, Wp)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class HybridEmbed(nn.Module):
|
| 285 |
+
""" CNN Feature Map Embedding
|
| 286 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
|
| 290 |
+
super().__init__()
|
| 291 |
+
assert isinstance(backbone, nn.Module)
|
| 292 |
+
img_size = to_2tuple(img_size)
|
| 293 |
+
self.img_size = img_size
|
| 294 |
+
self.backbone = backbone
|
| 295 |
+
if feature_size is None:
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
| 298 |
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
| 299 |
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
| 300 |
+
training = backbone.training
|
| 301 |
+
if training:
|
| 302 |
+
backbone.eval()
|
| 303 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
| 304 |
+
feature_size = o.shape[-2:]
|
| 305 |
+
feature_dim = o.shape[1]
|
| 306 |
+
backbone.train(training)
|
| 307 |
+
else:
|
| 308 |
+
feature_size = to_2tuple(feature_size)
|
| 309 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
| 310 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
| 311 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
| 312 |
+
|
| 313 |
+
def forward(self, x):
|
| 314 |
+
x = self.backbone(x)[-1]
|
| 315 |
+
x = x.flatten(2).transpose(1, 2)
|
| 316 |
+
x = self.proj(x)
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class RelativePositionBias(nn.Module):
|
| 321 |
+
|
| 322 |
+
def __init__(self, window_size, num_heads):
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.window_size = window_size
|
| 325 |
+
self.num_heads = num_heads
|
| 326 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 327 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 328 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 329 |
+
# cls to token & token 2 cls & cls to cls
|
| 330 |
+
|
| 331 |
+
# get pair-wise relative position index for each token inside the window
|
| 332 |
+
coords_h = torch.arange(window_size[0])
|
| 333 |
+
coords_w = torch.arange(window_size[1])
|
| 334 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 335 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 336 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 337 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 338 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 339 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 340 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 341 |
+
relative_position_index = \
|
| 342 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
| 343 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 344 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
| 345 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
| 346 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
| 347 |
+
|
| 348 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 349 |
+
|
| 350 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 351 |
+
|
| 352 |
+
def forward(self, training_window_size):
|
| 353 |
+
if training_window_size == self.window_size:
|
| 354 |
+
relative_position_bias = \
|
| 355 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 356 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 357 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 358 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 359 |
+
else:
|
| 360 |
+
training_window_size = tuple(training_window_size.tolist())
|
| 361 |
+
new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
|
| 362 |
+
# new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
|
| 363 |
+
new_relative_position_bias_table = F.interpolate(
|
| 364 |
+
self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
|
| 365 |
+
2 * self.window_size[0] - 1,
|
| 366 |
+
2 * self.window_size[1] - 1),
|
| 367 |
+
size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
|
| 368 |
+
align_corners=False)
|
| 369 |
+
new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
|
| 370 |
+
new_num_relative_distance - 3).permute(
|
| 371 |
+
1, 0)
|
| 372 |
+
new_relative_position_bias_table = torch.cat(
|
| 373 |
+
[new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
|
| 374 |
+
|
| 375 |
+
# get pair-wise relative position index for each token inside the window
|
| 376 |
+
coords_h = torch.arange(training_window_size[0])
|
| 377 |
+
coords_w = torch.arange(training_window_size[1])
|
| 378 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 379 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 380 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 381 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 382 |
+
relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
|
| 383 |
+
relative_coords[:, :, 1] += training_window_size[1] - 1
|
| 384 |
+
relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
|
| 385 |
+
relative_position_index = \
|
| 386 |
+
torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
|
| 387 |
+
dtype=relative_coords.dtype)
|
| 388 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 389 |
+
relative_position_index[0, 0:] = new_num_relative_distance - 3
|
| 390 |
+
relative_position_index[0:, 0] = new_num_relative_distance - 2
|
| 391 |
+
relative_position_index[0, 0] = new_num_relative_distance - 1
|
| 392 |
+
|
| 393 |
+
relative_position_bias = \
|
| 394 |
+
new_relative_position_bias_table[relative_position_index.view(-1)].view(
|
| 395 |
+
training_window_size[0] * training_window_size[1] + 1,
|
| 396 |
+
training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 397 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 398 |
+
|
| 399 |
+
return relative_position_bias
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@BACKBONES.register_module()
|
| 403 |
+
class BEiT(nn.Module):
|
| 404 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(self,
|
| 408 |
+
img_size=[224, 224],
|
| 409 |
+
patch_size=16,
|
| 410 |
+
in_chans=3,
|
| 411 |
+
num_classes=11,
|
| 412 |
+
embed_dim=768,
|
| 413 |
+
depth=12,
|
| 414 |
+
num_heads=12,
|
| 415 |
+
mlp_ratio=4.,
|
| 416 |
+
qkv_bias=False,
|
| 417 |
+
qk_scale=None,
|
| 418 |
+
drop_rate=0.,
|
| 419 |
+
attn_drop_rate=0.,
|
| 420 |
+
drop_path_rate=0.,
|
| 421 |
+
hybrid_backbone=None,
|
| 422 |
+
norm_layer=nn.LayerNorm,
|
| 423 |
+
init_values=None,
|
| 424 |
+
use_abs_pos_emb=False,
|
| 425 |
+
use_rel_pos_bias=False,
|
| 426 |
+
use_shared_rel_pos_bias=False,
|
| 427 |
+
use_checkpoint=True,
|
| 428 |
+
init_cfg=None,
|
| 429 |
+
out_indices=(0, 1, 2, 3),
|
| 430 |
+
):
|
| 431 |
+
|
| 432 |
+
super(BEiT, self).__init__()
|
| 433 |
+
self.init_cfg = init_cfg
|
| 434 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 435 |
+
self.num_classes = num_classes
|
| 436 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 437 |
+
self.use_checkpoint = use_checkpoint
|
| 438 |
+
|
| 439 |
+
if hybrid_backbone is not None:
|
| 440 |
+
self.patch_embed = HybridEmbed(
|
| 441 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 442 |
+
else:
|
| 443 |
+
self.patch_embed = PatchEmbed(
|
| 444 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 445 |
+
num_patches = self.patch_embed.num_patches
|
| 446 |
+
self.out_indices = out_indices
|
| 447 |
+
|
| 448 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 449 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 450 |
+
if use_abs_pos_emb:
|
| 451 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 452 |
+
else:
|
| 453 |
+
self.pos_embed = None
|
| 454 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 455 |
+
|
| 456 |
+
self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
|
| 457 |
+
if use_shared_rel_pos_bias:
|
| 458 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
| 459 |
+
else:
|
| 460 |
+
self.rel_pos_bias = None
|
| 461 |
+
|
| 462 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 463 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
| 464 |
+
self.blocks = nn.ModuleList([
|
| 465 |
+
Block(
|
| 466 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 467 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 468 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
| 469 |
+
for i in range(depth)])
|
| 470 |
+
|
| 471 |
+
# trunc_normal_(self.mask_token, std=.02)
|
| 472 |
+
|
| 473 |
+
if patch_size == 16:
|
| 474 |
+
self.fpn1 = nn.Sequential(
|
| 475 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
| 476 |
+
# nn.SyncBatchNorm(embed_dim),
|
| 477 |
+
nn.BatchNorm2d(embed_dim),
|
| 478 |
+
nn.GELU(),
|
| 479 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
self.fpn2 = nn.Sequential(
|
| 483 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
self.fpn3 = nn.Identity()
|
| 487 |
+
|
| 488 |
+
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 489 |
+
elif patch_size == 8:
|
| 490 |
+
self.fpn1 = nn.Sequential(
|
| 491 |
+
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
self.fpn2 = nn.Identity()
|
| 495 |
+
|
| 496 |
+
self.fpn3 = nn.Sequential(
|
| 497 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
self.fpn4 = nn.Sequential(
|
| 501 |
+
nn.MaxPool2d(kernel_size=4, stride=4),
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if self.pos_embed is not None:
|
| 505 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 506 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 507 |
+
self.apply(self._init_weights)
|
| 508 |
+
|
| 509 |
+
def init_weights(self):
|
| 510 |
+
logger = get_root_logger()
|
| 511 |
+
if self.init_cfg is None:
|
| 512 |
+
logger.warn(f'No pre-trained weights for '
|
| 513 |
+
f'{self.__class__.__name__}, '
|
| 514 |
+
f'training start from scratch')
|
| 515 |
+
for m in self.modules():
|
| 516 |
+
if isinstance(m, nn.Linear):
|
| 517 |
+
trunc_normal_init(m, std=.02, bias=0.)
|
| 518 |
+
elif isinstance(m, nn.LayerNorm):
|
| 519 |
+
constant_init(m, 1.0)
|
| 520 |
+
else:
|
| 521 |
+
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
| 522 |
+
f'specify `Pretrained` in ' \
|
| 523 |
+
f'`init_cfg` in ' \
|
| 524 |
+
f'{self.__class__.__name__} '
|
| 525 |
+
ckpt = _load_checkpoint(self.init_cfg.checkpoint,
|
| 526 |
+
logger=logger,
|
| 527 |
+
map_location='cpu')
|
| 528 |
+
if 'state_dict' in ckpt:
|
| 529 |
+
_state_dict = ckpt['state_dict']
|
| 530 |
+
elif 'model' in ckpt:
|
| 531 |
+
_state_dict = ckpt['model']
|
| 532 |
+
else:
|
| 533 |
+
_state_dict = ckpt
|
| 534 |
+
|
| 535 |
+
state_dict = OrderedDict()
|
| 536 |
+
for k, v in _state_dict.items():
|
| 537 |
+
if k.startswith('backbone.'):
|
| 538 |
+
state_dict[k[9:]] = v
|
| 539 |
+
else:
|
| 540 |
+
state_dict[k] = v
|
| 541 |
+
|
| 542 |
+
# strip prefix of state_dict
|
| 543 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
| 544 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 545 |
+
|
| 546 |
+
# load state_dict
|
| 547 |
+
meg = self.load_state_dict(state_dict, False)
|
| 548 |
+
logger.info(meg)
|
| 549 |
+
|
| 550 |
+
def _init_weights(self, m):
|
| 551 |
+
if isinstance(m, nn.Linear):
|
| 552 |
+
trunc_normal_(m.weight, std=.02)
|
| 553 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 554 |
+
nn.init.constant_(m.bias, 0)
|
| 555 |
+
elif isinstance(m, nn.LayerNorm):
|
| 556 |
+
nn.init.constant_(m.bias, 0)
|
| 557 |
+
nn.init.constant_(m.weight, 1.0)
|
| 558 |
+
|
| 559 |
+
def get_num_layers(self):
|
| 560 |
+
return len(self.blocks)
|
| 561 |
+
|
| 562 |
+
@torch.jit.ignore
|
| 563 |
+
def no_weight_decay(self):
|
| 564 |
+
return {'pos_embed', 'cls_token'}
|
| 565 |
+
|
| 566 |
+
def forward_features(self, x):
|
| 567 |
+
B, C, H, W = x.shape
|
| 568 |
+
x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
|
| 569 |
+
# Hp, Wp are HW for patches
|
| 570 |
+
batch_size, seq_len, _ = x.size()
|
| 571 |
+
|
| 572 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 573 |
+
if self.pos_embed is not None:
|
| 574 |
+
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
|
| 575 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 576 |
+
x = self.pos_drop(x)
|
| 577 |
+
|
| 578 |
+
features = []
|
| 579 |
+
training_window_size = torch.tensor([Hp, Wp])
|
| 580 |
+
|
| 581 |
+
rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
|
| 582 |
+
|
| 583 |
+
for i, blk in enumerate(self.blocks):
|
| 584 |
+
if self.use_checkpoint:
|
| 585 |
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
|
| 586 |
+
else:
|
| 587 |
+
x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
|
| 588 |
+
if i in self.out_indices:
|
| 589 |
+
xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
|
| 590 |
+
features.append(xp.contiguous())
|
| 591 |
+
|
| 592 |
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
| 593 |
+
for i in range(len(features)):
|
| 594 |
+
features[i] = ops[i](features[i])
|
| 595 |
+
|
| 596 |
+
return features
|
| 597 |
+
|
| 598 |
+
def forward(self, x):
|
| 599 |
+
x = self.forward_features(x)
|
| 600 |
+
return x
|
| 601 |
+
|
model/mmdet_custom/models/backbones/intern_image.py
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
import torch.utils.checkpoint as checkpoint
|
| 11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 12 |
+
from mmcv.runner import _load_checkpoint
|
| 13 |
+
from mmcv.cnn import constant_init, trunc_normal_init
|
| 14 |
+
from mmdet.utils import get_root_logger
|
| 15 |
+
from mmdet.models.builder import BACKBONES
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from ops_dcnv3 import modules as opsm
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class to_channels_first(nn.Module):
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return x.permute(0, 3, 1, 2)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class to_channels_last(nn.Module):
|
| 31 |
+
|
| 32 |
+
def __init__(self):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return x.permute(0, 2, 3, 1)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_norm_layer(dim,
|
| 40 |
+
norm_layer,
|
| 41 |
+
in_format='channels_last',
|
| 42 |
+
out_format='channels_last',
|
| 43 |
+
eps=1e-6):
|
| 44 |
+
layers = []
|
| 45 |
+
if norm_layer == 'BN':
|
| 46 |
+
if in_format == 'channels_last':
|
| 47 |
+
layers.append(to_channels_first())
|
| 48 |
+
layers.append(nn.BatchNorm2d(dim))
|
| 49 |
+
if out_format == 'channels_last':
|
| 50 |
+
layers.append(to_channels_last())
|
| 51 |
+
elif norm_layer == 'LN':
|
| 52 |
+
if in_format == 'channels_first':
|
| 53 |
+
layers.append(to_channels_last())
|
| 54 |
+
layers.append(nn.LayerNorm(dim, eps=eps))
|
| 55 |
+
if out_format == 'channels_first':
|
| 56 |
+
layers.append(to_channels_first())
|
| 57 |
+
else:
|
| 58 |
+
raise NotImplementedError(
|
| 59 |
+
f'build_norm_layer does not support {norm_layer}')
|
| 60 |
+
return nn.Sequential(*layers)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def build_act_layer(act_layer):
|
| 64 |
+
if act_layer == 'ReLU':
|
| 65 |
+
return nn.ReLU(inplace=True)
|
| 66 |
+
elif act_layer == 'SiLU':
|
| 67 |
+
return nn.SiLU(inplace=True)
|
| 68 |
+
elif act_layer == 'GELU':
|
| 69 |
+
return nn.GELU()
|
| 70 |
+
|
| 71 |
+
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CrossAttention(nn.Module):
|
| 75 |
+
r""" Cross Attention Module
|
| 76 |
+
Args:
|
| 77 |
+
dim (int): Number of input channels.
|
| 78 |
+
num_heads (int): Number of attention heads. Default: 8
|
| 79 |
+
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
| 80 |
+
Default: False.
|
| 81 |
+
qk_scale (float | None, optional): Override default qk scale of
|
| 82 |
+
head_dim ** -0.5 if set. Default: None.
|
| 83 |
+
attn_drop (float, optional): Dropout ratio of attention weight.
|
| 84 |
+
Default: 0.0
|
| 85 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 86 |
+
attn_head_dim (int, optional): Dimension of attention head.
|
| 87 |
+
out_dim (int, optional): Dimension of output.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self,
|
| 91 |
+
dim,
|
| 92 |
+
num_heads=8,
|
| 93 |
+
qkv_bias=False,
|
| 94 |
+
qk_scale=None,
|
| 95 |
+
attn_drop=0.,
|
| 96 |
+
proj_drop=0.,
|
| 97 |
+
attn_head_dim=None,
|
| 98 |
+
out_dim=None):
|
| 99 |
+
super().__init__()
|
| 100 |
+
if out_dim is None:
|
| 101 |
+
out_dim = dim
|
| 102 |
+
self.num_heads = num_heads
|
| 103 |
+
head_dim = dim // num_heads
|
| 104 |
+
if attn_head_dim is not None:
|
| 105 |
+
head_dim = attn_head_dim
|
| 106 |
+
all_head_dim = head_dim * self.num_heads
|
| 107 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 108 |
+
assert all_head_dim == dim
|
| 109 |
+
|
| 110 |
+
self.q = nn.Linear(dim, all_head_dim, bias=False)
|
| 111 |
+
self.k = nn.Linear(dim, all_head_dim, bias=False)
|
| 112 |
+
self.v = nn.Linear(dim, all_head_dim, bias=False)
|
| 113 |
+
|
| 114 |
+
if qkv_bias:
|
| 115 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 116 |
+
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 117 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 118 |
+
else:
|
| 119 |
+
self.q_bias = None
|
| 120 |
+
self.k_bias = None
|
| 121 |
+
self.v_bias = None
|
| 122 |
+
|
| 123 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 124 |
+
self.proj = nn.Linear(all_head_dim, out_dim)
|
| 125 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 126 |
+
|
| 127 |
+
def forward(self, x, k=None, v=None):
|
| 128 |
+
B, N, C = x.shape
|
| 129 |
+
N_k = k.shape[1]
|
| 130 |
+
N_v = v.shape[1]
|
| 131 |
+
|
| 132 |
+
q_bias, k_bias, v_bias = None, None, None
|
| 133 |
+
if self.q_bias is not None:
|
| 134 |
+
q_bias = self.q_bias
|
| 135 |
+
k_bias = self.k_bias
|
| 136 |
+
v_bias = self.v_bias
|
| 137 |
+
|
| 138 |
+
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
| 139 |
+
q = q.reshape(B, N, 1, self.num_heads,
|
| 140 |
+
-1).permute(2, 0, 3, 1,
|
| 141 |
+
4).squeeze(0) # (B, N_head, N_q, dim)
|
| 142 |
+
|
| 143 |
+
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
| 144 |
+
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
| 145 |
+
4).squeeze(0)
|
| 146 |
+
|
| 147 |
+
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
| 148 |
+
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
| 149 |
+
4).squeeze(0)
|
| 150 |
+
|
| 151 |
+
q = q * self.scale
|
| 152 |
+
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
| 153 |
+
|
| 154 |
+
attn = attn.softmax(dim=-1)
|
| 155 |
+
attn = self.attn_drop(attn)
|
| 156 |
+
|
| 157 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 158 |
+
x = self.proj(x)
|
| 159 |
+
x = self.proj_drop(x)
|
| 160 |
+
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class AttentiveBlock(nn.Module):
|
| 165 |
+
r"""Attentive Block
|
| 166 |
+
Args:
|
| 167 |
+
dim (int): Number of input channels.
|
| 168 |
+
num_heads (int): Number of attention heads. Default: 8
|
| 169 |
+
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
| 170 |
+
Default: False.
|
| 171 |
+
qk_scale (float | None, optional): Override default qk scale of
|
| 172 |
+
head_dim ** -0.5 if set. Default: None.
|
| 173 |
+
drop (float, optional): Dropout rate. Default: 0.0.
|
| 174 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
|
| 175 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate.
|
| 176 |
+
Default: 0.0.
|
| 177 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
|
| 178 |
+
attn_head_dim (int, optional): Dimension of attention head. Default: None.
|
| 179 |
+
out_dim (int, optional): Dimension of output. Default: None.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(self,
|
| 183 |
+
dim,
|
| 184 |
+
num_heads,
|
| 185 |
+
qkv_bias=False,
|
| 186 |
+
qk_scale=None,
|
| 187 |
+
drop=0.,
|
| 188 |
+
attn_drop=0.,
|
| 189 |
+
drop_path=0.,
|
| 190 |
+
norm_layer="LN",
|
| 191 |
+
attn_head_dim=None,
|
| 192 |
+
out_dim=None):
|
| 193 |
+
super().__init__()
|
| 194 |
+
|
| 195 |
+
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
|
| 196 |
+
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
|
| 197 |
+
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
|
| 198 |
+
self.cross_dcn = CrossAttention(dim,
|
| 199 |
+
num_heads=num_heads,
|
| 200 |
+
qkv_bias=qkv_bias,
|
| 201 |
+
qk_scale=qk_scale,
|
| 202 |
+
attn_drop=attn_drop,
|
| 203 |
+
proj_drop=drop,
|
| 204 |
+
attn_head_dim=attn_head_dim,
|
| 205 |
+
out_dim=out_dim)
|
| 206 |
+
|
| 207 |
+
self.drop_path = DropPath(
|
| 208 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
| 209 |
+
|
| 210 |
+
def forward(self,
|
| 211 |
+
x_q,
|
| 212 |
+
x_kv,
|
| 213 |
+
pos_q,
|
| 214 |
+
pos_k,
|
| 215 |
+
bool_masked_pos,
|
| 216 |
+
rel_pos_bias=None):
|
| 217 |
+
x_q = self.norm1_q(x_q + pos_q)
|
| 218 |
+
x_k = self.norm1_k(x_kv + pos_k)
|
| 219 |
+
x_v = self.norm1_v(x_kv)
|
| 220 |
+
|
| 221 |
+
x = self.cross_dcn(x_q, k=x_k, v=x_v)
|
| 222 |
+
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class AttentionPoolingBlock(AttentiveBlock):
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
x_q = x.mean(1, keepdim=True)
|
| 230 |
+
x_kv = x
|
| 231 |
+
pos_q, pos_k = 0, 0
|
| 232 |
+
x = super().forward(x_q, x_kv, pos_q, pos_k,
|
| 233 |
+
bool_masked_pos=None,
|
| 234 |
+
rel_pos_bias=None)
|
| 235 |
+
x = x.squeeze(1)
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class StemLayer(nn.Module):
|
| 240 |
+
r""" Stem layer of InternImage
|
| 241 |
+
Args:
|
| 242 |
+
in_chans (int): number of input channels
|
| 243 |
+
out_chans (int): number of output channels
|
| 244 |
+
act_layer (str): activation layer
|
| 245 |
+
norm_layer (str): normalization layer
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self,
|
| 249 |
+
in_chans=3,
|
| 250 |
+
out_chans=96,
|
| 251 |
+
act_layer='GELU',
|
| 252 |
+
norm_layer='BN'):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.conv1 = nn.Conv2d(in_chans,
|
| 255 |
+
out_chans // 2,
|
| 256 |
+
kernel_size=3,
|
| 257 |
+
stride=2,
|
| 258 |
+
padding=1)
|
| 259 |
+
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
|
| 260 |
+
'channels_first', 'channels_first')
|
| 261 |
+
self.act = build_act_layer(act_layer)
|
| 262 |
+
self.conv2 = nn.Conv2d(out_chans // 2,
|
| 263 |
+
out_chans,
|
| 264 |
+
kernel_size=3,
|
| 265 |
+
stride=2,
|
| 266 |
+
padding=1)
|
| 267 |
+
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
|
| 268 |
+
'channels_last')
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
x = self.conv1(x)
|
| 272 |
+
x = self.norm1(x)
|
| 273 |
+
x = self.act(x)
|
| 274 |
+
x = self.conv2(x)
|
| 275 |
+
x = self.norm2(x)
|
| 276 |
+
return x
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class DownsampleLayer(nn.Module):
|
| 280 |
+
r""" Downsample layer of InternImage
|
| 281 |
+
Args:
|
| 282 |
+
channels (int): number of input channels
|
| 283 |
+
norm_layer (str): normalization layer
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, channels, norm_layer='LN'):
|
| 287 |
+
super().__init__()
|
| 288 |
+
self.conv = nn.Conv2d(channels,
|
| 289 |
+
2 * channels,
|
| 290 |
+
kernel_size=3,
|
| 291 |
+
stride=2,
|
| 292 |
+
padding=1,
|
| 293 |
+
bias=False)
|
| 294 |
+
self.norm = build_norm_layer(2 * channels, norm_layer,
|
| 295 |
+
'channels_first', 'channels_last')
|
| 296 |
+
|
| 297 |
+
def forward(self, x):
|
| 298 |
+
x = self.conv(x.permute(0, 3, 1, 2))
|
| 299 |
+
x = self.norm(x)
|
| 300 |
+
return x
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class MLPLayer(nn.Module):
|
| 304 |
+
r""" MLP layer of InternImage
|
| 305 |
+
Args:
|
| 306 |
+
in_features (int): number of input features
|
| 307 |
+
hidden_features (int): number of hidden features
|
| 308 |
+
out_features (int): number of output features
|
| 309 |
+
act_layer (str): activation layer
|
| 310 |
+
drop (float): dropout rate
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(self,
|
| 314 |
+
in_features,
|
| 315 |
+
hidden_features=None,
|
| 316 |
+
out_features=None,
|
| 317 |
+
act_layer='GELU',
|
| 318 |
+
drop=0.):
|
| 319 |
+
super().__init__()
|
| 320 |
+
out_features = out_features or in_features
|
| 321 |
+
hidden_features = hidden_features or in_features
|
| 322 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 323 |
+
self.act = build_act_layer(act_layer)
|
| 324 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 325 |
+
self.drop = nn.Dropout(drop)
|
| 326 |
+
|
| 327 |
+
def forward(self, x):
|
| 328 |
+
x = self.fc1(x)
|
| 329 |
+
x = self.act(x)
|
| 330 |
+
x = self.drop(x)
|
| 331 |
+
x = self.fc2(x)
|
| 332 |
+
x = self.drop(x)
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class InternImageLayer(nn.Module):
|
| 337 |
+
r""" Basic layer of InternImage
|
| 338 |
+
Args:
|
| 339 |
+
core_op (nn.Module): core operation of InternImage
|
| 340 |
+
channels (int): number of input channels
|
| 341 |
+
groups (list): Groups of each block.
|
| 342 |
+
mlp_ratio (float): ratio of mlp hidden features to input channels
|
| 343 |
+
drop (float): dropout rate
|
| 344 |
+
drop_path (float): drop path rate
|
| 345 |
+
act_layer (str): activation layer
|
| 346 |
+
norm_layer (str): normalization layer
|
| 347 |
+
post_norm (bool): whether to use post normalization
|
| 348 |
+
layer_scale (float): layer scale
|
| 349 |
+
offset_scale (float): offset scale
|
| 350 |
+
with_cp (bool): whether to use checkpoint
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
def __init__(self,
|
| 354 |
+
core_op,
|
| 355 |
+
channels,
|
| 356 |
+
groups,
|
| 357 |
+
mlp_ratio=4.,
|
| 358 |
+
drop=0.,
|
| 359 |
+
drop_path=0.,
|
| 360 |
+
act_layer='GELU',
|
| 361 |
+
norm_layer='LN',
|
| 362 |
+
post_norm=False,
|
| 363 |
+
layer_scale=None,
|
| 364 |
+
offset_scale=1.0,
|
| 365 |
+
with_cp=False,
|
| 366 |
+
dw_kernel_size=None, # for InternImage-H/G
|
| 367 |
+
res_post_norm=False, # for InternImage-H/G
|
| 368 |
+
center_feature_scale=False): # for InternImage-H/G
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.channels = channels
|
| 371 |
+
self.groups = groups
|
| 372 |
+
self.mlp_ratio = mlp_ratio
|
| 373 |
+
self.with_cp = with_cp
|
| 374 |
+
|
| 375 |
+
self.norm1 = build_norm_layer(channels, 'LN')
|
| 376 |
+
self.post_norm = post_norm
|
| 377 |
+
self.dcn = core_op(
|
| 378 |
+
channels=channels,
|
| 379 |
+
kernel_size=3,
|
| 380 |
+
stride=1,
|
| 381 |
+
pad=1,
|
| 382 |
+
dilation=1,
|
| 383 |
+
group=groups,
|
| 384 |
+
offset_scale=offset_scale,
|
| 385 |
+
act_layer=act_layer,
|
| 386 |
+
norm_layer=norm_layer,
|
| 387 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
| 388 |
+
center_feature_scale=center_feature_scale) # for InternImage-H/G
|
| 389 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
| 390 |
+
else nn.Identity()
|
| 391 |
+
self.norm2 = build_norm_layer(channels, 'LN')
|
| 392 |
+
self.mlp = MLPLayer(in_features=channels,
|
| 393 |
+
hidden_features=int(channels * mlp_ratio),
|
| 394 |
+
act_layer=act_layer,
|
| 395 |
+
drop=drop)
|
| 396 |
+
self.layer_scale = layer_scale is not None
|
| 397 |
+
if self.layer_scale:
|
| 398 |
+
self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels),
|
| 399 |
+
requires_grad=True)
|
| 400 |
+
self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
|
| 401 |
+
requires_grad=True)
|
| 402 |
+
self.res_post_norm = res_post_norm
|
| 403 |
+
if res_post_norm:
|
| 404 |
+
self.res_post_norm1 = build_norm_layer(channels, 'LN')
|
| 405 |
+
self.res_post_norm2 = build_norm_layer(channels, 'LN')
|
| 406 |
+
|
| 407 |
+
def forward(self, x):
|
| 408 |
+
|
| 409 |
+
def _inner_forward(x):
|
| 410 |
+
if not self.layer_scale:
|
| 411 |
+
if self.post_norm:
|
| 412 |
+
x = x + self.drop_path(self.norm1(self.dcn(x)))
|
| 413 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
| 414 |
+
elif self.res_post_norm: # for InternImage-H/G
|
| 415 |
+
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
|
| 416 |
+
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
|
| 417 |
+
else:
|
| 418 |
+
x = x + self.drop_path(self.dcn(self.norm1(x)))
|
| 419 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 420 |
+
return x
|
| 421 |
+
if self.post_norm:
|
| 422 |
+
x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
|
| 423 |
+
x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
|
| 424 |
+
else:
|
| 425 |
+
x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
|
| 426 |
+
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
if self.with_cp and x.requires_grad:
|
| 430 |
+
x = checkpoint.checkpoint(_inner_forward, x)
|
| 431 |
+
else:
|
| 432 |
+
x = _inner_forward(x)
|
| 433 |
+
return x
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class InternImageBlock(nn.Module):
|
| 437 |
+
r""" Block of InternImage
|
| 438 |
+
Args:
|
| 439 |
+
core_op (nn.Module): core operation of InternImage
|
| 440 |
+
channels (int): number of input channels
|
| 441 |
+
depths (list): Depth of each block.
|
| 442 |
+
groups (list): Groups of each block.
|
| 443 |
+
mlp_ratio (float): ratio of mlp hidden features to input channels
|
| 444 |
+
drop (float): dropout rate
|
| 445 |
+
drop_path (float): drop path rate
|
| 446 |
+
act_layer (str): activation layer
|
| 447 |
+
norm_layer (str): normalization layer
|
| 448 |
+
post_norm (bool): whether to use post normalization
|
| 449 |
+
layer_scale (float): layer scale
|
| 450 |
+
offset_scale (float): offset scale
|
| 451 |
+
with_cp (bool): whether to use checkpoint
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
def __init__(self,
|
| 455 |
+
core_op,
|
| 456 |
+
channels,
|
| 457 |
+
depth,
|
| 458 |
+
groups,
|
| 459 |
+
downsample=True,
|
| 460 |
+
mlp_ratio=4.,
|
| 461 |
+
drop=0.,
|
| 462 |
+
drop_path=0.,
|
| 463 |
+
act_layer='GELU',
|
| 464 |
+
norm_layer='LN',
|
| 465 |
+
post_norm=False,
|
| 466 |
+
offset_scale=1.0,
|
| 467 |
+
layer_scale=None,
|
| 468 |
+
with_cp=False,
|
| 469 |
+
dw_kernel_size=None, # for InternImage-H/G
|
| 470 |
+
post_norm_block_ids=None, # for InternImage-H/G
|
| 471 |
+
res_post_norm=False, # for InternImage-H/G
|
| 472 |
+
center_feature_scale=False): # for InternImage-H/G
|
| 473 |
+
super().__init__()
|
| 474 |
+
self.channels = channels
|
| 475 |
+
self.depth = depth
|
| 476 |
+
self.post_norm = post_norm
|
| 477 |
+
self.center_feature_scale = center_feature_scale
|
| 478 |
+
|
| 479 |
+
self.blocks = nn.ModuleList([
|
| 480 |
+
InternImageLayer(
|
| 481 |
+
core_op=core_op,
|
| 482 |
+
channels=channels,
|
| 483 |
+
groups=groups,
|
| 484 |
+
mlp_ratio=mlp_ratio,
|
| 485 |
+
drop=drop,
|
| 486 |
+
drop_path=drop_path[i] if isinstance(
|
| 487 |
+
drop_path, list) else drop_path,
|
| 488 |
+
act_layer=act_layer,
|
| 489 |
+
norm_layer=norm_layer,
|
| 490 |
+
post_norm=post_norm,
|
| 491 |
+
layer_scale=layer_scale,
|
| 492 |
+
offset_scale=offset_scale,
|
| 493 |
+
with_cp=with_cp,
|
| 494 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
| 495 |
+
res_post_norm=res_post_norm, # for InternImage-H/G
|
| 496 |
+
center_feature_scale=center_feature_scale # for InternImage-H/G
|
| 497 |
+
) for i in range(depth)
|
| 498 |
+
])
|
| 499 |
+
if not self.post_norm or center_feature_scale:
|
| 500 |
+
self.norm = build_norm_layer(channels, 'LN')
|
| 501 |
+
self.post_norm_block_ids = post_norm_block_ids
|
| 502 |
+
if post_norm_block_ids is not None: # for InternImage-H/G
|
| 503 |
+
self.post_norms = nn.ModuleList(
|
| 504 |
+
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
|
| 505 |
+
)
|
| 506 |
+
self.downsample = DownsampleLayer(
|
| 507 |
+
channels=channels, norm_layer=norm_layer) if downsample else None
|
| 508 |
+
|
| 509 |
+
def forward(self, x, return_wo_downsample=False):
|
| 510 |
+
for i, blk in enumerate(self.blocks):
|
| 511 |
+
x = blk(x)
|
| 512 |
+
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
|
| 513 |
+
index = self.post_norm_block_ids.index(i)
|
| 514 |
+
x = self.post_norms[index](x) # for InternImage-H/G
|
| 515 |
+
if not self.post_norm or self.center_feature_scale:
|
| 516 |
+
x = self.norm(x)
|
| 517 |
+
if return_wo_downsample:
|
| 518 |
+
x_ = x
|
| 519 |
+
if self.downsample is not None:
|
| 520 |
+
x = self.downsample(x)
|
| 521 |
+
|
| 522 |
+
if return_wo_downsample:
|
| 523 |
+
return x, x_
|
| 524 |
+
return x
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
@BACKBONES.register_module()
|
| 528 |
+
class InternImage(nn.Module):
|
| 529 |
+
r""" InternImage
|
| 530 |
+
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
|
| 531 |
+
https://arxiv.org/pdf/2103.14030
|
| 532 |
+
Args:
|
| 533 |
+
core_op (str): Core operator. Default: 'DCNv3'
|
| 534 |
+
channels (int): Number of the first stage. Default: 64
|
| 535 |
+
depths (list): Depth of each block. Default: [3, 4, 18, 5]
|
| 536 |
+
groups (list): Groups of each block. Default: [3, 6, 12, 24]
|
| 537 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 538 |
+
drop_rate (float): Probability of an element to be zeroed. Default: 0.
|
| 539 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 540 |
+
act_layer (str): Activation layer. Default: 'GELU'
|
| 541 |
+
norm_layer (str): Normalization layer. Default: 'LN'
|
| 542 |
+
layer_scale (bool): Whether to use layer scale. Default: False
|
| 543 |
+
cls_scale (bool): Whether to use class scale. Default: False
|
| 544 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
| 545 |
+
dw_kernel_size (int): Size of the dwconv. Default: None
|
| 546 |
+
level2_post_norm (bool): Whether to use level2 post norm. Default: False
|
| 547 |
+
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
|
| 548 |
+
res_post_norm (bool): Whether to use res post norm. Default: False
|
| 549 |
+
center_feature_scale (bool): Whether to use center feature scale. Default: False
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
def __init__(self,
|
| 553 |
+
core_op='DCNv3',
|
| 554 |
+
channels=64,
|
| 555 |
+
depths=[3, 4, 18, 5],
|
| 556 |
+
groups=[3, 6, 12, 24],
|
| 557 |
+
mlp_ratio=4.,
|
| 558 |
+
drop_rate=0.,
|
| 559 |
+
drop_path_rate=0.2,
|
| 560 |
+
drop_path_type='linear',
|
| 561 |
+
act_layer='GELU',
|
| 562 |
+
norm_layer='LN',
|
| 563 |
+
layer_scale=None,
|
| 564 |
+
offset_scale=1.0,
|
| 565 |
+
post_norm=False,
|
| 566 |
+
with_cp=False,
|
| 567 |
+
dw_kernel_size=None, # for InternImage-H/G
|
| 568 |
+
level2_post_norm=False, # for InternImage-H/G
|
| 569 |
+
level2_post_norm_block_ids=None, # for InternImage-H/G
|
| 570 |
+
res_post_norm=False, # for InternImage-H/G
|
| 571 |
+
center_feature_scale=False, # for InternImage-H/G
|
| 572 |
+
out_indices=(0, 1, 2, 3),
|
| 573 |
+
init_cfg=None,
|
| 574 |
+
**kwargs):
|
| 575 |
+
super().__init__()
|
| 576 |
+
self.core_op = core_op
|
| 577 |
+
self.num_levels = len(depths)
|
| 578 |
+
self.depths = depths
|
| 579 |
+
self.channels = channels
|
| 580 |
+
self.num_features = int(channels * 2**(self.num_levels - 1))
|
| 581 |
+
self.post_norm = post_norm
|
| 582 |
+
self.mlp_ratio = mlp_ratio
|
| 583 |
+
self.init_cfg = init_cfg
|
| 584 |
+
self.out_indices = out_indices
|
| 585 |
+
self.level2_post_norm_block_ids = level2_post_norm_block_ids
|
| 586 |
+
logger = get_root_logger()
|
| 587 |
+
logger.info(f'using core type: {core_op}')
|
| 588 |
+
logger.info(f'using activation layer: {act_layer}')
|
| 589 |
+
logger.info(f'using main norm layer: {norm_layer}')
|
| 590 |
+
logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}')
|
| 591 |
+
logger.info(f"level2_post_norm: {level2_post_norm}")
|
| 592 |
+
logger.info(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}")
|
| 593 |
+
logger.info(f"res_post_norm: {res_post_norm}")
|
| 594 |
+
|
| 595 |
+
in_chans = 3
|
| 596 |
+
self.patch_embed = StemLayer(in_chans=in_chans,
|
| 597 |
+
out_chans=channels,
|
| 598 |
+
act_layer=act_layer,
|
| 599 |
+
norm_layer=norm_layer)
|
| 600 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 601 |
+
|
| 602 |
+
dpr = [
|
| 603 |
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
| 604 |
+
]
|
| 605 |
+
if drop_path_type == 'uniform':
|
| 606 |
+
for i in range(len(dpr)):
|
| 607 |
+
dpr[i] = drop_path_rate
|
| 608 |
+
|
| 609 |
+
self.levels = nn.ModuleList()
|
| 610 |
+
for i in range(self.num_levels):
|
| 611 |
+
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
|
| 612 |
+
i == 2) else None # for InternImage-H/G
|
| 613 |
+
level = InternImageBlock(
|
| 614 |
+
core_op=getattr(opsm, core_op),
|
| 615 |
+
channels=int(channels * 2**i),
|
| 616 |
+
depth=depths[i],
|
| 617 |
+
groups=groups[i],
|
| 618 |
+
mlp_ratio=self.mlp_ratio,
|
| 619 |
+
drop=drop_rate,
|
| 620 |
+
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
| 621 |
+
act_layer=act_layer,
|
| 622 |
+
norm_layer=norm_layer,
|
| 623 |
+
post_norm=post_norm,
|
| 624 |
+
downsample=(i < self.num_levels - 1),
|
| 625 |
+
layer_scale=layer_scale,
|
| 626 |
+
offset_scale=offset_scale,
|
| 627 |
+
with_cp=with_cp,
|
| 628 |
+
dw_kernel_size=dw_kernel_size, # for InternImage-H/G
|
| 629 |
+
post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G
|
| 630 |
+
res_post_norm=res_post_norm, # for InternImage-H/G
|
| 631 |
+
center_feature_scale=center_feature_scale # for InternImage-H/G
|
| 632 |
+
)
|
| 633 |
+
self.levels.append(level)
|
| 634 |
+
|
| 635 |
+
self.num_layers = len(depths)
|
| 636 |
+
self.apply(self._init_weights)
|
| 637 |
+
self.apply(self._init_deform_weights)
|
| 638 |
+
|
| 639 |
+
def init_weights(self):
|
| 640 |
+
logger = get_root_logger()
|
| 641 |
+
if self.init_cfg is None:
|
| 642 |
+
logger.warn(f'No pre-trained weights for '
|
| 643 |
+
f'{self.__class__.__name__}, '
|
| 644 |
+
f'training start from scratch')
|
| 645 |
+
for m in self.modules():
|
| 646 |
+
if isinstance(m, nn.Linear):
|
| 647 |
+
trunc_normal_init(m, std=.02, bias=0.)
|
| 648 |
+
elif isinstance(m, nn.LayerNorm):
|
| 649 |
+
constant_init(m, 1.0)
|
| 650 |
+
else:
|
| 651 |
+
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
| 652 |
+
f'specify `Pretrained` in ' \
|
| 653 |
+
f'`init_cfg` in ' \
|
| 654 |
+
f'{self.__class__.__name__} '
|
| 655 |
+
ckpt = _load_checkpoint(self.init_cfg.checkpoint,
|
| 656 |
+
logger=logger,
|
| 657 |
+
map_location='cpu')
|
| 658 |
+
if 'state_dict' in ckpt:
|
| 659 |
+
_state_dict = ckpt['state_dict']
|
| 660 |
+
elif 'model' in ckpt:
|
| 661 |
+
_state_dict = ckpt['model']
|
| 662 |
+
else:
|
| 663 |
+
_state_dict = ckpt
|
| 664 |
+
|
| 665 |
+
state_dict = OrderedDict()
|
| 666 |
+
for k, v in _state_dict.items():
|
| 667 |
+
if k.startswith('backbone.'):
|
| 668 |
+
state_dict[k[9:]] = v
|
| 669 |
+
else:
|
| 670 |
+
state_dict[k] = v
|
| 671 |
+
|
| 672 |
+
# strip prefix of state_dict
|
| 673 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
| 674 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 675 |
+
|
| 676 |
+
# load state_dict
|
| 677 |
+
meg = self.load_state_dict(state_dict, False)
|
| 678 |
+
logger.info(meg)
|
| 679 |
+
|
| 680 |
+
def _init_weights(self, m):
|
| 681 |
+
if isinstance(m, nn.Linear):
|
| 682 |
+
trunc_normal_(m.weight, std=.02)
|
| 683 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 684 |
+
nn.init.constant_(m.bias, 0)
|
| 685 |
+
elif isinstance(m, nn.LayerNorm):
|
| 686 |
+
nn.init.constant_(m.bias, 0)
|
| 687 |
+
nn.init.constant_(m.weight, 1.0)
|
| 688 |
+
|
| 689 |
+
def _init_deform_weights(self, m):
|
| 690 |
+
if isinstance(m, getattr(opsm, self.core_op)):
|
| 691 |
+
m._reset_parameters()
|
| 692 |
+
|
| 693 |
+
def forward(self, x):
|
| 694 |
+
x = self.patch_embed(x)
|
| 695 |
+
x = self.pos_drop(x)
|
| 696 |
+
|
| 697 |
+
seq_out = []
|
| 698 |
+
for level_idx, level in enumerate(self.levels):
|
| 699 |
+
x, x_ = level(x, return_wo_downsample=True)
|
| 700 |
+
if level_idx in self.out_indices:
|
| 701 |
+
seq_out.append(x_.permute(0, 3, 1, 2).contiguous())
|
| 702 |
+
return seq_out
|
model/mmdet_custom/models/backbones/swin_transformer.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Swin Transformer
|
| 3 |
+
# Copyright (c) 2021 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Ze Liu, Yutong Lin, Yixuan Wei
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.utils.checkpoint as checkpoint
|
| 12 |
+
import numpy as np
|
| 13 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 14 |
+
|
| 15 |
+
from mmcv_custom import load_checkpoint
|
| 16 |
+
from mmdet.utils import get_root_logger
|
| 17 |
+
from mmdet.models.builder import BACKBONES
|
| 18 |
+
|
| 19 |
+
from mmcv.runner import BaseModule
|
| 20 |
+
|
| 21 |
+
class Mlp(nn.Module):
|
| 22 |
+
""" Multilayer perceptron."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 25 |
+
super().__init__()
|
| 26 |
+
out_features = out_features or in_features
|
| 27 |
+
hidden_features = hidden_features or in_features
|
| 28 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 29 |
+
self.act = act_layer()
|
| 30 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 31 |
+
self.drop = nn.Dropout(drop)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
x = self.fc1(x)
|
| 35 |
+
x = self.act(x)
|
| 36 |
+
x = self.drop(x)
|
| 37 |
+
x = self.fc2(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def window_partition(x, window_size):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
x: (B, H, W, C)
|
| 46 |
+
window_size (int): window size
|
| 47 |
+
Returns:
|
| 48 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 49 |
+
"""
|
| 50 |
+
B, H, W, C = x.shape
|
| 51 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 52 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 53 |
+
return windows
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def window_reverse(windows, window_size, H, W):
|
| 57 |
+
"""
|
| 58 |
+
Args:
|
| 59 |
+
windows: (num_windows*B, window_size, window_size, C)
|
| 60 |
+
window_size (int): Window size
|
| 61 |
+
H (int): Height of image
|
| 62 |
+
W (int): Width of image
|
| 63 |
+
Returns:
|
| 64 |
+
x: (B, H, W, C)
|
| 65 |
+
"""
|
| 66 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 67 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 68 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class WindowAttention(nn.Module):
|
| 73 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
| 74 |
+
It supports both of shifted and non-shifted window.
|
| 75 |
+
Args:
|
| 76 |
+
dim (int): Number of input channels.
|
| 77 |
+
window_size (tuple[int]): The height and width of the window.
|
| 78 |
+
num_heads (int): Number of attention heads.
|
| 79 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 80 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
| 81 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
| 82 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 86 |
+
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.dim = dim
|
| 89 |
+
self.window_size = window_size # Wh, Ww
|
| 90 |
+
self.num_heads = num_heads
|
| 91 |
+
head_dim = dim // num_heads
|
| 92 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 93 |
+
|
| 94 |
+
# define a parameter table of relative position bias
|
| 95 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 96 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 97 |
+
|
| 98 |
+
# get pair-wise relative position index for each token inside the window
|
| 99 |
+
coords_h = torch.arange(self.window_size[0])
|
| 100 |
+
coords_w = torch.arange(self.window_size[1])
|
| 101 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
| 102 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 103 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 104 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 105 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 106 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 107 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 108 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 109 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 110 |
+
|
| 111 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 112 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 113 |
+
self.proj = nn.Linear(dim, dim)
|
| 114 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 115 |
+
|
| 116 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 117 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 118 |
+
|
| 119 |
+
def forward(self, x, mask=None):
|
| 120 |
+
""" Forward function.
|
| 121 |
+
Args:
|
| 122 |
+
x: input features with shape of (num_windows*B, N, C)
|
| 123 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
| 124 |
+
"""
|
| 125 |
+
B_, N, C = x.shape
|
| 126 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 127 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 128 |
+
|
| 129 |
+
q = q * self.scale
|
| 130 |
+
attn = (q @ k.transpose(-2, -1))
|
| 131 |
+
|
| 132 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 133 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
| 134 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 135 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 136 |
+
|
| 137 |
+
if mask is not None:
|
| 138 |
+
nW = mask.shape[0]
|
| 139 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 140 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 141 |
+
attn = self.softmax(attn)
|
| 142 |
+
else:
|
| 143 |
+
attn = self.softmax(attn)
|
| 144 |
+
|
| 145 |
+
attn = self.attn_drop(attn)
|
| 146 |
+
|
| 147 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 148 |
+
x = self.proj(x)
|
| 149 |
+
x = self.proj_drop(x)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class SwinTransformerBlock(nn.Module):
|
| 154 |
+
""" Swin Transformer Block.
|
| 155 |
+
Args:
|
| 156 |
+
dim (int): Number of input channels.
|
| 157 |
+
num_heads (int): Number of attention heads.
|
| 158 |
+
window_size (int): Window size.
|
| 159 |
+
shift_size (int): Shift size for SW-MSA.
|
| 160 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 161 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 162 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 163 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 164 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 165 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 166 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 167 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
| 171 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
| 172 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.dim = dim
|
| 175 |
+
self.num_heads = num_heads
|
| 176 |
+
self.window_size = window_size
|
| 177 |
+
self.shift_size = shift_size
|
| 178 |
+
self.mlp_ratio = mlp_ratio
|
| 179 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
| 180 |
+
|
| 181 |
+
self.norm1 = norm_layer(dim)
|
| 182 |
+
self.attn = WindowAttention(
|
| 183 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
| 184 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 185 |
+
|
| 186 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 187 |
+
self.norm2 = norm_layer(dim)
|
| 188 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 189 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 190 |
+
|
| 191 |
+
self.H = None
|
| 192 |
+
self.W = None
|
| 193 |
+
|
| 194 |
+
def forward(self, x, mask_matrix):
|
| 195 |
+
""" Forward function.
|
| 196 |
+
Args:
|
| 197 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 198 |
+
H, W: Spatial resolution of the input feature.
|
| 199 |
+
mask_matrix: Attention mask for cyclic shift.
|
| 200 |
+
"""
|
| 201 |
+
B, L, C = x.shape
|
| 202 |
+
H, W = self.H, self.W
|
| 203 |
+
assert L == H * W, "input feature has wrong size"
|
| 204 |
+
|
| 205 |
+
shortcut = x
|
| 206 |
+
x = self.norm1(x)
|
| 207 |
+
x = x.view(B, H, W, C)
|
| 208 |
+
|
| 209 |
+
# pad feature maps to multiples of window size
|
| 210 |
+
pad_l = pad_t = 0
|
| 211 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
| 212 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
| 213 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 214 |
+
_, Hp, Wp, _ = x.shape
|
| 215 |
+
|
| 216 |
+
# cyclic shift
|
| 217 |
+
if self.shift_size > 0:
|
| 218 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
| 219 |
+
attn_mask = mask_matrix
|
| 220 |
+
else:
|
| 221 |
+
shifted_x = x
|
| 222 |
+
attn_mask = None
|
| 223 |
+
|
| 224 |
+
# partition windows
|
| 225 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
| 226 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
| 227 |
+
|
| 228 |
+
# W-MSA/SW-MSA
|
| 229 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
| 230 |
+
|
| 231 |
+
# merge windows
|
| 232 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
| 233 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
| 234 |
+
|
| 235 |
+
# reverse cyclic shift
|
| 236 |
+
if self.shift_size > 0:
|
| 237 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
| 238 |
+
else:
|
| 239 |
+
x = shifted_x
|
| 240 |
+
|
| 241 |
+
if pad_r > 0 or pad_b > 0:
|
| 242 |
+
x = x[:, :H, :W, :].contiguous()
|
| 243 |
+
|
| 244 |
+
x = x.view(B, H * W, C)
|
| 245 |
+
|
| 246 |
+
# FFN
|
| 247 |
+
x = shortcut + self.drop_path(x)
|
| 248 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 249 |
+
|
| 250 |
+
return x
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class PatchMerging(nn.Module):
|
| 254 |
+
""" Patch Merging Layer
|
| 255 |
+
Args:
|
| 256 |
+
dim (int): Number of input channels.
|
| 257 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 258 |
+
"""
|
| 259 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.dim = dim
|
| 262 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 263 |
+
self.norm = norm_layer(4 * dim)
|
| 264 |
+
|
| 265 |
+
def forward(self, x, H, W):
|
| 266 |
+
""" Forward function.
|
| 267 |
+
Args:
|
| 268 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 269 |
+
H, W: Spatial resolution of the input feature.
|
| 270 |
+
"""
|
| 271 |
+
B, L, C = x.shape
|
| 272 |
+
assert L == H * W, "input feature has wrong size"
|
| 273 |
+
|
| 274 |
+
x = x.view(B, H, W, C)
|
| 275 |
+
|
| 276 |
+
# padding
|
| 277 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 278 |
+
if pad_input:
|
| 279 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 280 |
+
|
| 281 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
| 282 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
| 283 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
| 284 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
| 285 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
| 286 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
| 287 |
+
|
| 288 |
+
x = self.norm(x)
|
| 289 |
+
x = self.reduction(x)
|
| 290 |
+
|
| 291 |
+
return x
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class BasicLayer(nn.Module):
|
| 295 |
+
""" A basic Swin Transformer layer for one stage.
|
| 296 |
+
Args:
|
| 297 |
+
dim (int): Number of feature channels
|
| 298 |
+
depth (int): Depths of this stage.
|
| 299 |
+
num_heads (int): Number of attention head.
|
| 300 |
+
window_size (int): Local window size. Default: 7.
|
| 301 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 302 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
| 303 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
| 304 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 305 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
| 306 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 307 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 308 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 309 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
def __init__(self,
|
| 313 |
+
dim,
|
| 314 |
+
depth,
|
| 315 |
+
num_heads,
|
| 316 |
+
window_size=7,
|
| 317 |
+
mlp_ratio=4.,
|
| 318 |
+
qkv_bias=True,
|
| 319 |
+
qk_scale=None,
|
| 320 |
+
drop=0.,
|
| 321 |
+
attn_drop=0.,
|
| 322 |
+
drop_path=0.,
|
| 323 |
+
norm_layer=nn.LayerNorm,
|
| 324 |
+
downsample=None,
|
| 325 |
+
use_checkpoint=False):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.window_size = window_size
|
| 328 |
+
self.shift_size = window_size // 2
|
| 329 |
+
self.depth = depth
|
| 330 |
+
self.use_checkpoint = use_checkpoint
|
| 331 |
+
|
| 332 |
+
# build blocks
|
| 333 |
+
self.blocks = nn.ModuleList([
|
| 334 |
+
SwinTransformerBlock(
|
| 335 |
+
dim=dim,
|
| 336 |
+
num_heads=num_heads,
|
| 337 |
+
window_size=window_size,
|
| 338 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
| 339 |
+
mlp_ratio=mlp_ratio,
|
| 340 |
+
qkv_bias=qkv_bias,
|
| 341 |
+
qk_scale=qk_scale,
|
| 342 |
+
drop=drop,
|
| 343 |
+
attn_drop=attn_drop,
|
| 344 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 345 |
+
norm_layer=norm_layer)
|
| 346 |
+
for i in range(depth)])
|
| 347 |
+
|
| 348 |
+
# patch merging layer
|
| 349 |
+
if downsample is not None:
|
| 350 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 351 |
+
else:
|
| 352 |
+
self.downsample = None
|
| 353 |
+
|
| 354 |
+
def forward(self, x, H, W):
|
| 355 |
+
""" Forward function.
|
| 356 |
+
Args:
|
| 357 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 358 |
+
H, W: Spatial resolution of the input feature.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
# calculate attention mask for SW-MSA
|
| 362 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 363 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 364 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
| 365 |
+
h_slices = (slice(0, -self.window_size),
|
| 366 |
+
slice(-self.window_size, -self.shift_size),
|
| 367 |
+
slice(-self.shift_size, None))
|
| 368 |
+
w_slices = (slice(0, -self.window_size),
|
| 369 |
+
slice(-self.window_size, -self.shift_size),
|
| 370 |
+
slice(-self.shift_size, None))
|
| 371 |
+
cnt = 0
|
| 372 |
+
for h in h_slices:
|
| 373 |
+
for w in w_slices:
|
| 374 |
+
img_mask[:, h, w, :] = cnt
|
| 375 |
+
cnt += 1
|
| 376 |
+
|
| 377 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
| 378 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
| 379 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 380 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 381 |
+
|
| 382 |
+
for blk in self.blocks:
|
| 383 |
+
blk.H, blk.W = H, W
|
| 384 |
+
if self.use_checkpoint:
|
| 385 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 386 |
+
else:
|
| 387 |
+
x = blk(x, attn_mask)
|
| 388 |
+
if self.downsample is not None:
|
| 389 |
+
x_down = self.downsample(x, H, W)
|
| 390 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 391 |
+
return x, H, W, x_down, Wh, Ww
|
| 392 |
+
else:
|
| 393 |
+
return x, H, W, x, H, W
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class PatchEmbed(nn.Module):
|
| 397 |
+
""" Image to Patch Embedding
|
| 398 |
+
Args:
|
| 399 |
+
patch_size (int): Patch token size. Default: 4.
|
| 400 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 401 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 402 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
| 406 |
+
super().__init__()
|
| 407 |
+
patch_size = to_2tuple(patch_size)
|
| 408 |
+
self.patch_size = patch_size
|
| 409 |
+
|
| 410 |
+
self.in_chans = in_chans
|
| 411 |
+
self.embed_dim = embed_dim
|
| 412 |
+
|
| 413 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 414 |
+
if norm_layer is not None:
|
| 415 |
+
self.norm = norm_layer(embed_dim)
|
| 416 |
+
else:
|
| 417 |
+
self.norm = None
|
| 418 |
+
|
| 419 |
+
def forward(self, x):
|
| 420 |
+
"""Forward function."""
|
| 421 |
+
# padding
|
| 422 |
+
_, _, H, W = x.size()
|
| 423 |
+
if W % self.patch_size[1] != 0:
|
| 424 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 425 |
+
if H % self.patch_size[0] != 0:
|
| 426 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 427 |
+
|
| 428 |
+
x = self.proj(x) # B C Wh Ww
|
| 429 |
+
if self.norm is not None:
|
| 430 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 431 |
+
x = x.flatten(2).transpose(1, 2)
|
| 432 |
+
x = self.norm(x)
|
| 433 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 434 |
+
|
| 435 |
+
return x
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@BACKBONES.register_module()
|
| 439 |
+
class SwinTransformerV1(BaseModule):
|
| 440 |
+
""" Swin Transformer backbone.
|
| 441 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
| 442 |
+
https://arxiv.org/pdf/2103.14030
|
| 443 |
+
Args:
|
| 444 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 445 |
+
used in absolute postion embedding. Default 224.
|
| 446 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 447 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 448 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 449 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 450 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
| 451 |
+
window_size (int): Window size. Default: 7.
|
| 452 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 453 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
| 454 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
| 455 |
+
drop_rate (float): Dropout rate.
|
| 456 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
| 457 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 458 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 459 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
| 460 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 461 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 462 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 463 |
+
-1 means not freezing any parameters.
|
| 464 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 465 |
+
pretrained (str, optional): model pretrained path. Default: None.
|
| 466 |
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
| 467 |
+
Default: None.
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
def __init__(self,
|
| 471 |
+
pretrain_img_size=224,
|
| 472 |
+
patch_size=4,
|
| 473 |
+
in_chans=3,
|
| 474 |
+
embed_dim=96,
|
| 475 |
+
depths=[2, 2, 6, 2],
|
| 476 |
+
num_heads=[3, 6, 12, 24],
|
| 477 |
+
window_size=7,
|
| 478 |
+
mlp_ratio=4.,
|
| 479 |
+
qkv_bias=True,
|
| 480 |
+
qk_scale=None,
|
| 481 |
+
drop_rate=0.,
|
| 482 |
+
attn_drop_rate=0.,
|
| 483 |
+
drop_path_rate=0.2,
|
| 484 |
+
norm_layer=nn.LayerNorm,
|
| 485 |
+
ape=False,
|
| 486 |
+
patch_norm=True,
|
| 487 |
+
out_indices=(0, 1, 2, 3),
|
| 488 |
+
frozen_stages=-1,
|
| 489 |
+
use_checkpoint=False,
|
| 490 |
+
pretrained=None,
|
| 491 |
+
init_cfg=None):
|
| 492 |
+
assert init_cfg is None, 'To prevent abnormal initialization ' \
|
| 493 |
+
'behavior, init_cfg is not allowed to be set'
|
| 494 |
+
super().__init__(init_cfg=init_cfg)
|
| 495 |
+
|
| 496 |
+
self.pretrain_img_size = pretrain_img_size
|
| 497 |
+
self.num_layers = len(depths)
|
| 498 |
+
self.embed_dim = embed_dim
|
| 499 |
+
self.ape = ape
|
| 500 |
+
self.patch_norm = patch_norm
|
| 501 |
+
self.out_indices = out_indices
|
| 502 |
+
self.frozen_stages = frozen_stages
|
| 503 |
+
self.pretrained = pretrained
|
| 504 |
+
|
| 505 |
+
# split image into non-overlapping patches
|
| 506 |
+
self.patch_embed = PatchEmbed(
|
| 507 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 508 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 509 |
+
|
| 510 |
+
# absolute position embedding
|
| 511 |
+
if self.ape:
|
| 512 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
| 513 |
+
patch_size = to_2tuple(patch_size)
|
| 514 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
| 515 |
+
|
| 516 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
| 517 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
| 518 |
+
|
| 519 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 520 |
+
|
| 521 |
+
# stochastic depth
|
| 522 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 523 |
+
|
| 524 |
+
# build layers
|
| 525 |
+
self.layers = nn.ModuleList()
|
| 526 |
+
for i_layer in range(self.num_layers):
|
| 527 |
+
layer = BasicLayer(
|
| 528 |
+
dim=int(embed_dim * 2 ** i_layer),
|
| 529 |
+
depth=depths[i_layer],
|
| 530 |
+
num_heads=num_heads[i_layer],
|
| 531 |
+
window_size=window_size,
|
| 532 |
+
mlp_ratio=mlp_ratio,
|
| 533 |
+
qkv_bias=qkv_bias,
|
| 534 |
+
qk_scale=qk_scale,
|
| 535 |
+
drop=drop_rate,
|
| 536 |
+
attn_drop=attn_drop_rate,
|
| 537 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 538 |
+
norm_layer=norm_layer,
|
| 539 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
| 540 |
+
use_checkpoint=use_checkpoint)
|
| 541 |
+
self.layers.append(layer)
|
| 542 |
+
|
| 543 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
| 544 |
+
self.num_features = num_features
|
| 545 |
+
|
| 546 |
+
# add a norm layer for each output
|
| 547 |
+
for i_layer in out_indices:
|
| 548 |
+
layer = norm_layer(num_features[i_layer])
|
| 549 |
+
layer_name = f'norm{i_layer}'
|
| 550 |
+
self.add_module(layer_name, layer)
|
| 551 |
+
|
| 552 |
+
self._freeze_stages()
|
| 553 |
+
|
| 554 |
+
def _freeze_stages(self):
|
| 555 |
+
if self.frozen_stages >= 0:
|
| 556 |
+
self.patch_embed.eval()
|
| 557 |
+
for param in self.patch_embed.parameters():
|
| 558 |
+
param.requires_grad = False
|
| 559 |
+
|
| 560 |
+
if self.frozen_stages >= 1 and self.ape:
|
| 561 |
+
self.absolute_pos_embed.requires_grad = False
|
| 562 |
+
|
| 563 |
+
if self.frozen_stages >= 2:
|
| 564 |
+
self.pos_drop.eval()
|
| 565 |
+
for i in range(0, self.frozen_stages - 1):
|
| 566 |
+
m = self.layers[i]
|
| 567 |
+
m.eval()
|
| 568 |
+
for param in m.parameters():
|
| 569 |
+
param.requires_grad = False
|
| 570 |
+
|
| 571 |
+
# def init_weights(self, pretrained=None):
|
| 572 |
+
# """Initialize the weights in backbone.
|
| 573 |
+
|
| 574 |
+
# Args:
|
| 575 |
+
# pretrained (str, optional): Path to pre-trained weights.
|
| 576 |
+
# Defaults to None.
|
| 577 |
+
# """
|
| 578 |
+
|
| 579 |
+
# def _init_weights(m):
|
| 580 |
+
# if isinstance(m, nn.Linear):
|
| 581 |
+
# trunc_normal_(m.weight, std=.02)
|
| 582 |
+
# if isinstance(m, nn.Linear) and m.bias is not None:
|
| 583 |
+
# nn.init.constant_(m.bias, 0)
|
| 584 |
+
# elif isinstance(m, nn.LayerNorm):
|
| 585 |
+
# nn.init.constant_(m.bias, 0)
|
| 586 |
+
# nn.init.constant_(m.weight, 1.0)
|
| 587 |
+
|
| 588 |
+
# if isinstance(pretrained, str):
|
| 589 |
+
# self.apply(_init_weights)
|
| 590 |
+
# logger = get_root_logger()
|
| 591 |
+
# load_checkpoint(self, pretrained, strict=False, logger=logger)
|
| 592 |
+
# elif pretrained is None:
|
| 593 |
+
# self.apply(_init_weights)
|
| 594 |
+
# else:
|
| 595 |
+
# raise TypeError('pretrained must be a str or None')
|
| 596 |
+
|
| 597 |
+
def init_weights(self):
|
| 598 |
+
"""Initialize the weights in backbone."""
|
| 599 |
+
|
| 600 |
+
def _init_weights(m):
|
| 601 |
+
if isinstance(m, nn.Linear):
|
| 602 |
+
trunc_normal_(m.weight, std=.02)
|
| 603 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 604 |
+
nn.init.constant_(m.bias, 0)
|
| 605 |
+
elif isinstance(m, nn.LayerNorm):
|
| 606 |
+
nn.init.constant_(m.bias, 0)
|
| 607 |
+
nn.init.constant_(m.weight, 1.0)
|
| 608 |
+
|
| 609 |
+
if isinstance(self.pretrained, str):
|
| 610 |
+
self.apply(_init_weights)
|
| 611 |
+
logger = get_root_logger()
|
| 612 |
+
load_checkpoint(self, self.pretrained, strict=False, logger=logger)
|
| 613 |
+
elif self.pretrained is None:
|
| 614 |
+
self.apply(_init_weights)
|
| 615 |
+
else:
|
| 616 |
+
raise TypeError('pretrained must be a str or None')
|
| 617 |
+
|
| 618 |
+
def forward(self, x):
|
| 619 |
+
"""Forward function."""
|
| 620 |
+
x = self.patch_embed(x)
|
| 621 |
+
|
| 622 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 623 |
+
if self.ape:
|
| 624 |
+
# interpolate the position embedding to the corresponding size
|
| 625 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
| 626 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
| 627 |
+
else:
|
| 628 |
+
x = x.flatten(2).transpose(1, 2)
|
| 629 |
+
x = self.pos_drop(x)
|
| 630 |
+
|
| 631 |
+
outs = []
|
| 632 |
+
for i in range(self.num_layers):
|
| 633 |
+
layer = self.layers[i]
|
| 634 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 635 |
+
|
| 636 |
+
if i in self.out_indices:
|
| 637 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 638 |
+
x_out = norm_layer(x_out)
|
| 639 |
+
|
| 640 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
| 641 |
+
outs.append(out)
|
| 642 |
+
|
| 643 |
+
return tuple(outs)
|
| 644 |
+
|
| 645 |
+
def train(self, mode=True):
|
| 646 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 647 |
+
super(SwinTransformerV1, self).train(mode)
|
| 648 |
+
self._freeze_stages()
|
model/mmdet_custom/models/dense_heads/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# InternImage
|
| 3 |
+
# Copyright (c) 2022 OpenGVLab
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
from .deformable_detr_head import DeformableDETRHead
|
| 8 |
+
from .detr_head import DETRHead
|
| 9 |
+
from .dino_head import DINOHead
|
| 10 |
+
|
| 11 |
+
__all__ = ['DeformableDETRHead', 'DETRHead', 'DINOHead']
|
model/mmdet_custom/models/dense_heads/deformable_detr_head.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import copy
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from mmcv.cnn import Linear, bias_init_with_prob, constant_init
|
| 8 |
+
from mmcv.runner import force_fp32
|
| 9 |
+
|
| 10 |
+
from mmdet.core import multi_apply
|
| 11 |
+
from mmdet.models.utils.transformer import inverse_sigmoid
|
| 12 |
+
from mmdet.models.builder import HEADS
|
| 13 |
+
from .detr_head import DETRHead
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@HEADS.register_module(force=True)
|
| 17 |
+
class DeformableDETRHead(DETRHead):
|
| 18 |
+
"""Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-
|
| 19 |
+
End Object Detection.
|
| 20 |
+
|
| 21 |
+
Code is modified from the `official github repo
|
| 22 |
+
<https://github.com/fundamentalvision/Deformable-DETR>`_.
|
| 23 |
+
|
| 24 |
+
More details can be found in the `paper
|
| 25 |
+
<https://arxiv.org/abs/2010.04159>`_ .
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
with_box_refine (bool): Whether to refine the reference points
|
| 29 |
+
in the decoder. Defaults to False.
|
| 30 |
+
as_two_stage (bool) : Whether to generate the proposal from
|
| 31 |
+
the outputs of encoder.
|
| 32 |
+
transformer (obj:`ConfigDict`): ConfigDict is used for building
|
| 33 |
+
the Encoder and Decoder.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self,
|
| 37 |
+
*args,
|
| 38 |
+
with_box_refine=False,
|
| 39 |
+
as_two_stage=False,
|
| 40 |
+
transformer=None,
|
| 41 |
+
use_2fc_cls_branch=False,
|
| 42 |
+
**kwargs):
|
| 43 |
+
self.with_box_refine = with_box_refine
|
| 44 |
+
self.as_two_stage = as_two_stage
|
| 45 |
+
self.use_2fc_cls_branch = use_2fc_cls_branch
|
| 46 |
+
if self.as_two_stage:
|
| 47 |
+
transformer['as_two_stage'] = self.as_two_stage
|
| 48 |
+
|
| 49 |
+
super(DeformableDETRHead, self).__init__(
|
| 50 |
+
*args, transformer=transformer, **kwargs)
|
| 51 |
+
|
| 52 |
+
def _init_layers(self):
|
| 53 |
+
"""Initialize classification branch and regression branch of head."""
|
| 54 |
+
|
| 55 |
+
if not self.use_2fc_cls_branch:
|
| 56 |
+
fc_cls = Linear(self.embed_dims, self.cls_out_channels)
|
| 57 |
+
else:
|
| 58 |
+
fc_cls = nn.Sequential(*[
|
| 59 |
+
Linear(self.embed_dims, int(self.embed_dims * 1.5)),
|
| 60 |
+
nn.LayerNorm(int(self.embed_dims * 1.5)),
|
| 61 |
+
nn.GELU(),
|
| 62 |
+
Linear(int(self.embed_dims * 1.5), self.cls_out_channels),
|
| 63 |
+
])
|
| 64 |
+
fc_cls.out_features = self.cls_out_channels
|
| 65 |
+
|
| 66 |
+
reg_branch = []
|
| 67 |
+
for _ in range(self.num_reg_fcs):
|
| 68 |
+
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
|
| 69 |
+
reg_branch.append(nn.ReLU())
|
| 70 |
+
reg_branch.append(Linear(self.embed_dims, 4))
|
| 71 |
+
reg_branch = nn.Sequential(*reg_branch)
|
| 72 |
+
|
| 73 |
+
def _get_clones(module, N):
|
| 74 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 75 |
+
|
| 76 |
+
# last reg_branch is used to generate proposal from
|
| 77 |
+
# encode feature map when as_two_stage is True.
|
| 78 |
+
num_pred = (self.transformer.decoder.num_layers + 1) if \
|
| 79 |
+
self.as_two_stage else self.transformer.decoder.num_layers
|
| 80 |
+
|
| 81 |
+
if self.with_box_refine:
|
| 82 |
+
self.cls_branches = _get_clones(fc_cls, num_pred)
|
| 83 |
+
self.reg_branches = _get_clones(reg_branch, num_pred)
|
| 84 |
+
else:
|
| 85 |
+
|
| 86 |
+
self.cls_branches = nn.ModuleList(
|
| 87 |
+
[fc_cls for _ in range(num_pred)])
|
| 88 |
+
self.reg_branches = nn.ModuleList(
|
| 89 |
+
[reg_branch for _ in range(num_pred)])
|
| 90 |
+
|
| 91 |
+
if not self.as_two_stage:
|
| 92 |
+
self.query_embedding = nn.Embedding(
|
| 93 |
+
self.num_query,
|
| 94 |
+
self.embed_dims * 2)
|
| 95 |
+
|
| 96 |
+
def init_weights(self):
|
| 97 |
+
"""Initialize weights of the DeformDETR head."""
|
| 98 |
+
self.transformer.init_weights()
|
| 99 |
+
if self.loss_cls.use_sigmoid:
|
| 100 |
+
bias_init = bias_init_with_prob(0.01)
|
| 101 |
+
if not self.use_2fc_cls_branch:
|
| 102 |
+
for m in self.cls_branches:
|
| 103 |
+
nn.init.constant_(m.bias, bias_init)
|
| 104 |
+
for m in self.reg_branches:
|
| 105 |
+
constant_init(m[-1], 0, bias=0)
|
| 106 |
+
nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
|
| 107 |
+
if self.as_two_stage:
|
| 108 |
+
for m in self.reg_branches:
|
| 109 |
+
nn.init.constant_(m[-1].bias.data[2:], 0.0)
|
| 110 |
+
|
| 111 |
+
def forward(self, mlvl_feats, img_metas):
|
| 112 |
+
"""Forward function.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
mlvl_feats (tuple[Tensor]): Features from the upstream
|
| 116 |
+
network, each is a 4D-tensor with shape
|
| 117 |
+
(N, C, H, W).
|
| 118 |
+
img_metas (list[dict]): List of image information.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
all_cls_scores (Tensor): Outputs from the classification head, \
|
| 122 |
+
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
|
| 123 |
+
cls_out_channels should includes background.
|
| 124 |
+
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
|
| 125 |
+
head with normalized coordinate format (cx, cy, w, h). \
|
| 126 |
+
Shape [nb_dec, bs, num_query, 4].
|
| 127 |
+
enc_outputs_class (Tensor): The score of each point on encode \
|
| 128 |
+
feature map, has shape (N, h*w, num_class). Only when \
|
| 129 |
+
as_two_stage is True it would be returned, otherwise \
|
| 130 |
+
`None` would be returned.
|
| 131 |
+
enc_outputs_coord (Tensor): The proposal generate from the \
|
| 132 |
+
encode feature map, has shape (N, h*w, 4). Only when \
|
| 133 |
+
as_two_stage is True it would be returned, otherwise \
|
| 134 |
+
`None` would be returned.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
batch_size = mlvl_feats[0].size(0)
|
| 138 |
+
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
|
| 139 |
+
img_masks = mlvl_feats[0].new_ones(
|
| 140 |
+
(batch_size, input_img_h, input_img_w))
|
| 141 |
+
for img_id in range(batch_size):
|
| 142 |
+
img_h, img_w, _ = img_metas[img_id]['img_shape']
|
| 143 |
+
img_masks[img_id, :img_h, :img_w] = 0
|
| 144 |
+
|
| 145 |
+
mlvl_masks = []
|
| 146 |
+
mlvl_positional_encodings = []
|
| 147 |
+
for feat in mlvl_feats:
|
| 148 |
+
mlvl_masks.append(
|
| 149 |
+
F.interpolate(img_masks[None],
|
| 150 |
+
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
|
| 151 |
+
mlvl_positional_encodings.append(
|
| 152 |
+
self.positional_encoding(mlvl_masks[-1]))
|
| 153 |
+
|
| 154 |
+
query_embeds = None
|
| 155 |
+
if not self.as_two_stage:
|
| 156 |
+
query_embeds = self.query_embedding.weight
|
| 157 |
+
hs, init_reference, inter_references, \
|
| 158 |
+
enc_outputs_class, enc_outputs_coord = self.transformer(
|
| 159 |
+
mlvl_feats,
|
| 160 |
+
mlvl_masks,
|
| 161 |
+
query_embeds,
|
| 162 |
+
mlvl_positional_encodings,
|
| 163 |
+
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
|
| 164 |
+
cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
|
| 165 |
+
)
|
| 166 |
+
hs = hs.permute(0, 2, 1, 3)
|
| 167 |
+
outputs_classes = []
|
| 168 |
+
outputs_coords = []
|
| 169 |
+
|
| 170 |
+
for lvl in range(hs.shape[0]):
|
| 171 |
+
if lvl == 0:
|
| 172 |
+
reference = init_reference
|
| 173 |
+
else:
|
| 174 |
+
reference = inter_references[lvl - 1]
|
| 175 |
+
reference = inverse_sigmoid(reference)
|
| 176 |
+
outputs_class = self.cls_branches[lvl](hs[lvl])
|
| 177 |
+
tmp = self.reg_branches[lvl](hs[lvl])
|
| 178 |
+
if reference.shape[-1] == 4:
|
| 179 |
+
tmp += reference
|
| 180 |
+
else:
|
| 181 |
+
assert reference.shape[-1] == 2
|
| 182 |
+
tmp[..., :2] += reference
|
| 183 |
+
outputs_coord = tmp.sigmoid()
|
| 184 |
+
outputs_classes.append(outputs_class)
|
| 185 |
+
outputs_coords.append(outputs_coord)
|
| 186 |
+
|
| 187 |
+
outputs_classes = torch.stack(outputs_classes)
|
| 188 |
+
outputs_coords = torch.stack(outputs_coords)
|
| 189 |
+
if self.as_two_stage:
|
| 190 |
+
return outputs_classes, outputs_coords, \
|
| 191 |
+
enc_outputs_class, \
|
| 192 |
+
enc_outputs_coord.sigmoid()
|
| 193 |
+
else:
|
| 194 |
+
return outputs_classes, outputs_coords, \
|
| 195 |
+
None, None
|
| 196 |
+
|
| 197 |
+
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
|
| 198 |
+
def loss(self,
|
| 199 |
+
all_cls_scores,
|
| 200 |
+
all_bbox_preds,
|
| 201 |
+
enc_cls_scores,
|
| 202 |
+
enc_bbox_preds,
|
| 203 |
+
gt_bboxes_list,
|
| 204 |
+
gt_labels_list,
|
| 205 |
+
img_metas,
|
| 206 |
+
gt_bboxes_ignore=None):
|
| 207 |
+
""""Loss function.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
all_cls_scores (Tensor): Classification score of all
|
| 211 |
+
decoder layers, has shape
|
| 212 |
+
[nb_dec, bs, num_query, cls_out_channels].
|
| 213 |
+
all_bbox_preds (Tensor): Sigmoid regression
|
| 214 |
+
outputs of all decode layers. Each is a 4D-tensor with
|
| 215 |
+
normalized coordinate format (cx, cy, w, h) and shape
|
| 216 |
+
[nb_dec, bs, num_query, 4].
|
| 217 |
+
enc_cls_scores (Tensor): Classification scores of
|
| 218 |
+
points on encode feature map , has shape
|
| 219 |
+
(N, h*w, num_classes). Only be passed when as_two_stage is
|
| 220 |
+
True, otherwise is None.
|
| 221 |
+
enc_bbox_preds (Tensor): Regression results of each points
|
| 222 |
+
on the encode feature map, has shape (N, h*w, 4). Only be
|
| 223 |
+
passed when as_two_stage is True, otherwise is None.
|
| 224 |
+
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
|
| 225 |
+
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
|
| 226 |
+
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
| 227 |
+
image with shape (num_gts, ).
|
| 228 |
+
img_metas (list[dict]): List of image meta information.
|
| 229 |
+
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
|
| 230 |
+
which can be ignored for each image. Default None.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
dict[str, Tensor]: A dictionary of loss components.
|
| 234 |
+
"""
|
| 235 |
+
assert gt_bboxes_ignore is None, \
|
| 236 |
+
f'{self.__class__.__name__} only supports ' \
|
| 237 |
+
f'for gt_bboxes_ignore setting to None.'
|
| 238 |
+
|
| 239 |
+
num_dec_layers = len(all_cls_scores)
|
| 240 |
+
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
|
| 241 |
+
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
| 242 |
+
all_gt_bboxes_ignore_list = [
|
| 243 |
+
gt_bboxes_ignore for _ in range(num_dec_layers)
|
| 244 |
+
]
|
| 245 |
+
img_metas_list = [img_metas for _ in range(num_dec_layers)]
|
| 246 |
+
|
| 247 |
+
losses_cls, losses_bbox, losses_iou = multi_apply(
|
| 248 |
+
self.loss_single, all_cls_scores, all_bbox_preds,
|
| 249 |
+
all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
|
| 250 |
+
all_gt_bboxes_ignore_list)
|
| 251 |
+
|
| 252 |
+
loss_dict = dict()
|
| 253 |
+
# loss of proposal generated from encode feature map.
|
| 254 |
+
if enc_cls_scores is not None:
|
| 255 |
+
binary_labels_list = [
|
| 256 |
+
torch.zeros_like(gt_labels_list[i])
|
| 257 |
+
for i in range(len(img_metas))
|
| 258 |
+
]
|
| 259 |
+
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
|
| 260 |
+
self.loss_single(enc_cls_scores, enc_bbox_preds,
|
| 261 |
+
gt_bboxes_list, binary_labels_list,
|
| 262 |
+
img_metas, gt_bboxes_ignore)
|
| 263 |
+
loss_dict['enc_loss_cls'] = enc_loss_cls
|
| 264 |
+
loss_dict['enc_loss_bbox'] = enc_losses_bbox
|
| 265 |
+
loss_dict['enc_loss_iou'] = enc_losses_iou
|
| 266 |
+
|
| 267 |
+
# loss from the last decoder layer
|
| 268 |
+
loss_dict['loss_cls'] = losses_cls[-1]
|
| 269 |
+
loss_dict['loss_bbox'] = losses_bbox[-1]
|
| 270 |
+
loss_dict['loss_iou'] = losses_iou[-1]
|
| 271 |
+
# loss from other decoder layers
|
| 272 |
+
num_dec_layer = 0
|
| 273 |
+
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
|
| 274 |
+
losses_bbox[:-1],
|
| 275 |
+
losses_iou[:-1]):
|
| 276 |
+
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
|
| 277 |
+
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
|
| 278 |
+
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
|
| 279 |
+
num_dec_layer += 1
|
| 280 |
+
return loss_dict
|
| 281 |
+
|
| 282 |
+
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
|
| 283 |
+
def get_bboxes(self,
|
| 284 |
+
all_cls_scores,
|
| 285 |
+
all_bbox_preds,
|
| 286 |
+
enc_cls_scores,
|
| 287 |
+
enc_bbox_preds,
|
| 288 |
+
img_metas,
|
| 289 |
+
rescale=False):
|
| 290 |
+
"""Transform network outputs for a batch into bbox predictions.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
all_cls_scores (Tensor): Classification score of all
|
| 294 |
+
decoder layers, has shape
|
| 295 |
+
[nb_dec, bs, num_query, cls_out_channels].
|
| 296 |
+
all_bbox_preds (Tensor): Sigmoid regression
|
| 297 |
+
outputs of all decode layers. Each is a 4D-tensor with
|
| 298 |
+
normalized coordinate format (cx, cy, w, h) and shape
|
| 299 |
+
[nb_dec, bs, num_query, 4].
|
| 300 |
+
enc_cls_scores (Tensor): Classification scores of
|
| 301 |
+
points on encode feature map , has shape
|
| 302 |
+
(N, h*w, num_classes). Only be passed when as_two_stage is
|
| 303 |
+
True, otherwise is None.
|
| 304 |
+
enc_bbox_preds (Tensor): Regression results of each points
|
| 305 |
+
on the encode feature map, has shape (N, h*w, 4). Only be
|
| 306 |
+
passed when as_two_stage is True, otherwise is None.
|
| 307 |
+
img_metas (list[dict]): Meta information of each image.
|
| 308 |
+
rescale (bool, optional): If True, return boxes in original
|
| 309 |
+
image space. Default False.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
|
| 313 |
+
The first item is an (n, 5) tensor, where the first 4 columns \
|
| 314 |
+
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
|
| 315 |
+
5-th column is a score between 0 and 1. The second item is a \
|
| 316 |
+
(n,) tensor where each item is the predicted class label of \
|
| 317 |
+
the corresponding box.
|
| 318 |
+
"""
|
| 319 |
+
cls_scores = all_cls_scores[-1]
|
| 320 |
+
bbox_preds = all_bbox_preds[-1]
|
| 321 |
+
|
| 322 |
+
result_list = []
|
| 323 |
+
for img_id in range(len(img_metas)):
|
| 324 |
+
cls_score = cls_scores[img_id]
|
| 325 |
+
bbox_pred = bbox_preds[img_id]
|
| 326 |
+
img_shape = img_metas[img_id]['img_shape']
|
| 327 |
+
scale_factor = img_metas[img_id]['scale_factor']
|
| 328 |
+
proposals = self._get_bboxes_single(cls_score, bbox_pred,
|
| 329 |
+
img_shape, scale_factor,
|
| 330 |
+
rescale)
|
| 331 |
+
result_list.append(proposals)
|
| 332 |
+
return result_list
|
model/mmdet_custom/models/dense_heads/detr_head.py
ADDED
|
@@ -0,0 +1,954 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from mmcv.cnn import Conv2d, Linear, build_activation_layer
|
| 6 |
+
from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
|
| 7 |
+
from mmcv.runner import force_fp32
|
| 8 |
+
|
| 9 |
+
from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
|
| 10 |
+
build_assigner, build_sampler, multi_apply,
|
| 11 |
+
reduce_mean)
|
| 12 |
+
from mmdet.models.utils import build_transformer
|
| 13 |
+
from mmdet.models.builder import HEADS, build_loss
|
| 14 |
+
from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@HEADS.register_module(force=True)
|
| 19 |
+
class DETRHead(AnchorFreeHead):
|
| 20 |
+
"""Implements the DETR transformer head.
|
| 21 |
+
|
| 22 |
+
See `paper: End-to-End Object Detection with Transformers
|
| 23 |
+
<https://arxiv.org/pdf/2005.12872>`_ for details.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
num_classes (int): Number of categories excluding the background.
|
| 27 |
+
in_channels (int): Number of channels in the input feature map.
|
| 28 |
+
num_query (int): Number of query in Transformer.
|
| 29 |
+
num_reg_fcs (int, optional): Number of fully-connected layers used in
|
| 30 |
+
`FFN`, which is then used for the regression head. Default 2.
|
| 31 |
+
transformer (obj:`mmcv.ConfigDict`|dict): Config for transformer.
|
| 32 |
+
Default: None.
|
| 33 |
+
sync_cls_avg_factor (bool): Whether to sync the avg_factor of
|
| 34 |
+
all ranks. Default to False.
|
| 35 |
+
positional_encoding (obj:`mmcv.ConfigDict`|dict):
|
| 36 |
+
Config for position encoding.
|
| 37 |
+
loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the
|
| 38 |
+
classification loss. Default `CrossEntropyLoss`.
|
| 39 |
+
loss_bbox (obj:`mmcv.ConfigDict`|dict): Config of the
|
| 40 |
+
regression loss. Default `L1Loss`.
|
| 41 |
+
loss_iou (obj:`mmcv.ConfigDict`|dict): Config of the
|
| 42 |
+
regression iou loss. Default `GIoULoss`.
|
| 43 |
+
tran_cfg (obj:`mmcv.ConfigDict`|dict): Training config of
|
| 44 |
+
transformer head.
|
| 45 |
+
test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of
|
| 46 |
+
transformer head.
|
| 47 |
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
| 48 |
+
Default: None
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
_version = 2
|
| 52 |
+
|
| 53 |
+
def __init__(self,
|
| 54 |
+
num_classes,
|
| 55 |
+
in_channels,
|
| 56 |
+
num_query=100,
|
| 57 |
+
num_reg_fcs=2,
|
| 58 |
+
transformer=None,
|
| 59 |
+
sync_cls_avg_factor=False,
|
| 60 |
+
positional_encoding=dict(
|
| 61 |
+
type='SinePositionalEncoding',
|
| 62 |
+
num_feats=128,
|
| 63 |
+
normalize=True),
|
| 64 |
+
loss_cls=dict(
|
| 65 |
+
type='CrossEntropyLoss',
|
| 66 |
+
bg_cls_weight=0.1,
|
| 67 |
+
use_sigmoid=False,
|
| 68 |
+
loss_weight=1.0,
|
| 69 |
+
class_weight=1.0),
|
| 70 |
+
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
|
| 71 |
+
loss_iou=dict(type='GIoULoss', loss_weight=2.0),
|
| 72 |
+
train_cfg=dict(
|
| 73 |
+
assigner=dict(
|
| 74 |
+
type='HungarianAssigner',
|
| 75 |
+
cls_cost=dict(type='ClassificationCost', weight=1.),
|
| 76 |
+
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
|
| 77 |
+
iou_cost=dict(
|
| 78 |
+
type='IoUCost', iou_mode='giou', weight=2.0))),
|
| 79 |
+
test_cfg=dict(max_per_img=100),
|
| 80 |
+
init_cfg=None,
|
| 81 |
+
**kwargs):
|
| 82 |
+
# NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
|
| 83 |
+
# since it brings inconvenience when the initialization of
|
| 84 |
+
# `AnchorFreeHead` is called.
|
| 85 |
+
super(AnchorFreeHead, self).__init__(init_cfg)
|
| 86 |
+
|
| 87 |
+
self.bg_cls_weight = 0
|
| 88 |
+
self.sync_cls_avg_factor = sync_cls_avg_factor
|
| 89 |
+
class_weight = loss_cls.get('class_weight', None)
|
| 90 |
+
if class_weight is not None and (self.__class__ is DETRHead):
|
| 91 |
+
# assert isinstance(class_weight, float), 'Expected ' \
|
| 92 |
+
# 'class_weight to have type float. Found ' \
|
| 93 |
+
# f'{type(class_weight)}.'
|
| 94 |
+
|
| 95 |
+
# NOTE following the official DETR rep0, bg_cls_weight means
|
| 96 |
+
# relative classification weight of the no-object class.
|
| 97 |
+
bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
|
| 98 |
+
|
| 99 |
+
assert isinstance(bg_cls_weight, float), 'Expected ' \
|
| 100 |
+
'bg_cls_weight to have type float. Found ' \
|
| 101 |
+
f'{type(bg_cls_weight)}.'
|
| 102 |
+
if isinstance(class_weight, list):
|
| 103 |
+
class_weight.append(bg_cls_weight)
|
| 104 |
+
class_weight = np.array(class_weight)
|
| 105 |
+
class_weight = torch.from_numpy(class_weight)
|
| 106 |
+
class_weight = torch.ones(num_classes + 1) * class_weight
|
| 107 |
+
elif isinstance(class_weight, float):
|
| 108 |
+
class_weight = torch.ones(num_classes + 1) * class_weight
|
| 109 |
+
# set background class as the last indice
|
| 110 |
+
class_weight[num_classes] = bg_cls_weight
|
| 111 |
+
loss_cls.update({'class_weight': class_weight})
|
| 112 |
+
if 'bg_cls_weight' in loss_cls:
|
| 113 |
+
loss_cls.pop('bg_cls_weight')
|
| 114 |
+
self.bg_cls_weight = bg_cls_weight
|
| 115 |
+
|
| 116 |
+
if train_cfg:
|
| 117 |
+
assert 'assigner' in train_cfg, 'assigner should be provided ' \
|
| 118 |
+
'when train_cfg is set.'
|
| 119 |
+
assigner = train_cfg['assigner']
|
| 120 |
+
# assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'],
|
| 121 |
+
# 'The classification weight for loss and matcher should be' \
|
| 122 |
+
# 'exactly the same.'
|
| 123 |
+
# assert loss_bbox['loss_weight'] == assigner['reg_cost'][
|
| 124 |
+
# 'weight'], 'The regression L1 weight for loss and matcher '\
|
| 125 |
+
# 'should be exactly the same.'
|
| 126 |
+
# assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'],
|
| 127 |
+
# 'The regression iou weight for loss and matcher should be' \
|
| 128 |
+
# 'exactly the same.'
|
| 129 |
+
self.assigner = build_assigner(assigner)
|
| 130 |
+
# DETR sampling=False, so use PseudoSampler
|
| 131 |
+
sampler_cfg = dict(type='PseudoSampler')
|
| 132 |
+
self.sampler = build_sampler(sampler_cfg, context=self)
|
| 133 |
+
|
| 134 |
+
self.num_query = num_query
|
| 135 |
+
self.num_classes = num_classes
|
| 136 |
+
self.in_channels = in_channels
|
| 137 |
+
self.num_reg_fcs = num_reg_fcs
|
| 138 |
+
self.train_cfg = train_cfg
|
| 139 |
+
self.test_cfg = test_cfg
|
| 140 |
+
self.fp16_enabled = False
|
| 141 |
+
self.loss_cls = build_loss(loss_cls)
|
| 142 |
+
self.loss_bbox = build_loss(loss_bbox)
|
| 143 |
+
self.loss_iou = build_loss(loss_iou)
|
| 144 |
+
|
| 145 |
+
if self.loss_cls.use_sigmoid:
|
| 146 |
+
self.cls_out_channels = num_classes
|
| 147 |
+
else:
|
| 148 |
+
self.cls_out_channels = num_classes + 1
|
| 149 |
+
self.act_cfg = transformer.get('act_cfg',
|
| 150 |
+
dict(type='ReLU', inplace=True))
|
| 151 |
+
self.activate = build_activation_layer(self.act_cfg)
|
| 152 |
+
self.positional_encoding = build_positional_encoding(
|
| 153 |
+
positional_encoding)
|
| 154 |
+
self.transformer = build_transformer(transformer)
|
| 155 |
+
self.embed_dims = self.transformer.embed_dims
|
| 156 |
+
assert 'num_feats' in positional_encoding
|
| 157 |
+
num_feats = positional_encoding['num_feats']
|
| 158 |
+
assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
|
| 159 |
+
f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
|
| 160 |
+
f' and {num_feats}.'
|
| 161 |
+
|
| 162 |
+
self._init_layers()
|
| 163 |
+
|
| 164 |
+
def _init_layers(self):
|
| 165 |
+
"""Initialize layers of the transformer head."""
|
| 166 |
+
self.input_proj = Conv2d(
|
| 167 |
+
self.in_channels, self.embed_dims, kernel_size=1)
|
| 168 |
+
self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
|
| 169 |
+
self.reg_ffn = FFN(
|
| 170 |
+
self.embed_dims,
|
| 171 |
+
self.embed_dims,
|
| 172 |
+
self.num_reg_fcs,
|
| 173 |
+
self.act_cfg,
|
| 174 |
+
dropout=0.0,
|
| 175 |
+
add_residual=False)
|
| 176 |
+
self.fc_reg = Linear(self.embed_dims, 4)
|
| 177 |
+
self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
|
| 178 |
+
|
| 179 |
+
def init_weights(self):
|
| 180 |
+
"""Initialize weights of the transformer head."""
|
| 181 |
+
# The initialization for transformer is important
|
| 182 |
+
self.transformer.init_weights()
|
| 183 |
+
|
| 184 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 185 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 186 |
+
"""load checkpoints."""
|
| 187 |
+
# NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
|
| 188 |
+
# since `AnchorFreeHead._load_from_state_dict` should not be
|
| 189 |
+
# called here. Invoking the default `Module._load_from_state_dict`
|
| 190 |
+
# is enough.
|
| 191 |
+
|
| 192 |
+
# Names of some parameters in has been changed.
|
| 193 |
+
version = local_metadata.get('version', None)
|
| 194 |
+
if (version is None or version < 2) and self.__class__ is DETRHead:
|
| 195 |
+
convert_dict = {
|
| 196 |
+
'.self_attn.': '.attentions.0.',
|
| 197 |
+
'.ffn.': '.ffns.0.',
|
| 198 |
+
'.multihead_attn.': '.attentions.1.',
|
| 199 |
+
'.decoder.norm.': '.decoder.post_norm.'
|
| 200 |
+
}
|
| 201 |
+
state_dict_keys = list(state_dict.keys())
|
| 202 |
+
for k in state_dict_keys:
|
| 203 |
+
for ori_key, convert_key in convert_dict.items():
|
| 204 |
+
if ori_key in k:
|
| 205 |
+
convert_key = k.replace(ori_key, convert_key)
|
| 206 |
+
state_dict[convert_key] = state_dict[k]
|
| 207 |
+
del state_dict[k]
|
| 208 |
+
|
| 209 |
+
super(AnchorFreeHead,
|
| 210 |
+
self)._load_from_state_dict(state_dict, prefix, local_metadata,
|
| 211 |
+
strict, missing_keys,
|
| 212 |
+
unexpected_keys, error_msgs)
|
| 213 |
+
|
| 214 |
+
def forward(self, feats, img_metas):
|
| 215 |
+
"""Forward function.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
feats (tuple[Tensor]): Features from the upstream network, each is
|
| 219 |
+
a 4D-tensor.
|
| 220 |
+
img_metas (list[dict]): List of image information.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
|
| 224 |
+
|
| 225 |
+
- all_cls_scores_list (list[Tensor]): Classification scores \
|
| 226 |
+
for each scale level. Each is a 4D-tensor with shape \
|
| 227 |
+
[nb_dec, bs, num_query, cls_out_channels]. Note \
|
| 228 |
+
`cls_out_channels` should includes background.
|
| 229 |
+
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
|
| 230 |
+
outputs for each scale level. Each is a 4D-tensor with \
|
| 231 |
+
normalized coordinate format (cx, cy, w, h) and shape \
|
| 232 |
+
[nb_dec, bs, num_query, 4].
|
| 233 |
+
"""
|
| 234 |
+
num_levels = len(feats)
|
| 235 |
+
img_metas_list = [img_metas for _ in range(num_levels)]
|
| 236 |
+
return multi_apply(self.forward_single, feats, img_metas_list)
|
| 237 |
+
|
| 238 |
+
def forward_single(self, x, img_metas):
|
| 239 |
+
""""Forward function for a single feature level.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
x (Tensor): Input feature from backbone's single stage, shape
|
| 243 |
+
[bs, c, h, w].
|
| 244 |
+
img_metas (list[dict]): List of image information.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
all_cls_scores (Tensor): Outputs from the classification head,
|
| 248 |
+
shape [nb_dec, bs, num_query, cls_out_channels]. Note
|
| 249 |
+
cls_out_channels should includes background.
|
| 250 |
+
all_bbox_preds (Tensor): Sigmoid outputs from the regression
|
| 251 |
+
head with normalized coordinate format (cx, cy, w, h).
|
| 252 |
+
Shape [nb_dec, bs, num_query, 4].
|
| 253 |
+
"""
|
| 254 |
+
# construct binary masks which used for the transformer.
|
| 255 |
+
# NOTE following the official DETR repo, non-zero values representing
|
| 256 |
+
# ignored positions, while zero values means valid positions.
|
| 257 |
+
batch_size = x.size(0)
|
| 258 |
+
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
|
| 259 |
+
masks = x.new_ones((batch_size, input_img_h, input_img_w))
|
| 260 |
+
for img_id in range(batch_size):
|
| 261 |
+
img_h, img_w, _ = img_metas[img_id]['img_shape']
|
| 262 |
+
masks[img_id, :img_h, :img_w] = 0
|
| 263 |
+
|
| 264 |
+
x = self.input_proj(x)
|
| 265 |
+
# interpolate masks to have the same spatial shape with x
|
| 266 |
+
masks = F.interpolate(
|
| 267 |
+
masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
|
| 268 |
+
# position encoding
|
| 269 |
+
pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
|
| 270 |
+
# outs_dec: [nb_dec, bs, num_query, embed_dim]
|
| 271 |
+
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
|
| 272 |
+
pos_embed)
|
| 273 |
+
|
| 274 |
+
all_cls_scores = self.fc_cls(outs_dec)
|
| 275 |
+
all_bbox_preds = self.fc_reg(self.activate(
|
| 276 |
+
self.reg_ffn(outs_dec))).sigmoid()
|
| 277 |
+
return all_cls_scores, all_bbox_preds
|
| 278 |
+
|
| 279 |
+
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
|
| 280 |
+
def loss(self,
|
| 281 |
+
all_cls_scores_list,
|
| 282 |
+
all_bbox_preds_list,
|
| 283 |
+
gt_bboxes_list,
|
| 284 |
+
gt_labels_list,
|
| 285 |
+
img_metas,
|
| 286 |
+
gt_bboxes_ignore=None):
|
| 287 |
+
""""Loss function.
|
| 288 |
+
|
| 289 |
+
Only outputs from the last feature level are used for computing
|
| 290 |
+
losses by default.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
all_cls_scores_list (list[Tensor]): Classification outputs
|
| 294 |
+
for each feature level. Each is a 4D-tensor with shape
|
| 295 |
+
[nb_dec, bs, num_query, cls_out_channels].
|
| 296 |
+
all_bbox_preds_list (list[Tensor]): Sigmoid regression
|
| 297 |
+
outputs for each feature level. Each is a 4D-tensor with
|
| 298 |
+
normalized coordinate format (cx, cy, w, h) and shape
|
| 299 |
+
[nb_dec, bs, num_query, 4].
|
| 300 |
+
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
|
| 301 |
+
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
|
| 302 |
+
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
| 303 |
+
image with shape (num_gts, ).
|
| 304 |
+
img_metas (list[dict]): List of image meta information.
|
| 305 |
+
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
|
| 306 |
+
which can be ignored for each image. Default None.
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
dict[str, Tensor]: A dictionary of loss components.
|
| 310 |
+
"""
|
| 311 |
+
# NOTE defaultly only the outputs from the last feature scale is used.
|
| 312 |
+
all_cls_scores = all_cls_scores_list[-1]
|
| 313 |
+
all_bbox_preds = all_bbox_preds_list[-1]
|
| 314 |
+
assert gt_bboxes_ignore is None, \
|
| 315 |
+
'Only supports for gt_bboxes_ignore setting to None.'
|
| 316 |
+
|
| 317 |
+
num_dec_layers = len(all_cls_scores)
|
| 318 |
+
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
|
| 319 |
+
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
| 320 |
+
all_gt_bboxes_ignore_list = [
|
| 321 |
+
gt_bboxes_ignore for _ in range(num_dec_layers)
|
| 322 |
+
]
|
| 323 |
+
img_metas_list = [img_metas for _ in range(num_dec_layers)]
|
| 324 |
+
|
| 325 |
+
losses_cls, losses_bbox, losses_iou = multi_apply(
|
| 326 |
+
self.loss_single, all_cls_scores, all_bbox_preds,
|
| 327 |
+
all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
|
| 328 |
+
all_gt_bboxes_ignore_list)
|
| 329 |
+
|
| 330 |
+
loss_dict = dict()
|
| 331 |
+
# loss from the last decoder layer
|
| 332 |
+
loss_dict['loss_cls'] = losses_cls[-1]
|
| 333 |
+
loss_dict['loss_bbox'] = losses_bbox[-1]
|
| 334 |
+
loss_dict['loss_iou'] = losses_iou[-1]
|
| 335 |
+
# loss from other decoder layers
|
| 336 |
+
num_dec_layer = 0
|
| 337 |
+
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
|
| 338 |
+
losses_bbox[:-1],
|
| 339 |
+
losses_iou[:-1]):
|
| 340 |
+
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
|
| 341 |
+
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
|
| 342 |
+
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
|
| 343 |
+
num_dec_layer += 1
|
| 344 |
+
return loss_dict
|
| 345 |
+
|
| 346 |
+
def get_fed_loss_classes(self, gt_classes, num_fed_loss_classes, num_classes, weight):
|
| 347 |
+
"""
|
| 348 |
+
Args:
|
| 349 |
+
gt_classes: a long tensor of shape R that contains the gt class label of each proposal.
|
| 350 |
+
num_fed_loss_classes: minimum number of classes to keep when calculating federated loss.
|
| 351 |
+
Will sample negative classes if number of unique gt_classes is smaller than this value.
|
| 352 |
+
num_classes: number of foreground classes
|
| 353 |
+
weight: probabilities used to sample negative classes
|
| 354 |
+
Returns:
|
| 355 |
+
Tensor:
|
| 356 |
+
classes to keep when calculating the federated loss, including both unique gt
|
| 357 |
+
classes and sampled negative classes.
|
| 358 |
+
"""
|
| 359 |
+
unique_gt_classes = torch.unique(gt_classes)
|
| 360 |
+
prob = unique_gt_classes.new_ones(num_classes + 1).float()
|
| 361 |
+
prob[-1] = 0
|
| 362 |
+
if len(unique_gt_classes) < num_fed_loss_classes:
|
| 363 |
+
prob[:num_classes] = weight.float().clone()
|
| 364 |
+
prob[unique_gt_classes] = 0
|
| 365 |
+
sampled_negative_classes = torch.multinomial(
|
| 366 |
+
prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False
|
| 367 |
+
)
|
| 368 |
+
fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes])
|
| 369 |
+
else:
|
| 370 |
+
fed_loss_classes = unique_gt_classes
|
| 371 |
+
return fed_loss_classes
|
| 372 |
+
|
| 373 |
+
def loss_single(self,
|
| 374 |
+
cls_scores,
|
| 375 |
+
bbox_preds,
|
| 376 |
+
gt_bboxes_list,
|
| 377 |
+
gt_labels_list,
|
| 378 |
+
img_metas,
|
| 379 |
+
gt_bboxes_ignore_list=None):
|
| 380 |
+
""""Loss function for outputs from a single decoder layer of a single
|
| 381 |
+
feature level.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
cls_scores (Tensor): Box score logits from a single decoder layer
|
| 385 |
+
for all images. Shape [bs, num_query, cls_out_channels].
|
| 386 |
+
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
|
| 387 |
+
for all images, with normalized coordinate (cx, cy, w, h) and
|
| 388 |
+
shape [bs, num_query, 4].
|
| 389 |
+
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
|
| 390 |
+
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
|
| 391 |
+
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
| 392 |
+
image with shape (num_gts, ).
|
| 393 |
+
img_metas (list[dict]): List of image meta information.
|
| 394 |
+
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
|
| 395 |
+
boxes which can be ignored for each image. Default None.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
dict[str, Tensor]: A dictionary of loss components for outputs from
|
| 399 |
+
a single decoder layer.
|
| 400 |
+
"""
|
| 401 |
+
num_imgs = cls_scores.size(0)
|
| 402 |
+
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
|
| 403 |
+
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
|
| 404 |
+
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
|
| 405 |
+
gt_bboxes_list, gt_labels_list,
|
| 406 |
+
img_metas, gt_bboxes_ignore_list)
|
| 407 |
+
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
|
| 408 |
+
num_total_pos, num_total_neg) = cls_reg_targets
|
| 409 |
+
|
| 410 |
+
labels = torch.cat(labels_list, 0)
|
| 411 |
+
label_weights = torch.cat(label_weights_list, 0)
|
| 412 |
+
bbox_targets = torch.cat(bbox_targets_list, 0)
|
| 413 |
+
bbox_weights = torch.cat(bbox_weights_list, 0)
|
| 414 |
+
|
| 415 |
+
# classification loss
|
| 416 |
+
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
|
| 417 |
+
# construct weighted avg_factor to match with the official DETR repo
|
| 418 |
+
cls_avg_factor = num_total_pos * 1.0 + \
|
| 419 |
+
num_total_neg * self.bg_cls_weight
|
| 420 |
+
if self.sync_cls_avg_factor:
|
| 421 |
+
cls_avg_factor = reduce_mean(
|
| 422 |
+
cls_scores.new_tensor([cls_avg_factor]))
|
| 423 |
+
cls_avg_factor = max(cls_avg_factor, 1)
|
| 424 |
+
|
| 425 |
+
loss_cls = self.loss_cls(
|
| 426 |
+
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
|
| 427 |
+
|
| 428 |
+
# Compute the average number of gt boxes across all gpus, for
|
| 429 |
+
# normalization purposes
|
| 430 |
+
num_total_pos = loss_cls.new_tensor([num_total_pos])
|
| 431 |
+
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
|
| 432 |
+
|
| 433 |
+
# construct factors used for rescale bboxes
|
| 434 |
+
factors = []
|
| 435 |
+
for img_meta, bbox_pred in zip(img_metas, bbox_preds):
|
| 436 |
+
img_h, img_w, _ = img_meta['img_shape']
|
| 437 |
+
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
|
| 438 |
+
img_h]).unsqueeze(0).repeat(
|
| 439 |
+
bbox_pred.size(0), 1)
|
| 440 |
+
factors.append(factor)
|
| 441 |
+
factors = torch.cat(factors, 0)
|
| 442 |
+
|
| 443 |
+
# DETR regress the relative position of boxes (cxcywh) in the image,
|
| 444 |
+
# thus the learning target is normalized by the image size. So here
|
| 445 |
+
# we need to re-scale them for calculating IoU loss
|
| 446 |
+
bbox_preds = bbox_preds.reshape(-1, 4)
|
| 447 |
+
bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
|
| 448 |
+
bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
|
| 449 |
+
|
| 450 |
+
# regression IoU loss, defaultly GIoU loss
|
| 451 |
+
loss_iou = self.loss_iou(
|
| 452 |
+
bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
|
| 453 |
+
|
| 454 |
+
# regression L1 loss
|
| 455 |
+
loss_bbox = self.loss_bbox(
|
| 456 |
+
bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
|
| 457 |
+
return loss_cls, loss_bbox, loss_iou
|
| 458 |
+
|
| 459 |
+
def get_targets(self,
|
| 460 |
+
cls_scores_list,
|
| 461 |
+
bbox_preds_list,
|
| 462 |
+
gt_bboxes_list,
|
| 463 |
+
gt_labels_list,
|
| 464 |
+
img_metas,
|
| 465 |
+
gt_bboxes_ignore_list=None):
|
| 466 |
+
""""Compute regression and classification targets for a batch image.
|
| 467 |
+
|
| 468 |
+
Outputs from a single decoder layer of a single feature level are used.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
cls_scores_list (list[Tensor]): Box score logits from a single
|
| 472 |
+
decoder layer for each image with shape [num_query,
|
| 473 |
+
cls_out_channels].
|
| 474 |
+
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
|
| 475 |
+
decoder layer for each image, with normalized coordinate
|
| 476 |
+
(cx, cy, w, h) and shape [num_query, 4].
|
| 477 |
+
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
|
| 478 |
+
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
|
| 479 |
+
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
| 480 |
+
image with shape (num_gts, ).
|
| 481 |
+
img_metas (list[dict]): List of image meta information.
|
| 482 |
+
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
|
| 483 |
+
boxes which can be ignored for each image. Default None.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
tuple: a tuple containing the following targets.
|
| 487 |
+
|
| 488 |
+
- labels_list (list[Tensor]): Labels for all images.
|
| 489 |
+
- label_weights_list (list[Tensor]): Label weights for all \
|
| 490 |
+
images.
|
| 491 |
+
- bbox_targets_list (list[Tensor]): BBox targets for all \
|
| 492 |
+
images.
|
| 493 |
+
- bbox_weights_list (list[Tensor]): BBox weights for all \
|
| 494 |
+
images.
|
| 495 |
+
- num_total_pos (int): Number of positive samples in all \
|
| 496 |
+
images.
|
| 497 |
+
- num_total_neg (int): Number of negative samples in all \
|
| 498 |
+
images.
|
| 499 |
+
"""
|
| 500 |
+
assert gt_bboxes_ignore_list is None, \
|
| 501 |
+
'Only supports for gt_bboxes_ignore setting to None.'
|
| 502 |
+
num_imgs = len(cls_scores_list)
|
| 503 |
+
gt_bboxes_ignore_list = [
|
| 504 |
+
gt_bboxes_ignore_list for _ in range(num_imgs)
|
| 505 |
+
]
|
| 506 |
+
|
| 507 |
+
(labels_list, label_weights_list, bbox_targets_list,
|
| 508 |
+
bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
|
| 509 |
+
self._get_target_single, cls_scores_list, bbox_preds_list,
|
| 510 |
+
gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
|
| 511 |
+
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
|
| 512 |
+
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
|
| 513 |
+
return (labels_list, label_weights_list, bbox_targets_list,
|
| 514 |
+
bbox_weights_list, num_total_pos, num_total_neg)
|
| 515 |
+
|
| 516 |
+
def _get_area_thr(self, img_shape, type):
|
| 517 |
+
MIN_V = 0
|
| 518 |
+
MAX_V = 1e10
|
| 519 |
+
short_edge = min(img_shape[0], img_shape[1])
|
| 520 |
+
if type == 'v1':
|
| 521 |
+
DELTA = 4
|
| 522 |
+
if short_edge <= 600:
|
| 523 |
+
min_edge = 128 - DELTA
|
| 524 |
+
max_edge = MAX_V
|
| 525 |
+
elif 600 < short_edge <= 800:
|
| 526 |
+
min_edge = 96 - DELTA
|
| 527 |
+
max_edge = MAX_V
|
| 528 |
+
elif 800 < short_edge <= 1000:
|
| 529 |
+
min_edge = 64 - DELTA
|
| 530 |
+
max_edge = MAX_V
|
| 531 |
+
elif 1000 < short_edge <= 1200:
|
| 532 |
+
min_edge = 32 - DELTA
|
| 533 |
+
max_edge = MAX_V
|
| 534 |
+
elif 1200 < short_edge <= 1400:
|
| 535 |
+
min_edge = MIN_V
|
| 536 |
+
max_edge = MAX_V
|
| 537 |
+
else:
|
| 538 |
+
min_edge = MIN_V
|
| 539 |
+
max_edge = 2 + DELTA
|
| 540 |
+
elif type == 'v2':
|
| 541 |
+
if short_edge <= 1000:
|
| 542 |
+
min_edge = 112
|
| 543 |
+
max_edge = MAX_V
|
| 544 |
+
elif 1000 < short_edge <= 1400:
|
| 545 |
+
min_edge = 32
|
| 546 |
+
max_edge = 160
|
| 547 |
+
elif short_edge > 1400:
|
| 548 |
+
min_edge = 0
|
| 549 |
+
max_edge = 80
|
| 550 |
+
elif type == 'v3':
|
| 551 |
+
if short_edge <= 800:
|
| 552 |
+
min_edge = 96
|
| 553 |
+
max_edge = MAX_V
|
| 554 |
+
elif 800 < short_edge <= 1000:
|
| 555 |
+
min_edge = 64
|
| 556 |
+
max_edge = MAX_V
|
| 557 |
+
elif 1000 < short_edge <= 1400:
|
| 558 |
+
min_edge = MIN_V
|
| 559 |
+
max_edge = MAX_V
|
| 560 |
+
elif 1400 < short_edge <= 1600:
|
| 561 |
+
min_edge = MIN_V
|
| 562 |
+
max_edge = 96
|
| 563 |
+
elif short_edge > 1600:
|
| 564 |
+
min_edge = MIN_V
|
| 565 |
+
max_edge = 64
|
| 566 |
+
elif type == 'v4':
|
| 567 |
+
DELTA = 4
|
| 568 |
+
if short_edge <= 800:
|
| 569 |
+
min_edge = 96 - DELTA
|
| 570 |
+
max_edge = MAX_V
|
| 571 |
+
elif 800 < short_edge <= 1000:
|
| 572 |
+
min_edge = 64 - DELTA
|
| 573 |
+
max_edge = MAX_V
|
| 574 |
+
elif 1000 < short_edge <= 1400:
|
| 575 |
+
min_edge = MIN_V
|
| 576 |
+
max_edge = MAX_V
|
| 577 |
+
elif 1400 < short_edge <= 1600:
|
| 578 |
+
min_edge = MIN_V
|
| 579 |
+
max_edge = 64 + DELTA
|
| 580 |
+
elif short_edge > 1600:
|
| 581 |
+
min_edge = MIN_V
|
| 582 |
+
max_edge = 32 + DELTA
|
| 583 |
+
|
| 584 |
+
return min_edge ** 2, max_edge ** 2
|
| 585 |
+
|
| 586 |
+
def _get_target_single(self,
|
| 587 |
+
cls_score,
|
| 588 |
+
bbox_pred,
|
| 589 |
+
gt_bboxes,
|
| 590 |
+
gt_labels,
|
| 591 |
+
img_meta,
|
| 592 |
+
gt_bboxes_ignore=None):
|
| 593 |
+
""""Compute regression and classification targets for one image.
|
| 594 |
+
|
| 595 |
+
Outputs from a single decoder layer of a single feature level are used.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
cls_score (Tensor): Box score logits from a single decoder layer
|
| 599 |
+
for one image. Shape [num_query, cls_out_channels].
|
| 600 |
+
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
|
| 601 |
+
for one image, with normalized coordinate (cx, cy, w, h) and
|
| 602 |
+
shape [num_query, 4].
|
| 603 |
+
gt_bboxes (Tensor): Ground truth bboxes for one image with
|
| 604 |
+
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
|
| 605 |
+
gt_labels (Tensor): Ground truth class indices for one image
|
| 606 |
+
with shape (num_gts, ).
|
| 607 |
+
img_meta (dict): Meta information for one image.
|
| 608 |
+
gt_bboxes_ignore (Tensor, optional): Bounding boxes
|
| 609 |
+
which can be ignored. Default None.
|
| 610 |
+
|
| 611 |
+
Returns:
|
| 612 |
+
tuple[Tensor]: a tuple containing the following for one image.
|
| 613 |
+
|
| 614 |
+
- labels (Tensor): Labels of each image.
|
| 615 |
+
- label_weights (Tensor]): Label weights of each image.
|
| 616 |
+
- bbox_targets (Tensor): BBox targets of each image.
|
| 617 |
+
- bbox_weights (Tensor): BBox weights of each image.
|
| 618 |
+
- pos_inds (Tensor): Sampled positive indices for each image.
|
| 619 |
+
- neg_inds (Tensor): Sampled negative indices for each image.
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
num_bboxes = bbox_pred.size(0)
|
| 623 |
+
# assigner and sampler
|
| 624 |
+
assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
|
| 625 |
+
gt_labels, img_meta,
|
| 626 |
+
gt_bboxes_ignore)
|
| 627 |
+
sampling_result = self.sampler.sample(assign_result, bbox_pred,
|
| 628 |
+
gt_bboxes)
|
| 629 |
+
pos_inds = sampling_result.pos_inds
|
| 630 |
+
neg_inds = sampling_result.neg_inds
|
| 631 |
+
|
| 632 |
+
# label targets
|
| 633 |
+
labels = gt_bboxes.new_full((num_bboxes, ),
|
| 634 |
+
self.num_classes,
|
| 635 |
+
dtype=torch.long)
|
| 636 |
+
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
|
| 637 |
+
label_weights = gt_bboxes.new_ones(num_bboxes)
|
| 638 |
+
|
| 639 |
+
# bbox targets
|
| 640 |
+
bbox_targets = torch.zeros_like(bbox_pred)
|
| 641 |
+
bbox_weights = torch.zeros_like(bbox_pred)
|
| 642 |
+
bbox_weights[pos_inds] = 1.0
|
| 643 |
+
img_h, img_w, _ = img_meta['img_shape']
|
| 644 |
+
|
| 645 |
+
# DETR regress the relative position of boxes (cxcywh) in the image.
|
| 646 |
+
# Thus the learning target should be normalized by the image size, also
|
| 647 |
+
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
|
| 648 |
+
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
|
| 649 |
+
img_h]).unsqueeze(0)
|
| 650 |
+
pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
|
| 651 |
+
pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
|
| 652 |
+
bbox_targets[pos_inds] = pos_gt_bboxes_targets
|
| 653 |
+
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
|
| 654 |
+
neg_inds)
|
| 655 |
+
|
| 656 |
+
# over-write because img_metas are needed as inputs for bbox_head.
|
| 657 |
+
def forward_train(self,
|
| 658 |
+
x,
|
| 659 |
+
img_metas,
|
| 660 |
+
gt_bboxes,
|
| 661 |
+
gt_labels=None,
|
| 662 |
+
gt_bboxes_ignore=None,
|
| 663 |
+
proposal_cfg=None,
|
| 664 |
+
**kwargs):
|
| 665 |
+
"""Forward function for training mode.
|
| 666 |
+
|
| 667 |
+
Args:
|
| 668 |
+
x (list[Tensor]): Features from backbone.
|
| 669 |
+
img_metas (list[dict]): Meta information of each image, e.g.,
|
| 670 |
+
image size, scaling factor, etc.
|
| 671 |
+
gt_bboxes (Tensor): Ground truth bboxes of the image,
|
| 672 |
+
shape (num_gts, 4).
|
| 673 |
+
gt_labels (Tensor): Ground truth labels of each box,
|
| 674 |
+
shape (num_gts,).
|
| 675 |
+
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
|
| 676 |
+
ignored, shape (num_ignored_gts, 4).
|
| 677 |
+
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
|
| 678 |
+
if None, test_cfg would be used.
|
| 679 |
+
|
| 680 |
+
Returns:
|
| 681 |
+
dict[str, Tensor]: A dictionary of loss components.
|
| 682 |
+
"""
|
| 683 |
+
assert proposal_cfg is None, '"proposal_cfg" must be None'
|
| 684 |
+
outs = self(x, img_metas)
|
| 685 |
+
if gt_labels is None:
|
| 686 |
+
loss_inputs = outs + (gt_bboxes, img_metas)
|
| 687 |
+
else:
|
| 688 |
+
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
|
| 689 |
+
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
|
| 690 |
+
return losses
|
| 691 |
+
|
| 692 |
+
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
|
| 693 |
+
def get_bboxes(self,
|
| 694 |
+
all_cls_scores_list,
|
| 695 |
+
all_bbox_preds_list,
|
| 696 |
+
img_metas,
|
| 697 |
+
rescale=False):
|
| 698 |
+
"""Transform network outputs for a batch into bbox predictions.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
all_cls_scores_list (list[Tensor]): Classification outputs
|
| 702 |
+
for each feature level. Each is a 4D-tensor with shape
|
| 703 |
+
[nb_dec, bs, num_query, cls_out_channels].
|
| 704 |
+
all_bbox_preds_list (list[Tensor]): Sigmoid regression
|
| 705 |
+
outputs for each feature level. Each is a 4D-tensor with
|
| 706 |
+
normalized coordinate format (cx, cy, w, h) and shape
|
| 707 |
+
[nb_dec, bs, num_query, 4].
|
| 708 |
+
img_metas (list[dict]): Meta information of each image.
|
| 709 |
+
rescale (bool, optional): If True, return boxes in original
|
| 710 |
+
image space. Default False.
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
|
| 714 |
+
The first item is an (n, 5) tensor, where the first 4 columns \
|
| 715 |
+
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
|
| 716 |
+
5-th column is a score between 0 and 1. The second item is a \
|
| 717 |
+
(n,) tensor where each item is the predicted class label of \
|
| 718 |
+
the corresponding box.
|
| 719 |
+
"""
|
| 720 |
+
# NOTE defaultly only using outputs from the last feature level,
|
| 721 |
+
# and only the outputs from the last decoder layer is used.
|
| 722 |
+
cls_scores = all_cls_scores_list[-1][-1]
|
| 723 |
+
bbox_preds = all_bbox_preds_list[-1][-1]
|
| 724 |
+
|
| 725 |
+
result_list = []
|
| 726 |
+
for img_id in range(len(img_metas)):
|
| 727 |
+
cls_score = cls_scores[img_id]
|
| 728 |
+
bbox_pred = bbox_preds[img_id]
|
| 729 |
+
img_shape = img_metas[img_id]['img_shape']
|
| 730 |
+
scale_factor = img_metas[img_id]['scale_factor']
|
| 731 |
+
proposals = self._get_bboxes_single(cls_score, bbox_pred,
|
| 732 |
+
img_shape, scale_factor,
|
| 733 |
+
rescale)
|
| 734 |
+
result_list.append(proposals)
|
| 735 |
+
|
| 736 |
+
return result_list
|
| 737 |
+
|
| 738 |
+
def _get_bboxes_single(self,
|
| 739 |
+
cls_score,
|
| 740 |
+
bbox_pred,
|
| 741 |
+
img_shape,
|
| 742 |
+
scale_factor,
|
| 743 |
+
rescale=False):
|
| 744 |
+
"""Transform outputs from the last decoder layer into bbox predictions
|
| 745 |
+
for each image.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
cls_score (Tensor): Box score logits from the last decoder layer
|
| 749 |
+
for each image. Shape [num_query, cls_out_channels].
|
| 750 |
+
bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
|
| 751 |
+
for each image, with coordinate format (cx, cy, w, h) and
|
| 752 |
+
shape [num_query, 4].
|
| 753 |
+
img_shape (tuple[int]): Shape of input image, (height, width, 3).
|
| 754 |
+
scale_factor (ndarray, optional): Scale factor of the image arange
|
| 755 |
+
as (w_scale, h_scale, w_scale, h_scale).
|
| 756 |
+
rescale (bool, optional): If True, return boxes in original image
|
| 757 |
+
space. Default False.
|
| 758 |
+
|
| 759 |
+
Returns:
|
| 760 |
+
tuple[Tensor]: Results of detected bboxes and labels.
|
| 761 |
+
|
| 762 |
+
- det_bboxes: Predicted bboxes with shape [num_query, 5], \
|
| 763 |
+
where the first 4 columns are bounding box positions \
|
| 764 |
+
(tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
|
| 765 |
+
between 0 and 1.
|
| 766 |
+
- det_labels: Predicted labels of the corresponding box with \
|
| 767 |
+
shape [num_query].
|
| 768 |
+
"""
|
| 769 |
+
assert len(cls_score) == len(bbox_pred)
|
| 770 |
+
max_per_img = self.test_cfg.get('max_per_img', self.num_query)
|
| 771 |
+
# exclude background
|
| 772 |
+
if self.loss_cls.use_sigmoid:
|
| 773 |
+
cls_score = cls_score.sigmoid()
|
| 774 |
+
scores, indexes = cls_score.view(-1).topk(max_per_img)
|
| 775 |
+
det_labels = indexes % self.num_classes
|
| 776 |
+
bbox_index = indexes // self.num_classes
|
| 777 |
+
bbox_pred = bbox_pred[bbox_index]
|
| 778 |
+
else:
|
| 779 |
+
scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
|
| 780 |
+
scores, bbox_index = scores.topk(max_per_img)
|
| 781 |
+
bbox_pred = bbox_pred[bbox_index]
|
| 782 |
+
det_labels = det_labels[bbox_index]
|
| 783 |
+
|
| 784 |
+
det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
|
| 785 |
+
det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
|
| 786 |
+
det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
|
| 787 |
+
det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
|
| 788 |
+
det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
|
| 789 |
+
if rescale:
|
| 790 |
+
det_bboxes /= det_bboxes.new_tensor(scale_factor)
|
| 791 |
+
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)
|
| 792 |
+
|
| 793 |
+
return det_bboxes, det_labels
|
| 794 |
+
|
| 795 |
+
def simple_test_bboxes(self, feats, img_metas, rescale=False):
|
| 796 |
+
"""Test det bboxes without test-time augmentation.
|
| 797 |
+
|
| 798 |
+
Args:
|
| 799 |
+
feats (tuple[torch.Tensor]): Multi-level features from the
|
| 800 |
+
upstream network, each is a 4D-tensor.
|
| 801 |
+
img_metas (list[dict]): List of image information.
|
| 802 |
+
rescale (bool, optional): Whether to rescale the results.
|
| 803 |
+
Defaults to False.
|
| 804 |
+
|
| 805 |
+
Returns:
|
| 806 |
+
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
|
| 807 |
+
The first item is ``bboxes`` with shape (n, 5),
|
| 808 |
+
where 5 represent (tl_x, tl_y, br_x, br_y, score).
|
| 809 |
+
The shape of the second tensor in the tuple is ``labels``
|
| 810 |
+
with shape (n,)
|
| 811 |
+
"""
|
| 812 |
+
# forward of this head requires img_metas
|
| 813 |
+
outs = self.forward(feats, img_metas)
|
| 814 |
+
results_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
|
| 815 |
+
return results_list
|
| 816 |
+
|
| 817 |
+
def forward_onnx(self, feats, img_metas):
|
| 818 |
+
"""Forward function for exporting to ONNX.
|
| 819 |
+
|
| 820 |
+
Over-write `forward` because: `masks` is directly created with
|
| 821 |
+
zero (valid position tag) and has the same spatial size as `x`.
|
| 822 |
+
Thus the construction of `masks` is different from that in `forward`.
|
| 823 |
+
|
| 824 |
+
Args:
|
| 825 |
+
feats (tuple[Tensor]): Features from the upstream network, each is
|
| 826 |
+
a 4D-tensor.
|
| 827 |
+
img_metas (list[dict]): List of image information.
|
| 828 |
+
|
| 829 |
+
Returns:
|
| 830 |
+
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
|
| 831 |
+
|
| 832 |
+
- all_cls_scores_list (list[Tensor]): Classification scores \
|
| 833 |
+
for each scale level. Each is a 4D-tensor with shape \
|
| 834 |
+
[nb_dec, bs, num_query, cls_out_channels]. Note \
|
| 835 |
+
`cls_out_channels` should includes background.
|
| 836 |
+
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
|
| 837 |
+
outputs for each scale level. Each is a 4D-tensor with \
|
| 838 |
+
normalized coordinate format (cx, cy, w, h) and shape \
|
| 839 |
+
[nb_dec, bs, num_query, 4].
|
| 840 |
+
"""
|
| 841 |
+
num_levels = len(feats)
|
| 842 |
+
img_metas_list = [img_metas for _ in range(num_levels)]
|
| 843 |
+
return multi_apply(self.forward_single_onnx, feats, img_metas_list)
|
| 844 |
+
|
| 845 |
+
def forward_single_onnx(self, x, img_metas):
|
| 846 |
+
""""Forward function for a single feature level with ONNX exportation.
|
| 847 |
+
|
| 848 |
+
Args:
|
| 849 |
+
x (Tensor): Input feature from backbone's single stage, shape
|
| 850 |
+
[bs, c, h, w].
|
| 851 |
+
img_metas (list[dict]): List of image information.
|
| 852 |
+
|
| 853 |
+
Returns:
|
| 854 |
+
all_cls_scores (Tensor): Outputs from the classification head,
|
| 855 |
+
shape [nb_dec, bs, num_query, cls_out_channels]. Note
|
| 856 |
+
cls_out_channels should includes background.
|
| 857 |
+
all_bbox_preds (Tensor): Sigmoid outputs from the regression
|
| 858 |
+
head with normalized coordinate format (cx, cy, w, h).
|
| 859 |
+
Shape [nb_dec, bs, num_query, 4].
|
| 860 |
+
"""
|
| 861 |
+
# Note `img_shape` is not dynamically traceable to ONNX,
|
| 862 |
+
# since the related augmentation was done with numpy under
|
| 863 |
+
# CPU. Thus `masks` is directly created with zeros (valid tag)
|
| 864 |
+
# and the same spatial shape as `x`.
|
| 865 |
+
# The difference between torch and exported ONNX model may be
|
| 866 |
+
# ignored, since the same performance is achieved (e.g.
|
| 867 |
+
# 40.1 vs 40.1 for DETR)
|
| 868 |
+
batch_size = x.size(0)
|
| 869 |
+
h, w = x.size()[-2:]
|
| 870 |
+
masks = x.new_zeros((batch_size, h, w)) # [B,h,w]
|
| 871 |
+
|
| 872 |
+
x = self.input_proj(x)
|
| 873 |
+
# interpolate masks to have the same spatial shape with x
|
| 874 |
+
masks = F.interpolate(
|
| 875 |
+
masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
|
| 876 |
+
pos_embed = self.positional_encoding(masks)
|
| 877 |
+
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
|
| 878 |
+
pos_embed)
|
| 879 |
+
|
| 880 |
+
all_cls_scores = self.fc_cls(outs_dec)
|
| 881 |
+
all_bbox_preds = self.fc_reg(self.activate(
|
| 882 |
+
self.reg_ffn(outs_dec))).sigmoid()
|
| 883 |
+
return all_cls_scores, all_bbox_preds
|
| 884 |
+
|
| 885 |
+
def onnx_export(self, all_cls_scores_list, all_bbox_preds_list, img_metas):
|
| 886 |
+
"""Transform network outputs into bbox predictions, with ONNX
|
| 887 |
+
exportation.
|
| 888 |
+
|
| 889 |
+
Args:
|
| 890 |
+
all_cls_scores_list (list[Tensor]): Classification outputs
|
| 891 |
+
for each feature level. Each is a 4D-tensor with shape
|
| 892 |
+
[nb_dec, bs, num_query, cls_out_channels].
|
| 893 |
+
all_bbox_preds_list (list[Tensor]): Sigmoid regression
|
| 894 |
+
outputs for each feature level. Each is a 4D-tensor with
|
| 895 |
+
normalized coordinate format (cx, cy, w, h) and shape
|
| 896 |
+
[nb_dec, bs, num_query, 4].
|
| 897 |
+
img_metas (list[dict]): Meta information of each image.
|
| 898 |
+
|
| 899 |
+
Returns:
|
| 900 |
+
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
|
| 901 |
+
and class labels of shape [N, num_det].
|
| 902 |
+
"""
|
| 903 |
+
assert len(img_metas) == 1, \
|
| 904 |
+
'Only support one input image while in exporting to ONNX'
|
| 905 |
+
|
| 906 |
+
cls_scores = all_cls_scores_list[-1][-1]
|
| 907 |
+
bbox_preds = all_bbox_preds_list[-1][-1]
|
| 908 |
+
|
| 909 |
+
# Note `img_shape` is not dynamically traceable to ONNX,
|
| 910 |
+
# here `img_shape_for_onnx` (padded shape of image tensor)
|
| 911 |
+
# is used.
|
| 912 |
+
img_shape = img_metas[0]['img_shape_for_onnx']
|
| 913 |
+
max_per_img = self.test_cfg.get('max_per_img', self.num_query)
|
| 914 |
+
batch_size = cls_scores.size(0)
|
| 915 |
+
# `batch_index_offset` is used for the gather of concatenated tensor
|
| 916 |
+
batch_index_offset = torch.arange(batch_size).to(
|
| 917 |
+
cls_scores.device) * max_per_img
|
| 918 |
+
batch_index_offset = batch_index_offset.unsqueeze(1).expand(
|
| 919 |
+
batch_size, max_per_img)
|
| 920 |
+
|
| 921 |
+
# supports dynamical batch inference
|
| 922 |
+
if self.loss_cls.use_sigmoid:
|
| 923 |
+
cls_scores = cls_scores.sigmoid()
|
| 924 |
+
scores, indexes = cls_scores.view(batch_size, -1).topk(
|
| 925 |
+
max_per_img, dim=1)
|
| 926 |
+
det_labels = indexes % self.num_classes
|
| 927 |
+
bbox_index = indexes // self.num_classes
|
| 928 |
+
bbox_index = (bbox_index + batch_index_offset).view(-1)
|
| 929 |
+
bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
|
| 930 |
+
bbox_preds = bbox_preds.view(batch_size, -1, 4)
|
| 931 |
+
else:
|
| 932 |
+
scores, det_labels = F.softmax(
|
| 933 |
+
cls_scores, dim=-1)[..., :-1].max(-1)
|
| 934 |
+
scores, bbox_index = scores.topk(max_per_img, dim=1)
|
| 935 |
+
bbox_index = (bbox_index + batch_index_offset).view(-1)
|
| 936 |
+
bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
|
| 937 |
+
det_labels = det_labels.view(-1)[bbox_index]
|
| 938 |
+
bbox_preds = bbox_preds.view(batch_size, -1, 4)
|
| 939 |
+
det_labels = det_labels.view(batch_size, -1)
|
| 940 |
+
|
| 941 |
+
det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)
|
| 942 |
+
# use `img_shape_tensor` for dynamically exporting to ONNX
|
| 943 |
+
img_shape_tensor = img_shape.flip(0).repeat(2) # [w,h,w,h]
|
| 944 |
+
img_shape_tensor = img_shape_tensor.unsqueeze(0).unsqueeze(0).expand(
|
| 945 |
+
batch_size, det_bboxes.size(1), 4)
|
| 946 |
+
det_bboxes = det_bboxes * img_shape_tensor
|
| 947 |
+
# dynamically clip bboxes
|
| 948 |
+
x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
|
| 949 |
+
from mmdet.core.export import dynamic_clip_for_onnx
|
| 950 |
+
x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, img_shape)
|
| 951 |
+
det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
|
| 952 |
+
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)
|
| 953 |
+
|
| 954 |
+
return det_bboxes, det_labels
|
model/mmdet_custom/models/dense_heads/dino_head.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, multi_apply,
|
| 7 |
+
reduce_mean)
|
| 8 |
+
from ..utils import build_dn_generator
|
| 9 |
+
from mmdet.models.utils.transformer import inverse_sigmoid
|
| 10 |
+
from mmdet.models.builder import HEADS
|
| 11 |
+
from .deformable_detr_head import DeformableDETRHead
|
| 12 |
+
from mmcv.runner import force_fp32
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@HEADS.register_module()
|
| 16 |
+
class DINOHead(DeformableDETRHead):
|
| 17 |
+
|
| 18 |
+
def __init__(self, *args, dn_cfg=None, **kwargs):
|
| 19 |
+
super(DINOHead, self).__init__(*args, **kwargs)
|
| 20 |
+
self._init_layers()
|
| 21 |
+
self.init_denoising(dn_cfg)
|
| 22 |
+
assert self.as_two_stage, \
|
| 23 |
+
'as_two_stage must be True for DINO'
|
| 24 |
+
assert self.with_box_refine, \
|
| 25 |
+
'with_box_refine must be True for DINO'
|
| 26 |
+
|
| 27 |
+
def _init_layers(self):
|
| 28 |
+
super()._init_layers()
|
| 29 |
+
# NOTE The original repo of DINO set the num_embeddings 92 for coco,
|
| 30 |
+
# 91 (0~90) of which represents target classes and the 92 (91)
|
| 31 |
+
# indicates [Unknown] class. However, the embedding of unknown class
|
| 32 |
+
# is not used in the original DINO
|
| 33 |
+
self.label_embedding = nn.Embedding(self.cls_out_channels,
|
| 34 |
+
self.embed_dims)
|
| 35 |
+
|
| 36 |
+
def init_denoising(self, dn_cfg):
|
| 37 |
+
if dn_cfg is not None:
|
| 38 |
+
dn_cfg['num_classes'] = self.num_classes
|
| 39 |
+
dn_cfg['num_queries'] = self.num_query
|
| 40 |
+
dn_cfg['hidden_dim'] = self.embed_dims
|
| 41 |
+
self.dn_generator = build_dn_generator(dn_cfg)
|
| 42 |
+
|
| 43 |
+
def forward_train(self,
|
| 44 |
+
x,
|
| 45 |
+
img_metas,
|
| 46 |
+
gt_bboxes,
|
| 47 |
+
gt_labels=None,
|
| 48 |
+
gt_bboxes_ignore=None,
|
| 49 |
+
proposal_cfg=None,
|
| 50 |
+
**kwargs):
|
| 51 |
+
assert proposal_cfg is None, '"proposal_cfg" must be None'
|
| 52 |
+
assert self.dn_generator is not None, '"dn_cfg" must be set'
|
| 53 |
+
dn_label_query, dn_bbox_query, attn_mask, dn_meta = \
|
| 54 |
+
self.dn_generator(gt_bboxes, gt_labels,
|
| 55 |
+
self.label_embedding, img_metas)
|
| 56 |
+
outs = self(x, img_metas, dn_label_query, dn_bbox_query, attn_mask)
|
| 57 |
+
if gt_labels is None:
|
| 58 |
+
loss_inputs = outs + (gt_bboxes, img_metas, dn_meta)
|
| 59 |
+
else:
|
| 60 |
+
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, dn_meta)
|
| 61 |
+
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
|
| 62 |
+
return losses
|
| 63 |
+
|
| 64 |
+
def forward(self,
|
| 65 |
+
mlvl_feats,
|
| 66 |
+
img_metas,
|
| 67 |
+
dn_label_query=None,
|
| 68 |
+
dn_bbox_query=None,
|
| 69 |
+
attn_mask=None):
|
| 70 |
+
batch_size = mlvl_feats[0].size(0)
|
| 71 |
+
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
|
| 72 |
+
img_masks = mlvl_feats[0].new_ones(
|
| 73 |
+
(batch_size, input_img_h, input_img_w))
|
| 74 |
+
for img_id in range(batch_size):
|
| 75 |
+
img_h, img_w, _ = img_metas[img_id]['img_shape']
|
| 76 |
+
img_masks[img_id, :img_h, :img_w] = 0
|
| 77 |
+
|
| 78 |
+
mlvl_masks = []
|
| 79 |
+
mlvl_positional_encodings = []
|
| 80 |
+
for feat in mlvl_feats:
|
| 81 |
+
mlvl_masks.append(
|
| 82 |
+
F.interpolate(
|
| 83 |
+
img_masks[None],
|
| 84 |
+
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
|
| 85 |
+
mlvl_positional_encodings.append(
|
| 86 |
+
self.positional_encoding(mlvl_masks[-1]))
|
| 87 |
+
|
| 88 |
+
query_embeds = None
|
| 89 |
+
hs, inter_references, topk_score, topk_anchor = \
|
| 90 |
+
self.transformer(
|
| 91 |
+
mlvl_feats,
|
| 92 |
+
mlvl_masks,
|
| 93 |
+
query_embeds,
|
| 94 |
+
mlvl_positional_encodings,
|
| 95 |
+
dn_label_query,
|
| 96 |
+
dn_bbox_query,
|
| 97 |
+
attn_mask,
|
| 98 |
+
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
|
| 99 |
+
cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
|
| 100 |
+
)
|
| 101 |
+
hs = hs.permute(0, 2, 1, 3)
|
| 102 |
+
|
| 103 |
+
if dn_label_query is not None and dn_label_query.size(1) == 0:
|
| 104 |
+
# NOTE: If there is no target in the image, the parameters of
|
| 105 |
+
# label_embedding won't be used in producing loss, which raises
|
| 106 |
+
# RuntimeError when using distributed mode.
|
| 107 |
+
hs[0] += self.label_embedding.weight[0, 0] * 0.0
|
| 108 |
+
|
| 109 |
+
outputs_classes = []
|
| 110 |
+
outputs_coords = []
|
| 111 |
+
|
| 112 |
+
for lvl in range(hs.shape[0]):
|
| 113 |
+
reference = inter_references[lvl]
|
| 114 |
+
reference = inverse_sigmoid(reference, eps=1e-3)
|
| 115 |
+
outputs_class = self.cls_branches[lvl](hs[lvl])
|
| 116 |
+
tmp = self.reg_branches[lvl](hs[lvl])
|
| 117 |
+
if reference.shape[-1] == 4:
|
| 118 |
+
tmp += reference
|
| 119 |
+
else:
|
| 120 |
+
assert reference.shape[-1] == 2
|
| 121 |
+
tmp[..., :2] += reference
|
| 122 |
+
outputs_coord = tmp.sigmoid()
|
| 123 |
+
outputs_classes.append(outputs_class)
|
| 124 |
+
outputs_coords.append(outputs_coord)
|
| 125 |
+
|
| 126 |
+
outputs_classes = torch.stack(outputs_classes)
|
| 127 |
+
outputs_coords = torch.stack(outputs_coords)
|
| 128 |
+
|
| 129 |
+
return outputs_classes, outputs_coords, topk_score, topk_anchor
|
| 130 |
+
|
| 131 |
+
@force_fp32(apply_to=('all_cls_scores', 'all_bbox_preds'))
|
| 132 |
+
def loss(self,
|
| 133 |
+
all_cls_scores,
|
| 134 |
+
all_bbox_preds,
|
| 135 |
+
enc_topk_scores,
|
| 136 |
+
enc_topk_anchors,
|
| 137 |
+
gt_bboxes_list,
|
| 138 |
+
gt_labels_list,
|
| 139 |
+
img_metas,
|
| 140 |
+
dn_meta=None,
|
| 141 |
+
gt_bboxes_ignore=None):
|
| 142 |
+
assert gt_bboxes_ignore is None, \
|
| 143 |
+
f'{self.__class__.__name__} only supports ' \
|
| 144 |
+
f'for gt_bboxes_ignore setting to None.'
|
| 145 |
+
|
| 146 |
+
loss_dict = dict()
|
| 147 |
+
|
| 148 |
+
# extract denoising and matching part of outputs
|
| 149 |
+
all_cls_scores, all_bbox_preds, dn_cls_scores, dn_bbox_preds = \
|
| 150 |
+
self.extract_dn_outputs(all_cls_scores, all_bbox_preds, dn_meta)
|
| 151 |
+
|
| 152 |
+
if enc_topk_scores is not None:
|
| 153 |
+
# calculate loss from encode feature maps
|
| 154 |
+
# NOTE The DeformDETR calculate binary cls loss
|
| 155 |
+
# for all encoder embeddings, while DINO calculate
|
| 156 |
+
# multi-class loss for topk embeddings.
|
| 157 |
+
enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
|
| 158 |
+
self.loss_single(enc_topk_scores, enc_topk_anchors,
|
| 159 |
+
gt_bboxes_list, gt_labels_list,
|
| 160 |
+
img_metas, gt_bboxes_ignore)
|
| 161 |
+
|
| 162 |
+
# collate loss from encode feature maps
|
| 163 |
+
loss_dict['interm_loss_cls'] = enc_loss_cls
|
| 164 |
+
loss_dict['interm_loss_bbox'] = enc_losses_bbox
|
| 165 |
+
loss_dict['interm_loss_iou'] = enc_losses_iou
|
| 166 |
+
|
| 167 |
+
# calculate loss from all decoder layers
|
| 168 |
+
num_dec_layers = len(all_cls_scores)
|
| 169 |
+
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
|
| 170 |
+
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
| 171 |
+
all_gt_bboxes_ignore_list = [
|
| 172 |
+
gt_bboxes_ignore for _ in range(num_dec_layers)
|
| 173 |
+
]
|
| 174 |
+
img_metas_list = [img_metas for _ in range(num_dec_layers)]
|
| 175 |
+
losses_cls, losses_bbox, losses_iou = multi_apply(
|
| 176 |
+
self.loss_single, all_cls_scores, all_bbox_preds,
|
| 177 |
+
all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
|
| 178 |
+
all_gt_bboxes_ignore_list)
|
| 179 |
+
|
| 180 |
+
# collate loss from the last decoder layer
|
| 181 |
+
loss_dict['loss_cls'] = losses_cls[-1]
|
| 182 |
+
loss_dict['loss_bbox'] = losses_bbox[-1]
|
| 183 |
+
loss_dict['loss_iou'] = losses_iou[-1]
|
| 184 |
+
|
| 185 |
+
# collate loss from other decoder layers
|
| 186 |
+
num_dec_layer = 0
|
| 187 |
+
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
|
| 188 |
+
losses_bbox[:-1],
|
| 189 |
+
losses_iou[:-1]):
|
| 190 |
+
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
|
| 191 |
+
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
|
| 192 |
+
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
|
| 193 |
+
num_dec_layer += 1
|
| 194 |
+
|
| 195 |
+
if dn_cls_scores is not None:
|
| 196 |
+
# calculate denoising loss from all decoder layers
|
| 197 |
+
dn_meta = [dn_meta for _ in img_metas]
|
| 198 |
+
dn_losses_cls, dn_losses_bbox, dn_losses_iou = self.loss_dn(
|
| 199 |
+
dn_cls_scores, dn_bbox_preds, gt_bboxes_list, gt_labels_list,
|
| 200 |
+
img_metas, dn_meta)
|
| 201 |
+
|
| 202 |
+
# collate denoising loss
|
| 203 |
+
loss_dict['dn_loss_cls'] = dn_losses_cls[-1]
|
| 204 |
+
loss_dict['dn_loss_bbox'] = dn_losses_bbox[-1]
|
| 205 |
+
loss_dict['dn_loss_iou'] = dn_losses_iou[-1]
|
| 206 |
+
num_dec_layer = 0
|
| 207 |
+
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(
|
| 208 |
+
dn_losses_cls[:-1], dn_losses_bbox[:-1],
|
| 209 |
+
dn_losses_iou[:-1]):
|
| 210 |
+
loss_dict[f'd{num_dec_layer}.dn_loss_cls'] = loss_cls_i
|
| 211 |
+
loss_dict[f'd{num_dec_layer}.dn_loss_bbox'] = loss_bbox_i
|
| 212 |
+
loss_dict[f'd{num_dec_layer}.dn_loss_iou'] = loss_iou_i
|
| 213 |
+
num_dec_layer += 1
|
| 214 |
+
|
| 215 |
+
return loss_dict
|
| 216 |
+
|
| 217 |
+
def loss_dn(self, dn_cls_scores, dn_bbox_preds, gt_bboxes_list,
|
| 218 |
+
gt_labels_list, img_metas, dn_meta):
|
| 219 |
+
num_dec_layers = len(dn_cls_scores)
|
| 220 |
+
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
|
| 221 |
+
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
| 222 |
+
img_metas_list = [img_metas for _ in range(num_dec_layers)]
|
| 223 |
+
dn_meta_list = [dn_meta for _ in range(num_dec_layers)]
|
| 224 |
+
return multi_apply(self.loss_dn_single, dn_cls_scores, dn_bbox_preds,
|
| 225 |
+
all_gt_bboxes_list, all_gt_labels_list,
|
| 226 |
+
img_metas_list, dn_meta_list)
|
| 227 |
+
|
| 228 |
+
def loss_dn_single(self, dn_cls_scores, dn_bbox_preds, gt_bboxes_list,
|
| 229 |
+
gt_labels_list, img_metas, dn_meta):
|
| 230 |
+
num_imgs = dn_cls_scores.size(0)
|
| 231 |
+
bbox_preds_list = [dn_bbox_preds[i] for i in range(num_imgs)]
|
| 232 |
+
cls_reg_targets = self.get_dn_target(bbox_preds_list, gt_bboxes_list,
|
| 233 |
+
gt_labels_list, img_metas,
|
| 234 |
+
dn_meta)
|
| 235 |
+
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
|
| 236 |
+
num_total_pos, num_total_neg) = cls_reg_targets
|
| 237 |
+
labels = torch.cat(labels_list, 0)
|
| 238 |
+
label_weights = torch.cat(label_weights_list, 0)
|
| 239 |
+
bbox_targets = torch.cat(bbox_targets_list, 0)
|
| 240 |
+
bbox_weights = torch.cat(bbox_weights_list, 0)
|
| 241 |
+
|
| 242 |
+
# classification loss
|
| 243 |
+
cls_scores = dn_cls_scores.reshape(-1, self.cls_out_channels)
|
| 244 |
+
# construct weighted avg_factor to match with the official DETR repo
|
| 245 |
+
cls_avg_factor = \
|
| 246 |
+
num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight
|
| 247 |
+
if self.sync_cls_avg_factor:
|
| 248 |
+
cls_avg_factor = reduce_mean(
|
| 249 |
+
cls_scores.new_tensor([cls_avg_factor]))
|
| 250 |
+
cls_avg_factor = max(cls_avg_factor, 1)
|
| 251 |
+
|
| 252 |
+
if len(cls_scores) > 0:
|
| 253 |
+
loss_cls = self.loss_cls(
|
| 254 |
+
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
|
| 255 |
+
else:
|
| 256 |
+
loss_cls = torch.zeros( # TODO: How to better return zero loss
|
| 257 |
+
1,
|
| 258 |
+
dtype=cls_scores.dtype,
|
| 259 |
+
device=cls_scores.device)
|
| 260 |
+
|
| 261 |
+
# Compute the average number of gt boxes across all gpus, for
|
| 262 |
+
# normalization purposes
|
| 263 |
+
num_total_pos = loss_cls.new_tensor([num_total_pos])
|
| 264 |
+
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
|
| 265 |
+
|
| 266 |
+
# construct factors used for rescale bboxes
|
| 267 |
+
factors = []
|
| 268 |
+
for img_meta, bbox_pred in zip(img_metas, dn_bbox_preds):
|
| 269 |
+
img_h, img_w, _ = img_meta['img_shape']
|
| 270 |
+
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
|
| 271 |
+
img_h]).unsqueeze(0).repeat(
|
| 272 |
+
bbox_pred.size(0), 1)
|
| 273 |
+
factors.append(factor)
|
| 274 |
+
factors = torch.cat(factors, 0)
|
| 275 |
+
|
| 276 |
+
# DETR regress the relative position of boxes (cxcywh) in the image,
|
| 277 |
+
# thus the learning target is normalized by the image size. So here
|
| 278 |
+
# we need to re-scale them for calculating IoU loss
|
| 279 |
+
bbox_preds = dn_bbox_preds.reshape(-1, 4)
|
| 280 |
+
bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
|
| 281 |
+
bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
|
| 282 |
+
|
| 283 |
+
# regression IoU loss, defaultly GIoU loss
|
| 284 |
+
loss_iou = self.loss_iou(
|
| 285 |
+
bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
|
| 286 |
+
|
| 287 |
+
# regression L1 loss
|
| 288 |
+
loss_bbox = self.loss_bbox(
|
| 289 |
+
bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
|
| 290 |
+
return loss_cls, loss_bbox, loss_iou
|
| 291 |
+
|
| 292 |
+
def get_dn_target(self, dn_bbox_preds_list, gt_bboxes_list, gt_labels_list,
|
| 293 |
+
img_metas, dn_meta):
|
| 294 |
+
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
|
| 295 |
+
pos_inds_list,
|
| 296 |
+
neg_inds_list) = multi_apply(self._get_dn_target_single,
|
| 297 |
+
dn_bbox_preds_list, gt_bboxes_list,
|
| 298 |
+
gt_labels_list, img_metas, dn_meta)
|
| 299 |
+
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
|
| 300 |
+
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
|
| 301 |
+
return (labels_list, label_weights_list, bbox_targets_list,
|
| 302 |
+
bbox_weights_list, num_total_pos, num_total_neg)
|
| 303 |
+
|
| 304 |
+
def _get_dn_target_single(self, dn_bbox_pred, gt_bboxes, gt_labels,
|
| 305 |
+
img_meta, dn_meta):
|
| 306 |
+
num_groups = dn_meta['num_dn_group']
|
| 307 |
+
pad_size = dn_meta['pad_size']
|
| 308 |
+
assert pad_size % num_groups == 0
|
| 309 |
+
single_pad = pad_size // num_groups
|
| 310 |
+
num_bboxes = dn_bbox_pred.size(0)
|
| 311 |
+
|
| 312 |
+
if len(gt_labels) > 0:
|
| 313 |
+
t = torch.range(0, len(gt_labels) - 1).long().cuda()
|
| 314 |
+
t = t.unsqueeze(0).repeat(num_groups, 1)
|
| 315 |
+
pos_assigned_gt_inds = t.flatten()
|
| 316 |
+
pos_inds = (torch.tensor(range(num_groups)) *
|
| 317 |
+
single_pad).long().cuda().unsqueeze(1) + t
|
| 318 |
+
pos_inds = pos_inds.flatten()
|
| 319 |
+
else:
|
| 320 |
+
pos_inds = pos_assigned_gt_inds = torch.tensor([]).long().cuda()
|
| 321 |
+
neg_inds = pos_inds + single_pad // 2
|
| 322 |
+
|
| 323 |
+
# label targets
|
| 324 |
+
labels = gt_bboxes.new_full((num_bboxes, ),
|
| 325 |
+
self.num_classes,
|
| 326 |
+
dtype=torch.long)
|
| 327 |
+
labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
|
| 328 |
+
label_weights = gt_bboxes.new_ones(num_bboxes)
|
| 329 |
+
|
| 330 |
+
# bbox targets
|
| 331 |
+
bbox_targets = torch.zeros_like(dn_bbox_pred)
|
| 332 |
+
bbox_weights = torch.zeros_like(dn_bbox_pred)
|
| 333 |
+
bbox_weights[pos_inds] = 1.0
|
| 334 |
+
img_h, img_w, _ = img_meta['img_shape']
|
| 335 |
+
|
| 336 |
+
# DETR regress the relative position of boxes (cxcywh) in the image.
|
| 337 |
+
# Thus the learning target should be normalized by the image size, also
|
| 338 |
+
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
|
| 339 |
+
factor = dn_bbox_pred.new_tensor([img_w, img_h, img_w,
|
| 340 |
+
img_h]).unsqueeze(0)
|
| 341 |
+
gt_bboxes_normalized = gt_bboxes / factor
|
| 342 |
+
gt_bboxes_targets = bbox_xyxy_to_cxcywh(gt_bboxes_normalized)
|
| 343 |
+
bbox_targets[pos_inds] = gt_bboxes_targets.repeat([num_groups, 1])
|
| 344 |
+
|
| 345 |
+
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
|
| 346 |
+
neg_inds)
|
| 347 |
+
|
| 348 |
+
@staticmethod
|
| 349 |
+
def extract_dn_outputs(all_cls_scores, all_bbox_preds, dn_meta):
|
| 350 |
+
# if dn_meta and dn_meta['pad_size'] > 0:
|
| 351 |
+
if dn_meta is not None:
|
| 352 |
+
denoising_cls_scores = all_cls_scores[:, :, :
|
| 353 |
+
dn_meta['pad_size'], :]
|
| 354 |
+
denoising_bbox_preds = all_bbox_preds[:, :, :
|
| 355 |
+
dn_meta['pad_size'], :]
|
| 356 |
+
matching_cls_scores = all_cls_scores[:, :, dn_meta['pad_size']:, :]
|
| 357 |
+
matching_bbox_preds = all_bbox_preds[:, :, dn_meta['pad_size']:, :]
|
| 358 |
+
else:
|
| 359 |
+
denoising_cls_scores = None
|
| 360 |
+
denoising_bbox_preds = None
|
| 361 |
+
matching_cls_scores = all_cls_scores
|
| 362 |
+
matching_bbox_preds = all_bbox_preds
|
| 363 |
+
return (matching_cls_scores, matching_bbox_preds, denoising_cls_scores,
|
| 364 |
+
denoising_bbox_preds)
|