Upload 123 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- data/hyps/hyp.scratch-high.yaml +30 -0
- data/images/horses.jpg +3 -0
- data/multiplane.yaml +9 -0
- models/__init__.py +1 -0
- models/common.py +1296 -0
- models/detect/pk-yolo.yaml +126 -0
- models/detect/yolov9-e.yaml +144 -0
- models/experimental.py +275 -0
- models/repvit.py +440 -0
- models/tf.py +596 -0
- models/yolo.py +771 -0
- spark repvit/repvit_1kpretrained_timm_style.pth +3 -0
- spark/downstream_d2/README.md +101 -0
- spark/downstream_d2/configs/Base-RCNN-FPN.yaml +42 -0
- spark/downstream_d2/configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml +57 -0
- spark/downstream_d2/convert-timm-to-d2.py +43 -0
- spark/downstream_d2/lr_decay.py +132 -0
- spark/downstream_d2/train_net.py +322 -0
- spark/downstream_imagenet/README.md +54 -0
- spark/downstream_imagenet/arg.py +137 -0
- spark/downstream_imagenet/data.py +151 -0
- spark/downstream_imagenet/lr_decay.py +61 -0
- spark/downstream_imagenet/main.py +189 -0
- spark/downstream_imagenet/mixup.py +168 -0
- spark/downstream_imagenet/models/__init__.py +104 -0
- spark/downstream_imagenet/models/convnext_official.py +201 -0
- spark/downstream_imagenet/requirements.txt +5 -0
- spark/downstream_imagenet/util.py +131 -0
- spark/downstream_mmdet/README.md +76 -0
- spark/downstream_mmdet/configs/_base_/default_runtime.py +16 -0
- spark/downstream_mmdet/configs/_base_/models/cascade_mask_rcnn_convnext_fpn.py +208 -0
- spark/downstream_mmdet/configs/_base_/models/mask_rcnn_convnext_fpn.py +128 -0
- spark/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py +95 -0
- spark/downstream_mmdet/mmcv_custom/__init__.py +15 -0
- spark/downstream_mmdet/mmcv_custom/customized_text.py +130 -0
- spark/downstream_mmdet/mmcv_custom/layer_decay_optimizer_constructor.py +123 -0
- spark/downstream_mmdet/mmcv_custom/runner/checkpoint.py +85 -0
- spark/downstream_mmdet/mmdet/models/backbones/__init__.py +20 -0
- spark/downstream_mmdet/mmdet/models/backbones/convnext.py +180 -0
- spark/pretrain/README.md +118 -0
- spark/pretrain/decoder.py +74 -0
- spark/pretrain/dist.py +118 -0
- spark/pretrain/encoder.py +208 -0
- spark/pretrain/main.py +191 -0
- spark/pretrain/models/__init__.py +62 -0
- spark/pretrain/models/convnext.py +125 -0
- spark/pretrain/models/custom.py +141 -0
- spark/pretrain/models/custom_detr.py +102 -0
- spark/pretrain/models/custom_origin.py +89 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/images/horses.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
spark/pretrain/viz_imgs/recon.png filter=lfs diff=lfs merge=lfs -text
|
data/hyps/hyp.scratch-high.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
| 2 |
+
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
| 3 |
+
momentum: 0.937 # SGD momentum/Adam beta1
|
| 4 |
+
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
| 5 |
+
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
| 6 |
+
warmup_momentum: 0.8 # warmup initial momentum
|
| 7 |
+
warmup_bias_lr: 0.1 # warmup initial bias lr
|
| 8 |
+
box: 7.5 # box loss gain
|
| 9 |
+
cls: 0.5 # cls loss gain
|
| 10 |
+
cls_pw: 1.0 # cls BCELoss positive_weight
|
| 11 |
+
obj: 0.7 # obj loss gain (scale with pixels)
|
| 12 |
+
obj_pw: 1.0 # obj BCELoss positive_weight
|
| 13 |
+
dfl: 1.5 # dfl loss gain
|
| 14 |
+
iou_t: 0.20 # IoU training threshold
|
| 15 |
+
anchor_t: 5.0 # anchor-multiple threshold
|
| 16 |
+
# anchors: 3 # anchors per output layer (0 to ignore)
|
| 17 |
+
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
| 18 |
+
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
| 19 |
+
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
| 20 |
+
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
| 21 |
+
degrees: 0.0 # image rotation (+/- deg)
|
| 22 |
+
translate: 0.1 # image translation (+/- fraction)
|
| 23 |
+
scale: 0.9 # image scale (+/- gain)
|
| 24 |
+
shear: 0.0 # image shear (+/- deg)
|
| 25 |
+
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
| 26 |
+
flipud: 0.0 # image flip up-down (probability)
|
| 27 |
+
fliplr: 0.5 # image flip left-right (probability)
|
| 28 |
+
mosaic: 1.0 # image mosaic (probability)
|
| 29 |
+
mixup: 0.15 # image mixup (probability)
|
| 30 |
+
copy_paste: 0.3 # segment copy-paste (probability)
|
data/images/horses.jpg
ADDED
|
Git LFS Details
|
data/multiplane.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train: ./axial_t1wce_2_class/images/train
|
| 2 |
+
val: ./axial_t1wce_2_class/images/test
|
| 3 |
+
# train: ./coronal_t1wce_2_class/images/train
|
| 4 |
+
# val: ./coronal_t1wce_2_class/images/test
|
| 5 |
+
# train: ./sagittal_t1wce_2_class/images/train
|
| 6 |
+
# val: ./sagittal_t1wce_2_class/images/test
|
| 7 |
+
|
| 8 |
+
nc: 2
|
| 9 |
+
names: ['negative','positive']
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# init
|
models/common.py
ADDED
|
@@ -0,0 +1,1296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import contextlib
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import platform
|
| 6 |
+
import warnings
|
| 7 |
+
import zipfile
|
| 8 |
+
from collections import OrderedDict, namedtuple
|
| 9 |
+
from copy import copy
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from urllib.parse import urlparse
|
| 12 |
+
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import requests
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from IPython.display import display
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from torch.cuda import amp
|
| 24 |
+
|
| 25 |
+
from models.repvit import RepViT
|
| 26 |
+
from utils import TryExcept
|
| 27 |
+
from utils.dataloaders import exif_transpose, letterbox
|
| 28 |
+
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
| 29 |
+
increment_path, is_notebook, make_divisible, non_max_suppression, scale_boxes,
|
| 30 |
+
xywh2xyxy, xyxy2xywh, yaml_load)
|
| 31 |
+
from utils.plots import Annotator, colors, save_one_box
|
| 32 |
+
from utils.torch_utils import copy_attr, smart_inference_mode
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
| 36 |
+
# Pad to 'same' shape outputs
|
| 37 |
+
if d > 1:
|
| 38 |
+
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
| 39 |
+
if p is None:
|
| 40 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
| 41 |
+
return p
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Conv(nn.Module):
|
| 45 |
+
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
| 46 |
+
default_act = nn.SiLU() # default activation
|
| 47 |
+
|
| 48 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
| 51 |
+
self.bn = nn.BatchNorm2d(c2)
|
| 52 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
return self.act(self.bn(self.conv(x)))
|
| 56 |
+
|
| 57 |
+
def forward_fuse(self, x):
|
| 58 |
+
return self.act(self.conv(x))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class AConv(nn.Module):
|
| 62 |
+
def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.cv1 = Conv(c1, c2, 3, 2, 1)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
|
| 68 |
+
return self.cv1(x)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ADown(nn.Module):
|
| 72 |
+
def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.c = c2 // 2
|
| 75 |
+
self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
|
| 76 |
+
self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
|
| 80 |
+
x1,x2 = x.chunk(2, 1)
|
| 81 |
+
x1 = self.cv1(x1)
|
| 82 |
+
x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
|
| 83 |
+
x2 = self.cv2(x2)
|
| 84 |
+
return torch.cat((x1, x2), 1)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class RepConvN(nn.Module):
|
| 88 |
+
"""RepConv is a basic rep-style block, including training and deploy status
|
| 89 |
+
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
| 90 |
+
"""
|
| 91 |
+
default_act = nn.SiLU() # default activation
|
| 92 |
+
|
| 93 |
+
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
|
| 94 |
+
super().__init__()
|
| 95 |
+
assert k == 3 and p == 1
|
| 96 |
+
self.g = g
|
| 97 |
+
self.c1 = c1
|
| 98 |
+
self.c2 = c2
|
| 99 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
| 100 |
+
|
| 101 |
+
self.bn = None
|
| 102 |
+
self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
|
| 103 |
+
self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
|
| 104 |
+
|
| 105 |
+
def forward_fuse(self, x):
|
| 106 |
+
"""Forward process"""
|
| 107 |
+
return self.act(self.conv(x))
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
"""Forward process"""
|
| 111 |
+
id_out = 0 if self.bn is None else self.bn(x)
|
| 112 |
+
return self.act(self.conv1(x) + self.conv2(x) + id_out)
|
| 113 |
+
|
| 114 |
+
def get_equivalent_kernel_bias(self):
|
| 115 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
| 116 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
| 117 |
+
kernelid, biasid = self._fuse_bn_tensor(self.bn)
|
| 118 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
| 119 |
+
|
| 120 |
+
def _avg_to_3x3_tensor(self, avgp):
|
| 121 |
+
channels = self.c1
|
| 122 |
+
groups = self.g
|
| 123 |
+
kernel_size = avgp.kernel_size
|
| 124 |
+
input_dim = channels // groups
|
| 125 |
+
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
|
| 126 |
+
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
|
| 127 |
+
return k
|
| 128 |
+
|
| 129 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
| 130 |
+
if kernel1x1 is None:
|
| 131 |
+
return 0
|
| 132 |
+
else:
|
| 133 |
+
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
| 134 |
+
|
| 135 |
+
def _fuse_bn_tensor(self, branch):
|
| 136 |
+
if branch is None:
|
| 137 |
+
return 0, 0
|
| 138 |
+
if isinstance(branch, Conv):
|
| 139 |
+
kernel = branch.conv.weight
|
| 140 |
+
running_mean = branch.bn.running_mean
|
| 141 |
+
running_var = branch.bn.running_var
|
| 142 |
+
gamma = branch.bn.weight
|
| 143 |
+
beta = branch.bn.bias
|
| 144 |
+
eps = branch.bn.eps
|
| 145 |
+
elif isinstance(branch, nn.BatchNorm2d):
|
| 146 |
+
if not hasattr(self, 'id_tensor'):
|
| 147 |
+
input_dim = self.c1 // self.g
|
| 148 |
+
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
|
| 149 |
+
for i in range(self.c1):
|
| 150 |
+
kernel_value[i, i % input_dim, 1, 1] = 1
|
| 151 |
+
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
| 152 |
+
kernel = self.id_tensor
|
| 153 |
+
running_mean = branch.running_mean
|
| 154 |
+
running_var = branch.running_var
|
| 155 |
+
gamma = branch.weight
|
| 156 |
+
beta = branch.bias
|
| 157 |
+
eps = branch.eps
|
| 158 |
+
std = (running_var + eps).sqrt()
|
| 159 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
| 160 |
+
return kernel * t, beta - running_mean * gamma / std
|
| 161 |
+
|
| 162 |
+
def fuse_convs(self):
|
| 163 |
+
if hasattr(self, 'conv'):
|
| 164 |
+
return
|
| 165 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
| 166 |
+
self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
|
| 167 |
+
out_channels=self.conv1.conv.out_channels,
|
| 168 |
+
kernel_size=self.conv1.conv.kernel_size,
|
| 169 |
+
stride=self.conv1.conv.stride,
|
| 170 |
+
padding=self.conv1.conv.padding,
|
| 171 |
+
dilation=self.conv1.conv.dilation,
|
| 172 |
+
groups=self.conv1.conv.groups,
|
| 173 |
+
bias=True).requires_grad_(False)
|
| 174 |
+
self.conv.weight.data = kernel
|
| 175 |
+
self.conv.bias.data = bias
|
| 176 |
+
for para in self.parameters():
|
| 177 |
+
para.detach_()
|
| 178 |
+
self.__delattr__('conv1')
|
| 179 |
+
self.__delattr__('conv2')
|
| 180 |
+
if hasattr(self, 'nm'):
|
| 181 |
+
self.__delattr__('nm')
|
| 182 |
+
if hasattr(self, 'bn'):
|
| 183 |
+
self.__delattr__('bn')
|
| 184 |
+
if hasattr(self, 'id_tensor'):
|
| 185 |
+
self.__delattr__('id_tensor')
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class SP(nn.Module):
|
| 189 |
+
def __init__(self, k=3, s=1):
|
| 190 |
+
super(SP, self).__init__()
|
| 191 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
|
| 192 |
+
|
| 193 |
+
def forward(self, x):
|
| 194 |
+
return self.m(x)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class MP(nn.Module):
|
| 198 |
+
# Max pooling
|
| 199 |
+
def __init__(self, k=2):
|
| 200 |
+
super(MP, self).__init__()
|
| 201 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=k)
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
return self.m(x)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class ConvTranspose(nn.Module):
|
| 208 |
+
# Convolution transpose 2d layer
|
| 209 |
+
default_act = nn.SiLU() # default activation
|
| 210 |
+
|
| 211 |
+
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
|
| 214 |
+
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
|
| 215 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
return self.act(self.bn(self.conv_transpose(x)))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class DWConv(Conv):
|
| 222 |
+
# Depth-wise convolution
|
| 223 |
+
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
|
| 224 |
+
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class DWConvTranspose2d(nn.ConvTranspose2d):
|
| 228 |
+
# Depth-wise transpose convolution
|
| 229 |
+
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
|
| 230 |
+
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class DFL(nn.Module):
|
| 234 |
+
# DFL module
|
| 235 |
+
def __init__(self, c1=17):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
| 238 |
+
self.conv.weight.data[:] = nn.Parameter(torch.arange(c1, dtype=torch.float).view(1, c1, 1, 1)) # / 120.0
|
| 239 |
+
self.c1 = c1
|
| 240 |
+
# self.bn = nn.BatchNorm2d(4)
|
| 241 |
+
|
| 242 |
+
def forward(self, x):
|
| 243 |
+
b, c, a = x.shape # batch, channels, anchors
|
| 244 |
+
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
| 245 |
+
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class BottleneckBase(nn.Module):
|
| 249 |
+
# Standard bottleneck
|
| 250 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(1, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
| 251 |
+
super().__init__()
|
| 252 |
+
c_ = int(c2 * e) # hidden channels
|
| 253 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
| 254 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
| 255 |
+
self.add = shortcut and c1 == c2
|
| 256 |
+
|
| 257 |
+
def forward(self, x):
|
| 258 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class RBottleneckBase(nn.Module):
|
| 262 |
+
# Standard bottleneck
|
| 263 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
| 264 |
+
super().__init__()
|
| 265 |
+
c_ = int(c2 * e) # hidden channels
|
| 266 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
| 267 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
| 268 |
+
self.add = shortcut and c1 == c2
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class RepNRBottleneckBase(nn.Module):
|
| 275 |
+
# Standard bottleneck
|
| 276 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
| 277 |
+
super().__init__()
|
| 278 |
+
c_ = int(c2 * e) # hidden channels
|
| 279 |
+
self.cv1 = RepConvN(c1, c_, k[0], 1)
|
| 280 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
| 281 |
+
self.add = shortcut and c1 == c2
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Bottleneck(nn.Module):
|
| 288 |
+
# Standard bottleneck
|
| 289 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
| 290 |
+
super().__init__()
|
| 291 |
+
c_ = int(c2 * e) # hidden channels
|
| 292 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
| 293 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
| 294 |
+
self.add = shortcut and c1 == c2
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class RepNBottleneck(nn.Module):
|
| 301 |
+
# Standard bottleneck
|
| 302 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
| 303 |
+
super().__init__()
|
| 304 |
+
c_ = int(c2 * e) # hidden channels
|
| 305 |
+
self.cv1 = RepConvN(c1, c_, k[0], 1)
|
| 306 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
| 307 |
+
self.add = shortcut and c1 == c2
|
| 308 |
+
|
| 309 |
+
def forward(self, x):
|
| 310 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
| 311 |
+
|
| 312 |
+
class Backbone(nn.Module):
|
| 313 |
+
def __init__(self):
|
| 314 |
+
super(Backbone, self).__init__()
|
| 315 |
+
self.cfgs = [
|
| 316 |
+
# k, t, c, SE, HS, s
|
| 317 |
+
[3, 2, 64 * 2, 1, 0, 1],
|
| 318 |
+
[3, 2, 64 * 2, 0, 0, 1],
|
| 319 |
+
[3, 2, 64 * 2, 1, 0, 1],
|
| 320 |
+
[3, 2, 64 * 2, 0, 0, 1],
|
| 321 |
+
[3, 2, 64 * 2, 0, 0, 1],
|
| 322 |
+
[3, 2, 128 * 2, 0, 0, 2],
|
| 323 |
+
[3, 2, 128 * 2, 1, 0, 1],
|
| 324 |
+
[3, 2, 128 * 2, 0, 0, 1],
|
| 325 |
+
[3, 2, 128 * 2, 1, 0, 1],
|
| 326 |
+
[3, 2, 128 * 2, 0, 0, 1],
|
| 327 |
+
[3, 2, 128 * 2, 0, 0, 1],
|
| 328 |
+
[3, 2, 256 * 2, 0, 1, 2],
|
| 329 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 330 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 331 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 332 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 333 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 334 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 335 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 336 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 337 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 338 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 339 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 340 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 341 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 342 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 343 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 344 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 345 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 346 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 347 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 348 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 349 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 350 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 351 |
+
[3, 2, 256 * 2, 1, 1, 1],
|
| 352 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 353 |
+
[3, 2, 256 * 2, 0, 1, 1],
|
| 354 |
+
[3, 2, 512 * 2, 0, 1, 2],
|
| 355 |
+
[3, 2, 512 * 2, 1, 1, 1],
|
| 356 |
+
[3, 2, 512 * 2, 0, 1, 1],
|
| 357 |
+
[3, 2, 512 * 2, 1, 1, 1],
|
| 358 |
+
[3, 2, 512 * 2, 0, 1, 1]
|
| 359 |
+
]
|
| 360 |
+
self.backbone = RepViT(self.cfgs )
|
| 361 |
+
|
| 362 |
+
def forward(self, x):
|
| 363 |
+
outputs = self.backbone (x)
|
| 364 |
+
return outputs
|
| 365 |
+
class Down0(nn.Module):
|
| 366 |
+
def __init__(self,inp):
|
| 367 |
+
super(Down0, self).__init__()
|
| 368 |
+
|
| 369 |
+
def forward(self, x):
|
| 370 |
+
return x[0]
|
| 371 |
+
class Down1(nn.Module):
|
| 372 |
+
def __init__(self,inp):
|
| 373 |
+
super(Down1, self).__init__()
|
| 374 |
+
|
| 375 |
+
def forward(self, x):
|
| 376 |
+
return x[1]
|
| 377 |
+
class Down2(nn.Module):
|
| 378 |
+
def __init__(self,inp):
|
| 379 |
+
super(Down2, self).__init__()
|
| 380 |
+
|
| 381 |
+
def forward(self, x):
|
| 382 |
+
return x[2]
|
| 383 |
+
class Down3(nn.Module):
|
| 384 |
+
def __init__(self,inp):
|
| 385 |
+
super(Down3, self).__init__()
|
| 386 |
+
|
| 387 |
+
def forward(self, x):
|
| 388 |
+
return x[3]
|
| 389 |
+
|
| 390 |
+
class Down4(nn.Module):
|
| 391 |
+
def __init__(self,inp):
|
| 392 |
+
super(Down4, self).__init__()
|
| 393 |
+
|
| 394 |
+
def forward(self, x):
|
| 395 |
+
return x[4]
|
| 396 |
+
class Res(nn.Module):
|
| 397 |
+
# ResNet bottleneck
|
| 398 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
| 399 |
+
super(Res, self).__init__()
|
| 400 |
+
c_ = int(c2 * e) # hidden channels
|
| 401 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 402 |
+
self.cv2 = Conv(c_, c_, 3, 1, g=g)
|
| 403 |
+
self.cv3 = Conv(c_, c2, 1, 1)
|
| 404 |
+
self.add = shortcut and c1 == c2
|
| 405 |
+
|
| 406 |
+
def forward(self, x):
|
| 407 |
+
return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class RepNRes(nn.Module):
|
| 411 |
+
# ResNet bottleneck
|
| 412 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
| 413 |
+
super(RepNRes, self).__init__()
|
| 414 |
+
c_ = int(c2 * e) # hidden channels
|
| 415 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 416 |
+
self.cv2 = RepConvN(c_, c_, 3, 1, g=g)
|
| 417 |
+
self.cv3 = Conv(c_, c2, 1, 1)
|
| 418 |
+
self.add = shortcut and c1 == c2
|
| 419 |
+
|
| 420 |
+
def forward(self, x):
|
| 421 |
+
return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class BottleneckCSP(nn.Module):
|
| 425 |
+
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
| 426 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
| 427 |
+
super().__init__()
|
| 428 |
+
c_ = int(c2 * e) # hidden channels
|
| 429 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 430 |
+
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
| 431 |
+
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
| 432 |
+
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
| 433 |
+
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
| 434 |
+
self.act = nn.SiLU()
|
| 435 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
| 436 |
+
|
| 437 |
+
def forward(self, x):
|
| 438 |
+
y1 = self.cv3(self.m(self.cv1(x)))
|
| 439 |
+
y2 = self.cv2(x)
|
| 440 |
+
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class CSP(nn.Module):
|
| 444 |
+
# CSP Bottleneck with 3 convolutions
|
| 445 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
| 446 |
+
super().__init__()
|
| 447 |
+
c_ = int(c2 * e) # hidden channels
|
| 448 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 449 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 450 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
| 451 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
| 452 |
+
|
| 453 |
+
def forward(self, x):
|
| 454 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class RepNCSP(nn.Module):
|
| 458 |
+
# CSP Bottleneck with 3 convolutions
|
| 459 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
| 460 |
+
super().__init__()
|
| 461 |
+
c_ = int(c2 * e) # hidden channels
|
| 462 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 463 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 464 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
| 465 |
+
self.m = nn.Sequential(*(RepNBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
| 466 |
+
|
| 467 |
+
def forward(self, x):
|
| 468 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class CSPBase(nn.Module):
|
| 472 |
+
# CSP Bottleneck with 3 convolutions
|
| 473 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
| 474 |
+
super().__init__()
|
| 475 |
+
c_ = int(c2 * e) # hidden channels
|
| 476 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 477 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 478 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
| 479 |
+
self.m = nn.Sequential(*(BottleneckBase(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
| 480 |
+
|
| 481 |
+
def forward(self, x):
|
| 482 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class SPP(nn.Module):
|
| 486 |
+
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
|
| 487 |
+
def __init__(self, c1, c2, k=(5, 9, 13)):
|
| 488 |
+
super().__init__()
|
| 489 |
+
c_ = c1 // 2 # hidden channels
|
| 490 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 491 |
+
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
| 492 |
+
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
| 493 |
+
|
| 494 |
+
def forward(self, x):
|
| 495 |
+
x = self.cv1(x)
|
| 496 |
+
with warnings.catch_warnings():
|
| 497 |
+
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
| 498 |
+
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class ASPP(torch.nn.Module):
|
| 502 |
+
|
| 503 |
+
def __init__(self, in_channels, out_channels):
|
| 504 |
+
super().__init__()
|
| 505 |
+
kernel_sizes = [1, 3, 3, 1]
|
| 506 |
+
dilations = [1, 3, 6, 1]
|
| 507 |
+
paddings = [0, 3, 6, 0]
|
| 508 |
+
self.aspp = torch.nn.ModuleList()
|
| 509 |
+
for aspp_idx in range(len(kernel_sizes)):
|
| 510 |
+
conv = torch.nn.Conv2d(
|
| 511 |
+
in_channels,
|
| 512 |
+
out_channels,
|
| 513 |
+
kernel_size=kernel_sizes[aspp_idx],
|
| 514 |
+
stride=1,
|
| 515 |
+
dilation=dilations[aspp_idx],
|
| 516 |
+
padding=paddings[aspp_idx],
|
| 517 |
+
bias=True)
|
| 518 |
+
self.aspp.append(conv)
|
| 519 |
+
self.gap = torch.nn.AdaptiveAvgPool2d(1)
|
| 520 |
+
self.aspp_num = len(kernel_sizes)
|
| 521 |
+
for m in self.modules():
|
| 522 |
+
if isinstance(m, torch.nn.Conv2d):
|
| 523 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 524 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 525 |
+
m.bias.data.fill_(0)
|
| 526 |
+
|
| 527 |
+
def forward(self, x):
|
| 528 |
+
avg_x = self.gap(x)
|
| 529 |
+
out = []
|
| 530 |
+
for aspp_idx in range(self.aspp_num):
|
| 531 |
+
inp = avg_x if (aspp_idx == self.aspp_num - 1) else x
|
| 532 |
+
out.append(F.relu_(self.aspp[aspp_idx](inp)))
|
| 533 |
+
out[-1] = out[-1].expand_as(out[-2])
|
| 534 |
+
out = torch.cat(out, dim=1)
|
| 535 |
+
return out
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class SPPCSPC(nn.Module):
|
| 539 |
+
# CSP SPP https://github.com/WongKinYiu/CrossStagePartialNetworks
|
| 540 |
+
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
|
| 541 |
+
super(SPPCSPC, self).__init__()
|
| 542 |
+
c_ = int(2 * c2 * e) # hidden channels
|
| 543 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 544 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 545 |
+
self.cv3 = Conv(c_, c_, 3, 1)
|
| 546 |
+
self.cv4 = Conv(c_, c_, 1, 1)
|
| 547 |
+
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
| 548 |
+
self.cv5 = Conv(4 * c_, c_, 1, 1)
|
| 549 |
+
self.cv6 = Conv(c_, c_, 3, 1)
|
| 550 |
+
self.cv7 = Conv(2 * c_, c2, 1, 1)
|
| 551 |
+
|
| 552 |
+
def forward(self, x):
|
| 553 |
+
x1 = self.cv4(self.cv3(self.cv1(x)))
|
| 554 |
+
y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
|
| 555 |
+
y2 = self.cv2(x)
|
| 556 |
+
return self.cv7(torch.cat((y1, y2), dim=1))
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class SPPF(nn.Module):
|
| 560 |
+
# Spatial Pyramid Pooling - Fast (SPPF) layer by Glenn Jocher
|
| 561 |
+
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
| 562 |
+
super().__init__()
|
| 563 |
+
c_ = c1 // 2 # hidden channels
|
| 564 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 565 |
+
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
| 566 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
| 567 |
+
# self.m = SoftPool2d(kernel_size=k, stride=1, padding=k // 2)
|
| 568 |
+
|
| 569 |
+
def forward(self, x):
|
| 570 |
+
x = self.cv1(x)
|
| 571 |
+
with warnings.catch_warnings():
|
| 572 |
+
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
| 573 |
+
y1 = self.m(x)
|
| 574 |
+
y2 = self.m(y1)
|
| 575 |
+
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
import torch.nn.functional as F
|
| 579 |
+
from torch.nn.modules.utils import _pair
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
class ReOrg(nn.Module):
|
| 583 |
+
# yolo
|
| 584 |
+
def __init__(self):
|
| 585 |
+
super(ReOrg, self).__init__()
|
| 586 |
+
|
| 587 |
+
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
| 588 |
+
return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class Contract(nn.Module):
|
| 592 |
+
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
|
| 593 |
+
def __init__(self, gain=2):
|
| 594 |
+
super().__init__()
|
| 595 |
+
self.gain = gain
|
| 596 |
+
|
| 597 |
+
def forward(self, x):
|
| 598 |
+
b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
|
| 599 |
+
s = self.gain
|
| 600 |
+
x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
|
| 601 |
+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
|
| 602 |
+
return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
class Expand(nn.Module):
|
| 606 |
+
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
|
| 607 |
+
def __init__(self, gain=2):
|
| 608 |
+
super().__init__()
|
| 609 |
+
self.gain = gain
|
| 610 |
+
|
| 611 |
+
def forward(self, x):
|
| 612 |
+
b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
|
| 613 |
+
s = self.gain
|
| 614 |
+
x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
|
| 615 |
+
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
|
| 616 |
+
return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
class Concat(nn.Module):
|
| 620 |
+
# Concatenate a list of tensors along dimension
|
| 621 |
+
def __init__(self, dimension=1):
|
| 622 |
+
super().__init__()
|
| 623 |
+
self.d = dimension
|
| 624 |
+
|
| 625 |
+
def forward(self, x):
|
| 626 |
+
return torch.cat(x, self.d)
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class Shortcut(nn.Module):
|
| 630 |
+
def __init__(self, dimension=0):
|
| 631 |
+
super(Shortcut, self).__init__()
|
| 632 |
+
self.d = dimension
|
| 633 |
+
|
| 634 |
+
def forward(self, x):
|
| 635 |
+
return x[0]+x[1]
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class Silence(nn.Module):
|
| 639 |
+
def __init__(self):
|
| 640 |
+
super(Silence, self).__init__()
|
| 641 |
+
def forward(self, x):
|
| 642 |
+
return x
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
##### GELAN #####
|
| 646 |
+
|
| 647 |
+
class SPPELAN(nn.Module):
|
| 648 |
+
# spp-elan
|
| 649 |
+
def __init__(self, c1, c2, c3): # ch_in, ch_out, number, shortcut, groups, expansion
|
| 650 |
+
super().__init__()
|
| 651 |
+
self.c = c3
|
| 652 |
+
self.cv1 = Conv(c1, c3, 1, 1)
|
| 653 |
+
self.cv2 = SP(5)
|
| 654 |
+
self.cv3 = SP(5)
|
| 655 |
+
self.cv4 = SP(5)
|
| 656 |
+
self.cv5 = Conv(4*c3, c2, 1, 1)
|
| 657 |
+
|
| 658 |
+
def forward(self, x):
|
| 659 |
+
y = [self.cv1(x)]
|
| 660 |
+
y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
|
| 661 |
+
return self.cv5(torch.cat(y, 1))
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class RepNCSPELAN4(nn.Module):
|
| 665 |
+
# csp-elan
|
| 666 |
+
def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
|
| 667 |
+
super().__init__()
|
| 668 |
+
self.c = c3//2
|
| 669 |
+
self.cv1 = Conv(c1, c3, 1, 1)
|
| 670 |
+
self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), Conv(c4, c4, 3, 1))
|
| 671 |
+
self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
|
| 672 |
+
self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
|
| 673 |
+
|
| 674 |
+
def forward(self, x):
|
| 675 |
+
y = list(self.cv1(x).chunk(2, 1))
|
| 676 |
+
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
|
| 677 |
+
return self.cv4(torch.cat(y, 1))
|
| 678 |
+
|
| 679 |
+
def forward_split(self, x):
|
| 680 |
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
| 681 |
+
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
| 682 |
+
return self.cv4(torch.cat(y, 1))
|
| 683 |
+
|
| 684 |
+
#################
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
##### YOLOR #####
|
| 688 |
+
|
| 689 |
+
class ImplicitA(nn.Module):
|
| 690 |
+
def __init__(self, channel):
|
| 691 |
+
super(ImplicitA, self).__init__()
|
| 692 |
+
self.channel = channel
|
| 693 |
+
self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
| 694 |
+
nn.init.normal_(self.implicit, std=.02)
|
| 695 |
+
|
| 696 |
+
def forward(self, x):
|
| 697 |
+
return self.implicit + x
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class ImplicitM(nn.Module):
|
| 701 |
+
def __init__(self, channel):
|
| 702 |
+
super(ImplicitM, self).__init__()
|
| 703 |
+
self.channel = channel
|
| 704 |
+
self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
|
| 705 |
+
nn.init.normal_(self.implicit, mean=1., std=.02)
|
| 706 |
+
|
| 707 |
+
def forward(self, x):
|
| 708 |
+
return self.implicit * x
|
| 709 |
+
|
| 710 |
+
#################
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
##### CBNet #####
|
| 714 |
+
|
| 715 |
+
class CBLinear(nn.Module):
|
| 716 |
+
def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
|
| 717 |
+
super(CBLinear, self).__init__()
|
| 718 |
+
self.c2s = c2s
|
| 719 |
+
self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
|
| 720 |
+
|
| 721 |
+
def forward(self, x):
|
| 722 |
+
outs = self.conv(x).split(self.c2s, dim=1)
|
| 723 |
+
return outs
|
| 724 |
+
|
| 725 |
+
class CBFuse(nn.Module):
|
| 726 |
+
def __init__(self, idx):
|
| 727 |
+
super(CBFuse, self).__init__()
|
| 728 |
+
self.idx = idx
|
| 729 |
+
|
| 730 |
+
def forward(self, xs):
|
| 731 |
+
target_size = xs[-1].shape[2:]
|
| 732 |
+
res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
|
| 733 |
+
out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
|
| 734 |
+
return out
|
| 735 |
+
|
| 736 |
+
#################
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class DetectMultiBackend(nn.Module):
|
| 740 |
+
# YOLO MultiBackend class for python inference on various backends
|
| 741 |
+
def __init__(self, weights='yolo.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
|
| 742 |
+
# Usage:
|
| 743 |
+
# PyTorch: weights = *.pt
|
| 744 |
+
# TorchScript: *.torchscript
|
| 745 |
+
# ONNX Runtime: *.onnx
|
| 746 |
+
# ONNX OpenCV DNN: *.onnx --dnn
|
| 747 |
+
# OpenVINO: *_openvino_model
|
| 748 |
+
# CoreML: *.mlmodel
|
| 749 |
+
# TensorRT: *.engine
|
| 750 |
+
# TensorFlow SavedModel: *_saved_model
|
| 751 |
+
# TensorFlow GraphDef: *.pb
|
| 752 |
+
# TensorFlow Lite: *.tflite
|
| 753 |
+
# TensorFlow Edge TPU: *_edgetpu.tflite
|
| 754 |
+
# PaddlePaddle: *_paddle_model
|
| 755 |
+
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
|
| 756 |
+
|
| 757 |
+
super().__init__()
|
| 758 |
+
w = str(weights[0] if isinstance(weights, list) else weights)
|
| 759 |
+
pt, jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
| 760 |
+
fp16 &= pt or jit or onnx or engine # FP16
|
| 761 |
+
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
| 762 |
+
stride = 32 # default stride
|
| 763 |
+
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
| 764 |
+
if not (pt or triton):
|
| 765 |
+
w = attempt_download(w) # download if not local
|
| 766 |
+
|
| 767 |
+
if pt: # PyTorch
|
| 768 |
+
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
| 769 |
+
stride = max(int(model.stride.max()), 32) # model stride
|
| 770 |
+
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
| 771 |
+
model.half() if fp16 else model.float()
|
| 772 |
+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
| 773 |
+
elif jit: # TorchScript
|
| 774 |
+
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
| 775 |
+
extra_files = {'config.txt': ''} # model metadata
|
| 776 |
+
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
| 777 |
+
model.half() if fp16 else model.float()
|
| 778 |
+
if extra_files['config.txt']: # load metadata dict
|
| 779 |
+
d = json.loads(extra_files['config.txt'],
|
| 780 |
+
object_hook=lambda d: {int(k) if k.isdigit() else k: v
|
| 781 |
+
for k, v in d.items()})
|
| 782 |
+
stride, names = int(d['stride']), d['names']
|
| 783 |
+
elif dnn: # ONNX OpenCV DNN
|
| 784 |
+
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
| 785 |
+
check_requirements('opencv-python>=4.5.4')
|
| 786 |
+
net = cv2.dnn.readNetFromONNX(w)
|
| 787 |
+
elif onnx: # ONNX Runtime
|
| 788 |
+
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
| 789 |
+
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
| 790 |
+
import onnxruntime
|
| 791 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
| 792 |
+
session = onnxruntime.InferenceSession(w, providers=providers)
|
| 793 |
+
output_names = [x.name for x in session.get_outputs()]
|
| 794 |
+
meta = session.get_modelmeta().custom_metadata_map # metadata
|
| 795 |
+
if 'stride' in meta:
|
| 796 |
+
stride, names = int(meta['stride']), eval(meta['names'])
|
| 797 |
+
elif xml: # OpenVINO
|
| 798 |
+
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
| 799 |
+
check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
| 800 |
+
from openvino.runtime import Core, Layout, get_batch
|
| 801 |
+
ie = Core()
|
| 802 |
+
if not Path(w).is_file(): # if not *.xml
|
| 803 |
+
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
| 804 |
+
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
| 805 |
+
if network.get_parameters()[0].get_layout().empty:
|
| 806 |
+
network.get_parameters()[0].set_layout(Layout("NCHW"))
|
| 807 |
+
batch_dim = get_batch(network)
|
| 808 |
+
if batch_dim.is_static:
|
| 809 |
+
batch_size = batch_dim.get_length()
|
| 810 |
+
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
| 811 |
+
stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
|
| 812 |
+
elif engine: # TensorRT
|
| 813 |
+
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
| 814 |
+
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
| 815 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
| 816 |
+
if device.type == 'cpu':
|
| 817 |
+
device = torch.device('cuda:0')
|
| 818 |
+
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
| 819 |
+
logger = trt.Logger(trt.Logger.INFO)
|
| 820 |
+
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
| 821 |
+
model = runtime.deserialize_cuda_engine(f.read())
|
| 822 |
+
context = model.create_execution_context()
|
| 823 |
+
bindings = OrderedDict()
|
| 824 |
+
output_names = []
|
| 825 |
+
fp16 = False # default updated below
|
| 826 |
+
dynamic = False
|
| 827 |
+
for i in range(model.num_bindings):
|
| 828 |
+
name = model.get_binding_name(i)
|
| 829 |
+
dtype = trt.nptype(model.get_binding_dtype(i))
|
| 830 |
+
if model.binding_is_input(i):
|
| 831 |
+
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
| 832 |
+
dynamic = True
|
| 833 |
+
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
| 834 |
+
if dtype == np.float16:
|
| 835 |
+
fp16 = True
|
| 836 |
+
else: # output
|
| 837 |
+
output_names.append(name)
|
| 838 |
+
shape = tuple(context.get_binding_shape(i))
|
| 839 |
+
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
| 840 |
+
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
| 841 |
+
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
| 842 |
+
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
| 843 |
+
elif coreml: # CoreML
|
| 844 |
+
LOGGER.info(f'Loading {w} for CoreML inference...')
|
| 845 |
+
import coremltools as ct
|
| 846 |
+
model = ct.models.MLModel(w)
|
| 847 |
+
elif saved_model: # TF SavedModel
|
| 848 |
+
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
| 849 |
+
import tensorflow as tf
|
| 850 |
+
keras = False # assume TF1 saved_model
|
| 851 |
+
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
| 852 |
+
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
| 853 |
+
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
| 854 |
+
import tensorflow as tf
|
| 855 |
+
|
| 856 |
+
def wrap_frozen_graph(gd, inputs, outputs):
|
| 857 |
+
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
| 858 |
+
ge = x.graph.as_graph_element
|
| 859 |
+
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
| 860 |
+
|
| 861 |
+
def gd_outputs(gd):
|
| 862 |
+
name_list, input_list = [], []
|
| 863 |
+
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
| 864 |
+
name_list.append(node.name)
|
| 865 |
+
input_list.extend(node.input)
|
| 866 |
+
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
| 867 |
+
|
| 868 |
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
| 869 |
+
with open(w, 'rb') as f:
|
| 870 |
+
gd.ParseFromString(f.read())
|
| 871 |
+
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
| 872 |
+
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
| 873 |
+
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
| 874 |
+
from tflite_runtime.interpreter import Interpreter, load_delegate
|
| 875 |
+
except ImportError:
|
| 876 |
+
import tensorflow as tf
|
| 877 |
+
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
| 878 |
+
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
| 879 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
| 880 |
+
delegate = {
|
| 881 |
+
'Linux': 'libedgetpu.so.1',
|
| 882 |
+
'Darwin': 'libedgetpu.1.dylib',
|
| 883 |
+
'Windows': 'edgetpu.dll'}[platform.system()]
|
| 884 |
+
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
| 885 |
+
else: # TFLite
|
| 886 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
| 887 |
+
interpreter = Interpreter(model_path=w) # load TFLite model
|
| 888 |
+
interpreter.allocate_tensors() # allocate
|
| 889 |
+
input_details = interpreter.get_input_details() # inputs
|
| 890 |
+
output_details = interpreter.get_output_details() # outputs
|
| 891 |
+
# load metadata
|
| 892 |
+
with contextlib.suppress(zipfile.BadZipFile):
|
| 893 |
+
with zipfile.ZipFile(w, "r") as model:
|
| 894 |
+
meta_file = model.namelist()[0]
|
| 895 |
+
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
| 896 |
+
stride, names = int(meta['stride']), meta['names']
|
| 897 |
+
elif tfjs: # TF.js
|
| 898 |
+
raise NotImplementedError('ERROR: YOLO TF.js inference is not supported')
|
| 899 |
+
elif paddle: # PaddlePaddle
|
| 900 |
+
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
| 901 |
+
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
| 902 |
+
import paddle.inference as pdi
|
| 903 |
+
if not Path(w).is_file(): # if not *.pdmodel
|
| 904 |
+
w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
| 905 |
+
weights = Path(w).with_suffix('.pdiparams')
|
| 906 |
+
config = pdi.Config(str(w), str(weights))
|
| 907 |
+
if cuda:
|
| 908 |
+
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
| 909 |
+
predictor = pdi.create_predictor(config)
|
| 910 |
+
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
| 911 |
+
output_names = predictor.get_output_names()
|
| 912 |
+
elif triton: # NVIDIA Triton Inference Server
|
| 913 |
+
LOGGER.info(f'Using {w} as Triton Inference Server...')
|
| 914 |
+
check_requirements('tritonclient[all]')
|
| 915 |
+
from utils.triton import TritonRemoteModel
|
| 916 |
+
model = TritonRemoteModel(url=w)
|
| 917 |
+
nhwc = model.runtime.startswith("tensorflow")
|
| 918 |
+
else:
|
| 919 |
+
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
| 920 |
+
|
| 921 |
+
# class names
|
| 922 |
+
if 'names' not in locals():
|
| 923 |
+
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
|
| 924 |
+
if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
|
| 925 |
+
names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
|
| 926 |
+
|
| 927 |
+
self.__dict__.update(locals()) # assign all variables to self
|
| 928 |
+
|
| 929 |
+
def forward(self, im, augment=False, visualize=False):
|
| 930 |
+
# YOLO MultiBackend inference
|
| 931 |
+
b, ch, h, w = im.shape # batch, channel, height, width
|
| 932 |
+
if self.fp16 and im.dtype != torch.float16:
|
| 933 |
+
im = im.half() # to FP16
|
| 934 |
+
if self.nhwc:
|
| 935 |
+
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
| 936 |
+
|
| 937 |
+
if self.pt: # PyTorch
|
| 938 |
+
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
| 939 |
+
elif self.jit: # TorchScript
|
| 940 |
+
y = self.model(im)
|
| 941 |
+
elif self.dnn: # ONNX OpenCV DNN
|
| 942 |
+
im = im.cpu().numpy() # torch to numpy
|
| 943 |
+
self.net.setInput(im)
|
| 944 |
+
y = self.net.forward()
|
| 945 |
+
elif self.onnx: # ONNX Runtime
|
| 946 |
+
im = im.cpu().numpy() # torch to numpy
|
| 947 |
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
| 948 |
+
elif self.xml: # OpenVINO
|
| 949 |
+
im = im.cpu().numpy() # FP32
|
| 950 |
+
y = list(self.executable_network([im]).values())
|
| 951 |
+
elif self.engine: # TensorRT
|
| 952 |
+
if self.dynamic and im.shape != self.bindings['images'].shape:
|
| 953 |
+
i = self.model.get_binding_index('images')
|
| 954 |
+
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
| 955 |
+
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
| 956 |
+
for name in self.output_names:
|
| 957 |
+
i = self.model.get_binding_index(name)
|
| 958 |
+
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
| 959 |
+
s = self.bindings['images'].shape
|
| 960 |
+
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
| 961 |
+
self.binding_addrs['images'] = int(im.data_ptr())
|
| 962 |
+
self.context.execute_v2(list(self.binding_addrs.values()))
|
| 963 |
+
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
| 964 |
+
elif self.coreml: # CoreML
|
| 965 |
+
im = im.cpu().numpy()
|
| 966 |
+
im = Image.fromarray((im[0] * 255).astype('uint8'))
|
| 967 |
+
# im = im.resize((192, 320), Image.ANTIALIAS)
|
| 968 |
+
y = self.model.predict({'image': im}) # coordinates are xywh normalized
|
| 969 |
+
if 'confidence' in y:
|
| 970 |
+
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
| 971 |
+
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
| 972 |
+
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
| 973 |
+
else:
|
| 974 |
+
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
| 975 |
+
elif self.paddle: # PaddlePaddle
|
| 976 |
+
im = im.cpu().numpy().astype(np.float32)
|
| 977 |
+
self.input_handle.copy_from_cpu(im)
|
| 978 |
+
self.predictor.run()
|
| 979 |
+
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
| 980 |
+
elif self.triton: # NVIDIA Triton Inference Server
|
| 981 |
+
y = self.model(im)
|
| 982 |
+
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
| 983 |
+
im = im.cpu().numpy()
|
| 984 |
+
if self.saved_model: # SavedModel
|
| 985 |
+
y = self.model(im, training=False) if self.keras else self.model(im)
|
| 986 |
+
elif self.pb: # GraphDef
|
| 987 |
+
y = self.frozen_func(x=self.tf.constant(im))
|
| 988 |
+
else: # Lite or Edge TPU
|
| 989 |
+
input = self.input_details[0]
|
| 990 |
+
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
|
| 991 |
+
if int8:
|
| 992 |
+
scale, zero_point = input['quantization']
|
| 993 |
+
im = (im / scale + zero_point).astype(np.uint8) # de-scale
|
| 994 |
+
self.interpreter.set_tensor(input['index'], im)
|
| 995 |
+
self.interpreter.invoke()
|
| 996 |
+
y = []
|
| 997 |
+
for output in self.output_details:
|
| 998 |
+
x = self.interpreter.get_tensor(output['index'])
|
| 999 |
+
if int8:
|
| 1000 |
+
scale, zero_point = output['quantization']
|
| 1001 |
+
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
| 1002 |
+
y.append(x)
|
| 1003 |
+
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
| 1004 |
+
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
| 1005 |
+
|
| 1006 |
+
if isinstance(y, (list, tuple)):
|
| 1007 |
+
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
| 1008 |
+
else:
|
| 1009 |
+
return self.from_numpy(y)
|
| 1010 |
+
|
| 1011 |
+
def from_numpy(self, x):
|
| 1012 |
+
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
|
| 1013 |
+
|
| 1014 |
+
def warmup(self, imgsz=(1, 3, 640, 640)):
|
| 1015 |
+
# Warmup model by running inference once
|
| 1016 |
+
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
|
| 1017 |
+
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
| 1018 |
+
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
| 1019 |
+
for _ in range(2 if self.jit else 1): #
|
| 1020 |
+
self.forward(im) # warmup
|
| 1021 |
+
|
| 1022 |
+
@staticmethod
|
| 1023 |
+
def _model_type(p='path/to/model.pt'):
|
| 1024 |
+
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
| 1025 |
+
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
| 1026 |
+
from export import export_formats
|
| 1027 |
+
from utils.downloads import is_url
|
| 1028 |
+
sf = list(export_formats().Suffix) # export suffixes
|
| 1029 |
+
if not is_url(p, check=False):
|
| 1030 |
+
check_suffix(p, sf) # checks
|
| 1031 |
+
url = urlparse(p) # if url may be Triton inference server
|
| 1032 |
+
types = [s in Path(p).name for s in sf]
|
| 1033 |
+
types[8] &= not types[9] # tflite &= not edgetpu
|
| 1034 |
+
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
|
| 1035 |
+
return types + [triton]
|
| 1036 |
+
|
| 1037 |
+
@staticmethod
|
| 1038 |
+
def _load_metadata(f=Path('path/to/meta.yaml')):
|
| 1039 |
+
# Load metadata from meta.yaml if it exists
|
| 1040 |
+
if f.exists():
|
| 1041 |
+
d = yaml_load(f)
|
| 1042 |
+
return d['stride'], d['names'] # assign stride, names
|
| 1043 |
+
return None, None
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
class AutoShape(nn.Module):
|
| 1047 |
+
# YOLO input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
| 1048 |
+
conf = 0.25 # NMS confidence threshold
|
| 1049 |
+
iou = 0.45 # NMS IoU threshold
|
| 1050 |
+
agnostic = False # NMS class-agnostic
|
| 1051 |
+
multi_label = False # NMS multiple labels per box
|
| 1052 |
+
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
| 1053 |
+
max_det = 1000 # maximum number of detections per image
|
| 1054 |
+
amp = False # Automatic Mixed Precision (AMP) inference
|
| 1055 |
+
|
| 1056 |
+
def __init__(self, model, verbose=True):
|
| 1057 |
+
super().__init__()
|
| 1058 |
+
if verbose:
|
| 1059 |
+
LOGGER.info('Adding AutoShape... ')
|
| 1060 |
+
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
| 1061 |
+
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
|
| 1062 |
+
self.pt = not self.dmb or model.pt # PyTorch model
|
| 1063 |
+
self.model = model.eval()
|
| 1064 |
+
if self.pt:
|
| 1065 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
| 1066 |
+
m.inplace = False # Detect.inplace=False for safe multithread inference
|
| 1067 |
+
m.export = True # do not output loss values
|
| 1068 |
+
|
| 1069 |
+
def _apply(self, fn):
|
| 1070 |
+
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
| 1071 |
+
self = super()._apply(fn)
|
| 1072 |
+
from models.yolo import Detect, Segment
|
| 1073 |
+
if self.pt:
|
| 1074 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
| 1075 |
+
if isinstance(m, (Detect, Segment)):
|
| 1076 |
+
for k in 'stride', 'anchor_grid', 'stride_grid', 'grid':
|
| 1077 |
+
x = getattr(m, k)
|
| 1078 |
+
setattr(m, k, list(map(fn, x))) if isinstance(x, (list, tuple)) else setattr(m, k, fn(x))
|
| 1079 |
+
return self
|
| 1080 |
+
|
| 1081 |
+
@smart_inference_mode()
|
| 1082 |
+
def forward(self, ims, size=640, augment=False, profile=False):
|
| 1083 |
+
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
|
| 1084 |
+
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
| 1085 |
+
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
| 1086 |
+
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
| 1087 |
+
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
| 1088 |
+
# numpy: = np.zeros((640,1280,3)) # HWC
|
| 1089 |
+
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
| 1090 |
+
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
| 1091 |
+
|
| 1092 |
+
dt = (Profile(), Profile(), Profile())
|
| 1093 |
+
with dt[0]:
|
| 1094 |
+
if isinstance(size, int): # expand
|
| 1095 |
+
size = (size, size)
|
| 1096 |
+
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
| 1097 |
+
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
| 1098 |
+
if isinstance(ims, torch.Tensor): # torch
|
| 1099 |
+
with amp.autocast(autocast):
|
| 1100 |
+
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
| 1101 |
+
|
| 1102 |
+
# Pre-process
|
| 1103 |
+
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
| 1104 |
+
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
| 1105 |
+
for i, im in enumerate(ims):
|
| 1106 |
+
f = f'image{i}' # filename
|
| 1107 |
+
if isinstance(im, (str, Path)): # filename or uri
|
| 1108 |
+
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
| 1109 |
+
im = np.asarray(exif_transpose(im))
|
| 1110 |
+
elif isinstance(im, Image.Image): # PIL Image
|
| 1111 |
+
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
| 1112 |
+
files.append(Path(f).with_suffix('.jpg').name)
|
| 1113 |
+
if im.shape[0] < 5: # image in CHW
|
| 1114 |
+
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
| 1115 |
+
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
| 1116 |
+
s = im.shape[:2] # HWC
|
| 1117 |
+
shape0.append(s) # image shape
|
| 1118 |
+
g = max(size) / max(s) # gain
|
| 1119 |
+
shape1.append([int(y * g) for y in s])
|
| 1120 |
+
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
| 1121 |
+
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
|
| 1122 |
+
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
|
| 1123 |
+
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
| 1124 |
+
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
| 1125 |
+
|
| 1126 |
+
with amp.autocast(autocast):
|
| 1127 |
+
# Inference
|
| 1128 |
+
with dt[1]:
|
| 1129 |
+
y = self.model(x, augment=augment) # forward
|
| 1130 |
+
|
| 1131 |
+
# Post-process
|
| 1132 |
+
with dt[2]:
|
| 1133 |
+
y = non_max_suppression(y if self.dmb else y[0],
|
| 1134 |
+
self.conf,
|
| 1135 |
+
self.iou,
|
| 1136 |
+
self.classes,
|
| 1137 |
+
self.agnostic,
|
| 1138 |
+
self.multi_label,
|
| 1139 |
+
max_det=self.max_det) # NMS
|
| 1140 |
+
for i in range(n):
|
| 1141 |
+
scale_boxes(shape1, y[i][:, :4], shape0[i])
|
| 1142 |
+
|
| 1143 |
+
return Detections(ims, y, files, dt, self.names, x.shape)
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
class Detections:
|
| 1147 |
+
# YOLO detections class for inference results
|
| 1148 |
+
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
| 1149 |
+
super().__init__()
|
| 1150 |
+
d = pred[0].device # device
|
| 1151 |
+
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
| 1152 |
+
self.ims = ims # list of images as numpy arrays
|
| 1153 |
+
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
| 1154 |
+
self.names = names # class names
|
| 1155 |
+
self.files = files # image filenames
|
| 1156 |
+
self.times = times # profiling times
|
| 1157 |
+
self.xyxy = pred # xyxy pixels
|
| 1158 |
+
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
| 1159 |
+
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
| 1160 |
+
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
| 1161 |
+
self.n = len(self.pred) # number of images (batch size)
|
| 1162 |
+
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
| 1163 |
+
self.s = tuple(shape) # inference BCHW shape
|
| 1164 |
+
|
| 1165 |
+
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
| 1166 |
+
s, crops = '', []
|
| 1167 |
+
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
| 1168 |
+
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
| 1169 |
+
if pred.shape[0]:
|
| 1170 |
+
for c in pred[:, -1].unique():
|
| 1171 |
+
n = (pred[:, -1] == c).sum() # detections per class
|
| 1172 |
+
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
| 1173 |
+
s = s.rstrip(', ')
|
| 1174 |
+
if show or save or render or crop:
|
| 1175 |
+
annotator = Annotator(im, example=str(self.names))
|
| 1176 |
+
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
|
| 1177 |
+
label = f'{self.names[int(cls)]} {conf:.2f}'
|
| 1178 |
+
if crop:
|
| 1179 |
+
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
|
| 1180 |
+
crops.append({
|
| 1181 |
+
'box': box,
|
| 1182 |
+
'conf': conf,
|
| 1183 |
+
'cls': cls,
|
| 1184 |
+
'label': label,
|
| 1185 |
+
'im': save_one_box(box, im, file=file, save=save)})
|
| 1186 |
+
else: # all others
|
| 1187 |
+
annotator.box_label(box, label if labels else '', color=colors(cls))
|
| 1188 |
+
im = annotator.im
|
| 1189 |
+
else:
|
| 1190 |
+
s += '(no detections)'
|
| 1191 |
+
|
| 1192 |
+
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
| 1193 |
+
if show:
|
| 1194 |
+
display(im) if is_notebook() else im.show(self.files[i])
|
| 1195 |
+
if save:
|
| 1196 |
+
f = self.files[i]
|
| 1197 |
+
im.save(save_dir / f) # save
|
| 1198 |
+
if i == self.n - 1:
|
| 1199 |
+
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
| 1200 |
+
if render:
|
| 1201 |
+
self.ims[i] = np.asarray(im)
|
| 1202 |
+
if pprint:
|
| 1203 |
+
s = s.lstrip('\n')
|
| 1204 |
+
return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
|
| 1205 |
+
if crop:
|
| 1206 |
+
if save:
|
| 1207 |
+
LOGGER.info(f'Saved results to {save_dir}\n')
|
| 1208 |
+
return crops
|
| 1209 |
+
|
| 1210 |
+
@TryExcept('Showing images is not supported in this environment')
|
| 1211 |
+
def show(self, labels=True):
|
| 1212 |
+
self._run(show=True, labels=labels) # show results
|
| 1213 |
+
|
| 1214 |
+
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
|
| 1215 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
|
| 1216 |
+
self._run(save=True, labels=labels, save_dir=save_dir) # save results
|
| 1217 |
+
|
| 1218 |
+
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
|
| 1219 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
|
| 1220 |
+
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
|
| 1221 |
+
|
| 1222 |
+
def render(self, labels=True):
|
| 1223 |
+
self._run(render=True, labels=labels) # render results
|
| 1224 |
+
return self.ims
|
| 1225 |
+
|
| 1226 |
+
def pandas(self):
|
| 1227 |
+
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
| 1228 |
+
new = copy(self) # return copy
|
| 1229 |
+
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
| 1230 |
+
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
| 1231 |
+
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
| 1232 |
+
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
| 1233 |
+
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
| 1234 |
+
return new
|
| 1235 |
+
|
| 1236 |
+
def tolist(self):
|
| 1237 |
+
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
| 1238 |
+
r = range(self.n) # iterable
|
| 1239 |
+
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
| 1240 |
+
# for d in x:
|
| 1241 |
+
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
| 1242 |
+
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
| 1243 |
+
return x
|
| 1244 |
+
|
| 1245 |
+
def print(self):
|
| 1246 |
+
LOGGER.info(self.__str__())
|
| 1247 |
+
|
| 1248 |
+
def __len__(self): # override len(results)
|
| 1249 |
+
return self.n
|
| 1250 |
+
|
| 1251 |
+
def __str__(self): # override print(results)
|
| 1252 |
+
return self._run(pprint=True) # print results
|
| 1253 |
+
|
| 1254 |
+
def __repr__(self):
|
| 1255 |
+
return f'YOLO {self.__class__} instance\n' + self.__str__()
|
| 1256 |
+
|
| 1257 |
+
|
| 1258 |
+
class Proto(nn.Module):
|
| 1259 |
+
# YOLO mask Proto module for segmentation models
|
| 1260 |
+
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
| 1261 |
+
super().__init__()
|
| 1262 |
+
self.cv1 = Conv(c1, c_, k=3)
|
| 1263 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
| 1264 |
+
self.cv2 = Conv(c_, c_, k=3)
|
| 1265 |
+
self.cv3 = Conv(c_, c2)
|
| 1266 |
+
|
| 1267 |
+
def forward(self, x):
|
| 1268 |
+
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
class UConv(nn.Module):
|
| 1272 |
+
def __init__(self, c1, c_=256, c2=256): # ch_in, number of protos, number of masks
|
| 1273 |
+
super().__init__()
|
| 1274 |
+
|
| 1275 |
+
self.cv1 = Conv(c1, c_, k=3)
|
| 1276 |
+
self.cv2 = nn.Conv2d(c_, c2, 1, 1)
|
| 1277 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 1278 |
+
|
| 1279 |
+
def forward(self, x):
|
| 1280 |
+
return self.up(self.cv2(self.cv1(x)))
|
| 1281 |
+
|
| 1282 |
+
|
| 1283 |
+
class Classify(nn.Module):
|
| 1284 |
+
# YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
| 1285 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
| 1286 |
+
super().__init__()
|
| 1287 |
+
c_ = 1280 # efficientnet_b0 size
|
| 1288 |
+
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
|
| 1289 |
+
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
| 1290 |
+
self.drop = nn.Dropout(p=0.0, inplace=True)
|
| 1291 |
+
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
| 1292 |
+
|
| 1293 |
+
def forward(self, x):
|
| 1294 |
+
if isinstance(x, list):
|
| 1295 |
+
x = torch.cat(x, 1)
|
| 1296 |
+
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
models/detect/pk-yolo.yaml
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# YOLOv9
|
| 2 |
+
|
| 3 |
+
# parameters
|
| 4 |
+
nc: 2 # number of classes
|
| 5 |
+
depth_multiple: 1.0 # model depth multiple
|
| 6 |
+
width_multiple: 1.0 # layer channel multiple
|
| 7 |
+
# activation: nn.LeakyReLU(0.1)
|
| 8 |
+
# activation: nn.ReLU()
|
| 9 |
+
|
| 10 |
+
# anchors
|
| 11 |
+
anchors: 3
|
| 12 |
+
|
| 13 |
+
# YOLOv9 backbone
|
| 14 |
+
backbone:
|
| 15 |
+
[
|
| 16 |
+
[-1, 1, Silence, []],
|
| 17 |
+
|
| 18 |
+
[-1, 1, Backbone, []],
|
| 19 |
+
# conv down
|
| 20 |
+
[1, 1, Down0, [64]], #2 320 1
|
| 21 |
+
[1, 1, Down1, [128]], # 3 160 3
|
| 22 |
+
[1, 1, Down2, [256]],# 4 80 5
|
| 23 |
+
[1, 1, Down3, [512]], #5 40 7
|
| 24 |
+
[1, 1, Down4, [1024]], #6 20 9
|
| 25 |
+
|
| 26 |
+
# routing
|
| 27 |
+
[ 2, 1, CBLinear, [ [ 64 ] ] ], # 10
|
| 28 |
+
[ 3, 1, CBLinear, [ [ 64, 128 ] ] ], # 11
|
| 29 |
+
[ 4, 1, CBLinear, [ [ 64, 128, 256 ] ] ], # 12
|
| 30 |
+
[ 5, 1, CBLinear, [ [ 64, 128, 256, 512 ] ] ], # 13
|
| 31 |
+
[ 6, 1, CBLinear, [ [ 64, 128, 256, 512, 1024 ] ] ], # 14 -3
|
| 32 |
+
|
| 33 |
+
# conv down fuse
|
| 34 |
+
[ 0, 1, Conv, [ 64, 3, 2 ] ], # 15-P1/2
|
| 35 |
+
[ [ 7, 8, 9, 10, 11, -1 ], 1, CBFuse, [ [ 0, 0, 0, 0, 0 ] ] ], # 16
|
| 36 |
+
|
| 37 |
+
# conv down fuse
|
| 38 |
+
[ -1, 1, Conv, [ 128, 3, 2 ] ], # 17-P2/4
|
| 39 |
+
[ [ 8, 9, 10, 11, -1 ], 1, CBFuse, [ [ 1, 1, 1, 1 ] ] ], # 18
|
| 40 |
+
|
| 41 |
+
# elan-1 block
|
| 42 |
+
[ -1, 1, RepNCSPELAN4, [ 256, 128, 64, 2 ] ], # 19
|
| 43 |
+
|
| 44 |
+
# avg-conv down fuse
|
| 45 |
+
[ -1, 1, ADown, [ 256 ] ], # 20-P3/8
|
| 46 |
+
[ [ 9, 10, 11, -1 ], 1, CBFuse, [ [ 2, 2, 2 ] ] ], # 21
|
| 47 |
+
|
| 48 |
+
# elan-2 block
|
| 49 |
+
[ -1, 1, RepNCSPELAN4, [ 512, 256, 128, 2 ] ], # 22
|
| 50 |
+
|
| 51 |
+
# avg-conv down fuse
|
| 52 |
+
[ -1, 1, ADown, [ 512 ] ], # 23-P4/16
|
| 53 |
+
[ [ 10, 11, -1 ], 1, CBFuse, [ [ 3, 3 ] ] ], # 24
|
| 54 |
+
|
| 55 |
+
# elan-2 block
|
| 56 |
+
[ -1, 1, RepNCSPELAN4, [ 1024, 512, 256, 2 ] ], # 25
|
| 57 |
+
|
| 58 |
+
# avg-conv down fuse
|
| 59 |
+
[ -1, 1, ADown, [ 1024 ] ], # 26-P5/32
|
| 60 |
+
[ [ 11, -1 ], 1, CBFuse, [ [ 4 ] ] ], # 27
|
| 61 |
+
|
| 62 |
+
# elan-2 block
|
| 63 |
+
[ -1, 1, RepNCSPELAN4, [ 1024, 512, 256, 2 ] ], # 28 25
|
| 64 |
+
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
# YOLOv9 head
|
| 68 |
+
head:
|
| 69 |
+
[
|
| 70 |
+
# multi-level auxiliary branch
|
| 71 |
+
|
| 72 |
+
# elan-spp block
|
| 73 |
+
[6, 1, SPPELAN, [512, 256]], # 29
|
| 74 |
+
|
| 75 |
+
# up-concat merge
|
| 76 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 77 |
+
[[-1, 5], 1, Concat, [1]], # cat backbone P4
|
| 78 |
+
|
| 79 |
+
# csp-elan block
|
| 80 |
+
[-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 32
|
| 81 |
+
|
| 82 |
+
# up-concat merge
|
| 83 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 84 |
+
[[-1, 4], 1, Concat, [1]], # cat backbone P3
|
| 85 |
+
|
| 86 |
+
# csp-elan block
|
| 87 |
+
[-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 35
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# main branch
|
| 92 |
+
|
| 93 |
+
# elan-spp block
|
| 94 |
+
[25, 1, SPPELAN, [512, 256]], # 36
|
| 95 |
+
|
| 96 |
+
# up-concat merge
|
| 97 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 98 |
+
[[-1, 22], 1, Concat, [1]], # cat backbone P4
|
| 99 |
+
|
| 100 |
+
# csp-elan block
|
| 101 |
+
[-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 39
|
| 102 |
+
|
| 103 |
+
# up-concat merge
|
| 104 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 105 |
+
[[-1, 19], 1, Concat, [1]], # cat backbone P3
|
| 106 |
+
|
| 107 |
+
# csp-elan block
|
| 108 |
+
[-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 42 (P3/8-small)
|
| 109 |
+
|
| 110 |
+
# avg-conv-down merge
|
| 111 |
+
[-1, 1, ADown, [256]],
|
| 112 |
+
[[-1, 36], 1, Concat, [1]], # cat head P4
|
| 113 |
+
|
| 114 |
+
# csp-elan block
|
| 115 |
+
[-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 45 (P4/16-medium)
|
| 116 |
+
|
| 117 |
+
# avg-conv-down merge
|
| 118 |
+
[-1, 1, ADown, [512]],
|
| 119 |
+
[[-1, 33], 1, Concat, [1]], # cat head P5
|
| 120 |
+
|
| 121 |
+
# csp-elan block
|
| 122 |
+
[-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]], # 48 (P5/32-large)
|
| 123 |
+
|
| 124 |
+
# detect
|
| 125 |
+
[[32, 29, 26, 39, 42, 45], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
|
| 126 |
+
]
|
models/detect/yolov9-e.yaml
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# YOLOv9
|
| 2 |
+
|
| 3 |
+
# parameters
|
| 4 |
+
nc: 2 # number of classes
|
| 5 |
+
depth_multiple: 1.0 # model depth multiple
|
| 6 |
+
width_multiple: 1.0 # layer channel multiple
|
| 7 |
+
#activation: nn.LeakyReLU(0.1)
|
| 8 |
+
#activation: nn.ReLU()
|
| 9 |
+
|
| 10 |
+
# anchors
|
| 11 |
+
anchors: 3
|
| 12 |
+
|
| 13 |
+
# YOLOv9 backbone
|
| 14 |
+
backbone:
|
| 15 |
+
[
|
| 16 |
+
[-1, 1, Silence, []],
|
| 17 |
+
|
| 18 |
+
# conv down
|
| 19 |
+
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
|
| 20 |
+
|
| 21 |
+
# conv down
|
| 22 |
+
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
|
| 23 |
+
|
| 24 |
+
# csp-elan block
|
| 25 |
+
[-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 3
|
| 26 |
+
|
| 27 |
+
# avg-conv down
|
| 28 |
+
[-1, 1, ADown, [256]], # 4-P3/8
|
| 29 |
+
|
| 30 |
+
# csp-elan block
|
| 31 |
+
[-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 5
|
| 32 |
+
|
| 33 |
+
# avg-conv down
|
| 34 |
+
[-1, 1, ADown, [512]], # 6-P4/16
|
| 35 |
+
|
| 36 |
+
# csp-elan block
|
| 37 |
+
[-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 7
|
| 38 |
+
|
| 39 |
+
# avg-conv down
|
| 40 |
+
[-1, 1, ADown, [1024]], # 8-P5/32
|
| 41 |
+
|
| 42 |
+
# csp-elan block
|
| 43 |
+
[-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 9
|
| 44 |
+
|
| 45 |
+
# routing
|
| 46 |
+
[1, 1, CBLinear, [[64]]], # 10
|
| 47 |
+
[3, 1, CBLinear, [[64, 128]]], # 11
|
| 48 |
+
[5, 1, CBLinear, [[64, 128, 256]]], # 12
|
| 49 |
+
[7, 1, CBLinear, [[64, 128, 256, 512]]], # 13
|
| 50 |
+
[9, 1, CBLinear, [[64, 128, 256, 512, 1024]]], # 14
|
| 51 |
+
|
| 52 |
+
# conv down
|
| 53 |
+
[0, 1, Conv, [64, 3, 2]], # 15-P1/2
|
| 54 |
+
[[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]], # 16
|
| 55 |
+
|
| 56 |
+
# conv down
|
| 57 |
+
[-1, 1, Conv, [128, 3, 2]], # 17-P2/4
|
| 58 |
+
[[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]], # 18
|
| 59 |
+
|
| 60 |
+
# csp-elan block
|
| 61 |
+
[-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 19
|
| 62 |
+
|
| 63 |
+
# avg-conv down fuse
|
| 64 |
+
[-1, 1, ADown, [256]], # 20-P3/8
|
| 65 |
+
[[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]], # 21
|
| 66 |
+
|
| 67 |
+
# csp-elan block
|
| 68 |
+
[-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 22
|
| 69 |
+
|
| 70 |
+
# avg-conv down fuse
|
| 71 |
+
[-1, 1, ADown, [512]], # 23-P4/16
|
| 72 |
+
[[13, 14, -1], 1, CBFuse, [[3, 3]]], # 24
|
| 73 |
+
|
| 74 |
+
# csp-elan block
|
| 75 |
+
[-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 25
|
| 76 |
+
|
| 77 |
+
# avg-conv down fuse
|
| 78 |
+
[-1, 1, ADown, [1024]], # 26-P5/32
|
| 79 |
+
[[14, -1], 1, CBFuse, [[4]]], # 27
|
| 80 |
+
|
| 81 |
+
# csp-elan block
|
| 82 |
+
[-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 28
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# YOLOv9 head
|
| 86 |
+
head:
|
| 87 |
+
[
|
| 88 |
+
# multi-level auxiliary branch
|
| 89 |
+
|
| 90 |
+
# elan-spp block
|
| 91 |
+
[9, 1, SPPELAN, [512, 256]], # 29
|
| 92 |
+
|
| 93 |
+
# up-concat merge
|
| 94 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 95 |
+
[[-1, 7], 1, Concat, [1]], # cat backbone P4
|
| 96 |
+
|
| 97 |
+
# csp-elan block
|
| 98 |
+
[-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 32
|
| 99 |
+
|
| 100 |
+
# up-concat merge
|
| 101 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 102 |
+
[[-1, 5], 1, Concat, [1]], # cat backbone P3
|
| 103 |
+
|
| 104 |
+
# csp-elan block
|
| 105 |
+
[-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 35
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# main branch
|
| 110 |
+
|
| 111 |
+
# elan-spp block
|
| 112 |
+
[28, 1, SPPELAN, [512, 256]], # 36
|
| 113 |
+
|
| 114 |
+
# up-concat merge
|
| 115 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 116 |
+
[[-1, 25], 1, Concat, [1]], # cat backbone P4
|
| 117 |
+
|
| 118 |
+
# csp-elan block
|
| 119 |
+
[-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 39
|
| 120 |
+
|
| 121 |
+
# up-concat merge
|
| 122 |
+
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
|
| 123 |
+
[[-1, 22], 1, Concat, [1]], # cat backbone P3
|
| 124 |
+
|
| 125 |
+
# csp-elan block
|
| 126 |
+
[-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 42 (P3/8-small)
|
| 127 |
+
|
| 128 |
+
# avg-conv-down merge
|
| 129 |
+
[-1, 1, ADown, [256]],
|
| 130 |
+
[[-1, 39], 1, Concat, [1]], # cat head P4
|
| 131 |
+
|
| 132 |
+
# csp-elan block
|
| 133 |
+
[-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 45 (P4/16-medium)
|
| 134 |
+
|
| 135 |
+
# avg-conv-down merge
|
| 136 |
+
[-1, 1, ADown, [512]],
|
| 137 |
+
[[-1, 36], 1, Concat, [1]], # cat head P5
|
| 138 |
+
|
| 139 |
+
# csp-elan block
|
| 140 |
+
[-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]], # 48 (P5/32-large)
|
| 141 |
+
|
| 142 |
+
# detect
|
| 143 |
+
[[35, 32, 29, 42, 45, 48], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
|
| 144 |
+
]
|
models/experimental.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from utils.downloads import attempt_download
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Sum(nn.Module):
|
| 11 |
+
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
| 12 |
+
def __init__(self, n, weight=False): # n: number of inputs
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.weight = weight # apply weights boolean
|
| 15 |
+
self.iter = range(n - 1) # iter object
|
| 16 |
+
if weight:
|
| 17 |
+
self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
y = x[0] # no weight
|
| 21 |
+
if self.weight:
|
| 22 |
+
w = torch.sigmoid(self.w) * 2
|
| 23 |
+
for i in self.iter:
|
| 24 |
+
y = y + x[i + 1] * w[i]
|
| 25 |
+
else:
|
| 26 |
+
for i in self.iter:
|
| 27 |
+
y = y + x[i + 1]
|
| 28 |
+
return y
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MixConv2d(nn.Module):
|
| 32 |
+
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
|
| 33 |
+
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
|
| 34 |
+
super().__init__()
|
| 35 |
+
n = len(k) # number of convolutions
|
| 36 |
+
if equal_ch: # equal c_ per group
|
| 37 |
+
i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
|
| 38 |
+
c_ = [(i == g).sum() for g in range(n)] # intermediate channels
|
| 39 |
+
else: # equal weight.numel() per group
|
| 40 |
+
b = [c2] + [0] * n
|
| 41 |
+
a = np.eye(n + 1, n, k=-1)
|
| 42 |
+
a -= np.roll(a, 1, axis=1)
|
| 43 |
+
a *= np.array(k) ** 2
|
| 44 |
+
a[0] = 1
|
| 45 |
+
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
|
| 46 |
+
|
| 47 |
+
self.m = nn.ModuleList([
|
| 48 |
+
nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
|
| 49 |
+
self.bn = nn.BatchNorm2d(c2)
|
| 50 |
+
self.act = nn.SiLU()
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Ensemble(nn.ModuleList):
|
| 57 |
+
# Ensemble of models
|
| 58 |
+
def __init__(self):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
def forward(self, x, augment=False, profile=False, visualize=False):
|
| 62 |
+
y = [module(x, augment, profile, visualize)[0] for module in self]
|
| 63 |
+
# y = torch.stack(y).max(0)[0] # max ensemble
|
| 64 |
+
# y = torch.stack(y).mean(0) # mean ensemble
|
| 65 |
+
y = torch.cat(y, 1) # nms ensemble
|
| 66 |
+
return y, None # inference, train output
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ORT_NMS(torch.autograd.Function):
|
| 70 |
+
'''ONNX-Runtime NMS operation'''
|
| 71 |
+
@staticmethod
|
| 72 |
+
def forward(ctx,
|
| 73 |
+
boxes,
|
| 74 |
+
scores,
|
| 75 |
+
max_output_boxes_per_class=torch.tensor([100]),
|
| 76 |
+
iou_threshold=torch.tensor([0.45]),
|
| 77 |
+
score_threshold=torch.tensor([0.25])):
|
| 78 |
+
device = boxes.device
|
| 79 |
+
batch = scores.shape[0]
|
| 80 |
+
num_det = random.randint(0, 100)
|
| 81 |
+
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
|
| 82 |
+
idxs = torch.arange(100, 100 + num_det).to(device)
|
| 83 |
+
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
|
| 84 |
+
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
|
| 85 |
+
selected_indices = selected_indices.to(torch.int64)
|
| 86 |
+
return selected_indices
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
|
| 90 |
+
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TRT_NMS(torch.autograd.Function):
|
| 94 |
+
'''TensorRT NMS operation'''
|
| 95 |
+
@staticmethod
|
| 96 |
+
def forward(
|
| 97 |
+
ctx,
|
| 98 |
+
boxes,
|
| 99 |
+
scores,
|
| 100 |
+
background_class=-1,
|
| 101 |
+
box_coding=1,
|
| 102 |
+
iou_threshold=0.45,
|
| 103 |
+
max_output_boxes=100,
|
| 104 |
+
plugin_version="1",
|
| 105 |
+
score_activation=0,
|
| 106 |
+
score_threshold=0.25,
|
| 107 |
+
):
|
| 108 |
+
|
| 109 |
+
batch_size, num_boxes, num_classes = scores.shape
|
| 110 |
+
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
|
| 111 |
+
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
|
| 112 |
+
det_scores = torch.randn(batch_size, max_output_boxes)
|
| 113 |
+
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
|
| 114 |
+
return num_det, det_boxes, det_scores, det_classes
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def symbolic(g,
|
| 118 |
+
boxes,
|
| 119 |
+
scores,
|
| 120 |
+
background_class=-1,
|
| 121 |
+
box_coding=1,
|
| 122 |
+
iou_threshold=0.45,
|
| 123 |
+
max_output_boxes=100,
|
| 124 |
+
plugin_version="1",
|
| 125 |
+
score_activation=0,
|
| 126 |
+
score_threshold=0.25):
|
| 127 |
+
out = g.op("TRT::EfficientNMS_TRT",
|
| 128 |
+
boxes,
|
| 129 |
+
scores,
|
| 130 |
+
background_class_i=background_class,
|
| 131 |
+
box_coding_i=box_coding,
|
| 132 |
+
iou_threshold_f=iou_threshold,
|
| 133 |
+
max_output_boxes_i=max_output_boxes,
|
| 134 |
+
plugin_version_s=plugin_version,
|
| 135 |
+
score_activation_i=score_activation,
|
| 136 |
+
score_threshold_f=score_threshold,
|
| 137 |
+
outputs=4)
|
| 138 |
+
nums, boxes, scores, classes = out
|
| 139 |
+
return nums, boxes, scores, classes
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ONNX_ORT(nn.Module):
|
| 143 |
+
'''onnx module with ONNX-Runtime NMS operation.'''
|
| 144 |
+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.device = device if device else torch.device("cpu")
|
| 147 |
+
self.max_obj = torch.tensor([max_obj]).to(device)
|
| 148 |
+
self.iou_threshold = torch.tensor([iou_thres]).to(device)
|
| 149 |
+
self.score_threshold = torch.tensor([score_thres]).to(device)
|
| 150 |
+
self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
|
| 151 |
+
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
|
| 152 |
+
dtype=torch.float32,
|
| 153 |
+
device=self.device)
|
| 154 |
+
self.n_classes=n_classes
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
|
| 158 |
+
## thanks https://github.com/thaitc-hust
|
| 159 |
+
if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
|
| 160 |
+
x = x[1]
|
| 161 |
+
x = x.permute(0, 2, 1)
|
| 162 |
+
bboxes_x = x[..., 0:1]
|
| 163 |
+
bboxes_y = x[..., 1:2]
|
| 164 |
+
bboxes_w = x[..., 2:3]
|
| 165 |
+
bboxes_h = x[..., 3:4]
|
| 166 |
+
bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
|
| 167 |
+
bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
|
| 168 |
+
obj_conf = x[..., 4:]
|
| 169 |
+
scores = obj_conf
|
| 170 |
+
bboxes @= self.convert_matrix
|
| 171 |
+
max_score, category_id = scores.max(2, keepdim=True)
|
| 172 |
+
dis = category_id.float() * self.max_wh
|
| 173 |
+
nmsbox = bboxes + dis
|
| 174 |
+
max_score_tp = max_score.transpose(1, 2).contiguous()
|
| 175 |
+
selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
|
| 176 |
+
X, Y = selected_indices[:, 0], selected_indices[:, 2]
|
| 177 |
+
selected_boxes = bboxes[X, Y, :]
|
| 178 |
+
selected_categories = category_id[X, Y, :].float()
|
| 179 |
+
selected_scores = max_score[X, Y, :]
|
| 180 |
+
X = X.unsqueeze(1).float()
|
| 181 |
+
return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ONNX_TRT(nn.Module):
|
| 185 |
+
'''onnx module with TensorRT NMS operation.'''
|
| 186 |
+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
|
| 187 |
+
super().__init__()
|
| 188 |
+
assert max_wh is None
|
| 189 |
+
self.device = device if device else torch.device('cpu')
|
| 190 |
+
self.background_class = -1,
|
| 191 |
+
self.box_coding = 1,
|
| 192 |
+
self.iou_threshold = iou_thres
|
| 193 |
+
self.max_obj = max_obj
|
| 194 |
+
self.plugin_version = '1'
|
| 195 |
+
self.score_activation = 0
|
| 196 |
+
self.score_threshold = score_thres
|
| 197 |
+
self.n_classes=n_classes
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
|
| 201 |
+
## thanks https://github.com/thaitc-hust
|
| 202 |
+
if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
|
| 203 |
+
x = x[1]
|
| 204 |
+
x = x.permute(0, 2, 1)
|
| 205 |
+
bboxes_x = x[..., 0:1]
|
| 206 |
+
bboxes_y = x[..., 1:2]
|
| 207 |
+
bboxes_w = x[..., 2:3]
|
| 208 |
+
bboxes_h = x[..., 3:4]
|
| 209 |
+
bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
|
| 210 |
+
bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
|
| 211 |
+
obj_conf = x[..., 4:]
|
| 212 |
+
scores = obj_conf
|
| 213 |
+
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(bboxes, scores, self.background_class, self.box_coding,
|
| 214 |
+
self.iou_threshold, self.max_obj,
|
| 215 |
+
self.plugin_version, self.score_activation,
|
| 216 |
+
self.score_threshold)
|
| 217 |
+
return num_det, det_boxes, det_scores, det_classes
|
| 218 |
+
|
| 219 |
+
class End2End(nn.Module):
|
| 220 |
+
'''export onnx or tensorrt model with NMS operation.'''
|
| 221 |
+
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
|
| 222 |
+
super().__init__()
|
| 223 |
+
device = device if device else torch.device('cpu')
|
| 224 |
+
assert isinstance(max_wh,(int)) or max_wh is None
|
| 225 |
+
self.model = model.to(device)
|
| 226 |
+
self.model.model[-1].end2end = True
|
| 227 |
+
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
|
| 228 |
+
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
|
| 229 |
+
self.end2end.eval()
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
x = self.model(x)
|
| 233 |
+
x = self.end2end(x)
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def attempt_load(weights, device=None, inplace=True, fuse=True):
|
| 238 |
+
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
| 239 |
+
from models.yolo import Detect, Model
|
| 240 |
+
|
| 241 |
+
model = Ensemble()
|
| 242 |
+
for w in weights if isinstance(weights, list) else [weights]:
|
| 243 |
+
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
| 244 |
+
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
| 245 |
+
|
| 246 |
+
# Model compatibility updates
|
| 247 |
+
if not hasattr(ckpt, 'stride'):
|
| 248 |
+
ckpt.stride = torch.tensor([32.])
|
| 249 |
+
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
| 250 |
+
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
| 251 |
+
|
| 252 |
+
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
| 253 |
+
|
| 254 |
+
# Module compatibility updates
|
| 255 |
+
for m in model.modules():
|
| 256 |
+
t = type(m)
|
| 257 |
+
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
| 258 |
+
m.inplace = inplace # torch 1.7.0 compatibility
|
| 259 |
+
# if t is Detect and not isinstance(m.anchor_grid, list):
|
| 260 |
+
# delattr(m, 'anchor_grid')
|
| 261 |
+
# setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
| 262 |
+
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
| 263 |
+
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
| 264 |
+
|
| 265 |
+
# Return model
|
| 266 |
+
if len(model) == 1:
|
| 267 |
+
return model[-1]
|
| 268 |
+
|
| 269 |
+
# Return detection ensemble
|
| 270 |
+
print(f'Ensemble created with {weights}\n')
|
| 271 |
+
for k in 'names', 'nc', 'yaml':
|
| 272 |
+
setattr(model, k, getattr(model[0], k))
|
| 273 |
+
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
| 274 |
+
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
| 275 |
+
return model
|
models/repvit.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import numpy as np
|
| 3 |
+
import itertools
|
| 4 |
+
|
| 5 |
+
def _make_divisible(v, divisor, min_value=None):
|
| 6 |
+
"""
|
| 7 |
+
This function is taken from the original tf repo.
|
| 8 |
+
It ensures that all layers have a channel number that is divisible by 8
|
| 9 |
+
It can be seen here:
|
| 10 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
| 11 |
+
:param v:
|
| 12 |
+
:param divisor:
|
| 13 |
+
:param min_value:
|
| 14 |
+
:return:
|
| 15 |
+
"""
|
| 16 |
+
if min_value is None:
|
| 17 |
+
min_value = divisor
|
| 18 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
| 19 |
+
# Make sure that round down does not go down by more than 10%.
|
| 20 |
+
if new_v < 0.9 * v:
|
| 21 |
+
new_v += divisor
|
| 22 |
+
return new_v
|
| 23 |
+
|
| 24 |
+
from timm.models.layers import SqueezeExcite
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
class Conv2d_BN(torch.nn.Sequential):
|
| 29 |
+
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
|
| 30 |
+
groups=1, bn_weight_init=1, resolution=-10000):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.add_module('c', torch.nn.Conv2d(
|
| 33 |
+
a, b, ks, stride, pad, dilation, groups, bias=False))
|
| 34 |
+
self.add_module('bn', torch.nn.BatchNorm2d(b))
|
| 35 |
+
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
|
| 36 |
+
torch.nn.init.constant_(self.bn.bias, 0)
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def fuse(self):
|
| 40 |
+
c, bn = self._modules.values()
|
| 41 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
| 42 |
+
w = c.weight * w[:, None, None, None]
|
| 43 |
+
b = bn.bias - bn.running_mean * bn.weight / \
|
| 44 |
+
(bn.running_var + bn.eps)**0.5
|
| 45 |
+
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
|
| 46 |
+
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
|
| 47 |
+
device=c.weight.device)
|
| 48 |
+
m.weight.data.copy_(w)
|
| 49 |
+
m.bias.data.copy_(b)
|
| 50 |
+
return m
|
| 51 |
+
|
| 52 |
+
class Residual(torch.nn.Module):
|
| 53 |
+
def __init__(self, m, drop=0.):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.m = m
|
| 56 |
+
self.drop = drop
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
if self.training and self.drop > 0:
|
| 60 |
+
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
|
| 61 |
+
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
| 62 |
+
else:
|
| 63 |
+
return x + self.m(x)
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def fuse(self):
|
| 67 |
+
if isinstance(self.m, Conv2d_BN):
|
| 68 |
+
m = self.m.fuse()
|
| 69 |
+
assert(m.groups == m.in_channels)
|
| 70 |
+
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
| 71 |
+
identity = torch.nn.functional.pad(identity, [1,1,1,1])
|
| 72 |
+
m.weight += identity.to(m.weight.device)
|
| 73 |
+
return m
|
| 74 |
+
elif isinstance(self.m, torch.nn.Conv2d):
|
| 75 |
+
m = self.m
|
| 76 |
+
assert(m.groups != m.in_channels)
|
| 77 |
+
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
|
| 78 |
+
identity = torch.nn.functional.pad(identity, [1,1,1,1])
|
| 79 |
+
m.weight += identity.to(m.weight.device)
|
| 80 |
+
return m
|
| 81 |
+
else:
|
| 82 |
+
return self
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class RepVGGDW(torch.nn.Module):
|
| 86 |
+
def __init__(self, ed) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
|
| 89 |
+
self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
|
| 90 |
+
self.dim = ed
|
| 91 |
+
self.bn = torch.nn.BatchNorm2d(ed)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
return self.bn((self.conv(x) + self.conv1(x)) + x)
|
| 95 |
+
|
| 96 |
+
@torch.no_grad()
|
| 97 |
+
def fuse(self):
|
| 98 |
+
conv = self.conv.fuse()
|
| 99 |
+
conv1 = self.conv1
|
| 100 |
+
|
| 101 |
+
conv_w = conv.weight
|
| 102 |
+
conv_b = conv.bias
|
| 103 |
+
conv1_w = conv1.weight
|
| 104 |
+
conv1_b = conv1.bias
|
| 105 |
+
|
| 106 |
+
conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
|
| 107 |
+
|
| 108 |
+
identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
|
| 109 |
+
|
| 110 |
+
final_conv_w = conv_w + conv1_w + identity
|
| 111 |
+
final_conv_b = conv_b + conv1_b
|
| 112 |
+
|
| 113 |
+
conv.weight.data.copy_(final_conv_w)
|
| 114 |
+
conv.bias.data.copy_(final_conv_b)
|
| 115 |
+
|
| 116 |
+
bn = self.bn
|
| 117 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
| 118 |
+
w = conv.weight * w[:, None, None, None]
|
| 119 |
+
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
|
| 120 |
+
(bn.running_var + bn.eps)**0.5
|
| 121 |
+
conv.weight.data.copy_(w)
|
| 122 |
+
conv.bias.data.copy_(b)
|
| 123 |
+
return conv
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class RepViTBlock(nn.Module):
|
| 127 |
+
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
|
| 128 |
+
super(RepViTBlock, self).__init__()
|
| 129 |
+
assert stride in [1, 2]
|
| 130 |
+
|
| 131 |
+
self.identity = stride == 1 and inp == oup
|
| 132 |
+
assert(hidden_dim == 2 * inp)
|
| 133 |
+
|
| 134 |
+
if stride == 2:
|
| 135 |
+
self.token_mixer = nn.Sequential(
|
| 136 |
+
Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
|
| 137 |
+
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
| 138 |
+
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
|
| 139 |
+
)
|
| 140 |
+
self.channel_mixer = Residual(nn.Sequential(
|
| 141 |
+
# pw
|
| 142 |
+
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
|
| 143 |
+
nn.GELU() if use_hs else nn.GELU(),
|
| 144 |
+
# pw-linear
|
| 145 |
+
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
|
| 146 |
+
))
|
| 147 |
+
else:
|
| 148 |
+
assert(self.identity)
|
| 149 |
+
self.token_mixer = nn.Sequential(
|
| 150 |
+
RepVGGDW(inp),
|
| 151 |
+
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
|
| 152 |
+
)
|
| 153 |
+
self.channel_mixer = Residual(nn.Sequential(
|
| 154 |
+
# pw
|
| 155 |
+
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
|
| 156 |
+
nn.GELU() if use_hs else nn.GELU(),
|
| 157 |
+
# pw-linear
|
| 158 |
+
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
|
| 159 |
+
))
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
return self.channel_mixer(self.token_mixer(x))
|
| 163 |
+
|
| 164 |
+
from timm.models.vision_transformer import trunc_normal_
|
| 165 |
+
class BN_Linear(torch.nn.Sequential):
|
| 166 |
+
def __init__(self, a, b, bias=True, std=0.02):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.add_module('bn', torch.nn.BatchNorm1d(a))
|
| 169 |
+
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
|
| 170 |
+
trunc_normal_(self.l.weight, std=std)
|
| 171 |
+
if bias:
|
| 172 |
+
torch.nn.init.constant_(self.l.bias, 0)
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def fuse(self):
|
| 176 |
+
bn, l = self._modules.values()
|
| 177 |
+
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
| 178 |
+
b = bn.bias - self.bn.running_mean * \
|
| 179 |
+
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
| 180 |
+
w = l.weight * w[None, :]
|
| 181 |
+
if l.bias is None:
|
| 182 |
+
b = b @ self.l.weight.T
|
| 183 |
+
else:
|
| 184 |
+
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
|
| 185 |
+
m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
|
| 186 |
+
m.weight.data.copy_(w)
|
| 187 |
+
m.bias.data.copy_(b)
|
| 188 |
+
return m
|
| 189 |
+
|
| 190 |
+
class RepViT(nn.Module):
|
| 191 |
+
def __init__(self, cfgs, distillation=False, pretrained=None, init_cfg=None, out_indices=[]):
|
| 192 |
+
super(RepViT, self).__init__()
|
| 193 |
+
# setting of inverted residual blocks
|
| 194 |
+
self.cfgs = cfgs
|
| 195 |
+
|
| 196 |
+
# building first layer
|
| 197 |
+
input_channel = self.cfgs[0][2]
|
| 198 |
+
patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU() )
|
| 199 |
+
layers = [patch_embed]
|
| 200 |
+
patch_embed2 = torch.nn.Sequential(Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1), torch.nn.GELU())
|
| 201 |
+
layers.append(patch_embed2)
|
| 202 |
+
|
| 203 |
+
# building inverted residual blocks
|
| 204 |
+
block = RepViTBlock
|
| 205 |
+
for k, t, c, use_se, use_hs, s in self.cfgs:
|
| 206 |
+
output_channel = _make_divisible(c, 8)
|
| 207 |
+
exp_size = _make_divisible(input_channel * t, 8)
|
| 208 |
+
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
|
| 209 |
+
input_channel = output_channel
|
| 210 |
+
self.features = nn.ModuleList(layers)
|
| 211 |
+
#
|
| 212 |
+
# self.init_cfg = init_cfg
|
| 213 |
+
# assert(self.init_cfg is not None)
|
| 214 |
+
self.out_indices = out_indices
|
| 215 |
+
|
| 216 |
+
#self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
|
| 217 |
+
self.train()
|
| 218 |
+
self.out_indices=[0,5,11, 37, 42]
|
| 219 |
+
# 320 160 80 40 20
|
| 220 |
+
def train(self, mode=True):
|
| 221 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 222 |
+
super(RepViT, self).train(mode)
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
outs = []
|
| 226 |
+
for i, f in enumerate(self.features):
|
| 227 |
+
x = f(x)
|
| 228 |
+
#print(x.shape)
|
| 229 |
+
if i in self.out_indices:
|
| 230 |
+
outs.append(x)
|
| 231 |
+
#print(x.shape)
|
| 232 |
+
# assert(len(outs) == 4)
|
| 233 |
+
return outs
|
| 234 |
+
|
| 235 |
+
from timm.models import register_model
|
| 236 |
+
def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
|
| 237 |
+
"""
|
| 238 |
+
Constructs a MobileNetV3-Large model
|
| 239 |
+
"""
|
| 240 |
+
cfgs = [
|
| 241 |
+
# k, t, c, SE, HS, s
|
| 242 |
+
[3, 2, 64, 1, 0, 1],
|
| 243 |
+
[3, 2, 64, 0, 0, 1],
|
| 244 |
+
[3, 2, 64, 0, 0, 1],
|
| 245 |
+
[3, 2, 128, 0, 0, 2],
|
| 246 |
+
[3, 2, 128, 1, 0, 1],
|
| 247 |
+
[3, 2, 128, 0, 0, 1],
|
| 248 |
+
[3, 2, 128, 0, 0, 1],
|
| 249 |
+
[3, 2, 256, 0, 1, 2],
|
| 250 |
+
[3, 2, 256, 1, 1, 1],
|
| 251 |
+
[3, 2, 256, 0, 1, 1],
|
| 252 |
+
[3, 2, 256, 1, 1, 1],
|
| 253 |
+
[3, 2, 256, 0, 1, 1],
|
| 254 |
+
[3, 2, 256, 1, 1, 1],
|
| 255 |
+
[3, 2, 256, 0, 1, 1],
|
| 256 |
+
[3, 2, 256, 1, 1, 1],
|
| 257 |
+
[3, 2, 256, 0, 1, 1],
|
| 258 |
+
[3, 2, 256, 1, 1, 1],
|
| 259 |
+
[3, 2, 256, 0, 1, 1],
|
| 260 |
+
[3, 2, 256, 1, 1, 1],
|
| 261 |
+
[3, 2, 256, 0, 1, 1],
|
| 262 |
+
[3, 2, 256, 0, 1, 1],
|
| 263 |
+
[3, 2, 512, 0, 1, 2],
|
| 264 |
+
[3, 2, 512, 1, 1, 1],
|
| 265 |
+
[3, 2, 512, 0, 1, 1]
|
| 266 |
+
]
|
| 267 |
+
return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
|
| 268 |
+
|
| 269 |
+
def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
|
| 270 |
+
"""
|
| 271 |
+
Constructs a MobileNetV3-Large model
|
| 272 |
+
"""
|
| 273 |
+
cfgs = [
|
| 274 |
+
# k, t, c, SE, HS, s
|
| 275 |
+
[3, 2, 64, 1, 0, 1],
|
| 276 |
+
[3, 2, 64, 0, 0, 1],
|
| 277 |
+
[3, 2, 64, 1, 0, 1],
|
| 278 |
+
[3, 2, 64, 0, 0, 1],
|
| 279 |
+
[3, 2, 64, 0, 0, 1],
|
| 280 |
+
[3, 2, 128, 0, 0, 2],
|
| 281 |
+
[3, 2, 128, 1, 0, 1],
|
| 282 |
+
[3, 2, 128, 0, 0, 1],
|
| 283 |
+
[3, 2, 128, 1, 0, 1],
|
| 284 |
+
[3, 2, 128, 0, 0, 1],
|
| 285 |
+
[3, 2, 128, 0, 0, 1],
|
| 286 |
+
[3, 2, 256, 0, 1, 2],
|
| 287 |
+
[3, 2, 256, 1, 1, 1],
|
| 288 |
+
[3, 2, 256, 0, 1, 1],
|
| 289 |
+
[3, 2, 256, 1, 1, 1],
|
| 290 |
+
[3, 2, 256, 0, 1, 1],
|
| 291 |
+
[3, 2, 256, 1, 1, 1],
|
| 292 |
+
[3, 2, 256, 0, 1, 1],
|
| 293 |
+
[3, 2, 256, 1, 1, 1],
|
| 294 |
+
[3, 2, 256, 0, 1, 1],
|
| 295 |
+
[3, 2, 256, 1, 1, 1],
|
| 296 |
+
[3, 2, 256, 0, 1, 1],
|
| 297 |
+
[3, 2, 256, 1, 1, 1],
|
| 298 |
+
[3, 2, 256, 0, 1, 1],
|
| 299 |
+
[3, 2, 256, 1, 1, 1],
|
| 300 |
+
[3, 2, 256, 0, 1, 1],
|
| 301 |
+
[3, 2, 256, 1, 1, 1],
|
| 302 |
+
[3, 2, 256, 0, 1, 1],
|
| 303 |
+
[3, 2, 256, 1, 1, 1],
|
| 304 |
+
[3, 2, 256, 0, 1, 1],
|
| 305 |
+
[3, 2, 256, 1, 1, 1],
|
| 306 |
+
[3, 2, 256, 0, 1, 1],
|
| 307 |
+
[3, 2, 256, 1, 1, 1],
|
| 308 |
+
[3, 2, 256, 0, 1, 1],
|
| 309 |
+
[3, 2, 256, 1, 1, 1],
|
| 310 |
+
[3, 2, 256, 0, 1, 1],
|
| 311 |
+
[3, 2, 256, 0, 1, 1],
|
| 312 |
+
[3, 2, 512, 0, 1, 2],
|
| 313 |
+
[3, 2, 512, 1, 1, 1],
|
| 314 |
+
[3, 2, 512, 0, 1, 1],
|
| 315 |
+
[3, 2, 512, 1, 1, 1],
|
| 316 |
+
[3, 2, 512, 0, 1, 1]
|
| 317 |
+
]
|
| 318 |
+
return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
|
| 322 |
+
"""
|
| 323 |
+
Constructs a MobileNetV3-Large model
|
| 324 |
+
"""
|
| 325 |
+
cfgs = [
|
| 326 |
+
# k, t, c, SE, HS, s
|
| 327 |
+
[3, 2, 80, 1, 0, 1],
|
| 328 |
+
[3, 2, 80, 0, 0, 1],
|
| 329 |
+
[3, 2, 80, 1, 0, 1],
|
| 330 |
+
[3, 2, 80, 0, 0, 1],
|
| 331 |
+
[3, 2, 80, 1, 0, 1],
|
| 332 |
+
[3, 2, 80, 0, 0, 1],
|
| 333 |
+
[3, 2, 80, 0, 0, 1],
|
| 334 |
+
[3, 2, 160, 0, 0, 2],
|
| 335 |
+
[3, 2, 160, 1, 0, 1],
|
| 336 |
+
[3, 2, 160, 0, 0, 1],
|
| 337 |
+
[3, 2, 160, 1, 0, 1],
|
| 338 |
+
[3, 2, 160, 0, 0, 1],
|
| 339 |
+
[3, 2, 160, 1, 0, 1],
|
| 340 |
+
[3, 2, 160, 0, 0, 1],
|
| 341 |
+
[3, 2, 160, 0, 0, 1],
|
| 342 |
+
[3, 2, 320, 0, 1, 2],
|
| 343 |
+
[3, 2, 320, 1, 1, 1],
|
| 344 |
+
[3, 2, 320, 0, 1, 1],
|
| 345 |
+
[3, 2, 320, 1, 1, 1],
|
| 346 |
+
[3, 2, 320, 0, 1, 1],
|
| 347 |
+
[3, 2, 320, 1, 1, 1],
|
| 348 |
+
[3, 2, 320, 0, 1, 1],
|
| 349 |
+
[3, 2, 320, 1, 1, 1],
|
| 350 |
+
[3, 2, 320, 0, 1, 1],
|
| 351 |
+
[3, 2, 320, 1, 1, 1],
|
| 352 |
+
[3, 2, 320, 0, 1, 1],
|
| 353 |
+
[3, 2, 320, 1, 1, 1],
|
| 354 |
+
[3, 2, 320, 0, 1, 1],
|
| 355 |
+
[3, 2, 320, 1, 1, 1],
|
| 356 |
+
[3, 2, 320, 0, 1, 1],
|
| 357 |
+
[3, 2, 320, 1, 1, 1],
|
| 358 |
+
[3, 2, 320, 0, 1, 1],
|
| 359 |
+
[3, 2, 320, 1, 1, 1],
|
| 360 |
+
[3, 2, 320, 0, 1, 1],
|
| 361 |
+
[3, 2, 320, 1, 1, 1],
|
| 362 |
+
[3, 2, 320, 0, 1, 1],
|
| 363 |
+
[3, 2, 320, 1, 1, 1],
|
| 364 |
+
[3, 2, 320, 0, 1, 1],
|
| 365 |
+
[3, 2, 320, 1, 1, 1],
|
| 366 |
+
[3, 2, 320, 0, 1, 1],
|
| 367 |
+
[3, 2, 320, 1, 1, 1],
|
| 368 |
+
[3, 2, 320, 0, 1, 1],
|
| 369 |
+
[3, 2, 320, 1, 1, 1],
|
| 370 |
+
[3, 2, 320, 0, 1, 1],
|
| 371 |
+
[3, 2, 320, 1, 1, 1],
|
| 372 |
+
[3, 2, 320, 0, 1, 1],
|
| 373 |
+
[3, 2, 320, 1, 1, 1],
|
| 374 |
+
[3, 2, 320, 0, 1, 1],
|
| 375 |
+
[3, 2, 320, 1, 1, 1],
|
| 376 |
+
[3, 2, 320, 0, 1, 1],
|
| 377 |
+
# [3, 2, 320, 1, 1, 1],
|
| 378 |
+
# [3, 2, 320, 0, 1, 1],
|
| 379 |
+
[3, 2, 320, 0, 1, 1],
|
| 380 |
+
[3, 2, 640, 0, 1, 2],
|
| 381 |
+
[3, 2, 640, 1, 1, 1],
|
| 382 |
+
[3, 2, 640, 0, 1, 1],
|
| 383 |
+
# [3, 2, 640, 1, 1, 1],
|
| 384 |
+
# [3, 2, 640, 0, 1, 1]
|
| 385 |
+
]
|
| 386 |
+
return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
cfgs = [
|
| 392 |
+
# k, t, c, SE, HS, s
|
| 393 |
+
[3, 2, 64*2, 1, 0, 1],
|
| 394 |
+
[3, 2, 64*2, 0, 0, 1],
|
| 395 |
+
[3, 2, 64*2, 1, 0, 1],
|
| 396 |
+
[3, 2, 64*2, 0, 0, 1],
|
| 397 |
+
[3, 2, 64*2, 0, 0, 1],
|
| 398 |
+
[3, 2, 128*2, 0, 0, 2],
|
| 399 |
+
[3, 2, 128*2, 1, 0, 1],
|
| 400 |
+
[3, 2, 128*2, 0, 0, 1],
|
| 401 |
+
[3, 2, 128*2, 1, 0, 1],
|
| 402 |
+
[3, 2, 128*2, 0, 0, 1],
|
| 403 |
+
[3, 2, 128*2, 0, 0, 1],
|
| 404 |
+
[3, 2, 256*2, 0, 1, 2],
|
| 405 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 406 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 407 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 408 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 409 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 410 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 411 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 412 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 413 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 414 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 415 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 416 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 417 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 418 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 419 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 420 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 421 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 422 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 423 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 424 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 425 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 426 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 427 |
+
[3, 2, 256*2, 1, 1, 1],
|
| 428 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 429 |
+
[3, 2, 256*2, 0, 1, 1],
|
| 430 |
+
[3, 2, 512*2, 0, 1, 2],
|
| 431 |
+
[3, 2, 512*2, 1, 1, 1],
|
| 432 |
+
[3, 2, 512*2, 0, 1, 1],
|
| 433 |
+
[3, 2, 512*2, 1, 1, 1],
|
| 434 |
+
[3, 2, 512*2, 0, 1, 1]
|
| 435 |
+
]
|
| 436 |
+
|
| 437 |
+
if __name__ =="__main__":
|
| 438 |
+
model = RepViT(cfgs )
|
| 439 |
+
t1 = torch.rand(1,3,640,640)
|
| 440 |
+
x = model(t1)
|
models/tf.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import sys
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
FILE = Path(__file__).resolve()
|
| 7 |
+
ROOT = FILE.parents[1] # YOLO root directory
|
| 8 |
+
if str(ROOT) not in sys.path:
|
| 9 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
| 10 |
+
# ROOT = ROOT.relative_to(Path.cwd()) # relative
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import tensorflow as tf
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from tensorflow import keras
|
| 17 |
+
|
| 18 |
+
from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
|
| 19 |
+
DWConvTranspose2d, Focus, autopad)
|
| 20 |
+
from models.experimental import MixConv2d, attempt_load
|
| 21 |
+
from models.yolo import Detect, Segment
|
| 22 |
+
from utils.activations import SiLU
|
| 23 |
+
from utils.general import LOGGER, make_divisible, print_args
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TFBN(keras.layers.Layer):
|
| 27 |
+
# TensorFlow BatchNormalization wrapper
|
| 28 |
+
def __init__(self, w=None):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.bn = keras.layers.BatchNormalization(
|
| 31 |
+
beta_initializer=keras.initializers.Constant(w.bias.numpy()),
|
| 32 |
+
gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
|
| 33 |
+
moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
|
| 34 |
+
moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
|
| 35 |
+
epsilon=w.eps)
|
| 36 |
+
|
| 37 |
+
def call(self, inputs):
|
| 38 |
+
return self.bn(inputs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TFPad(keras.layers.Layer):
|
| 42 |
+
# Pad inputs in spatial dimensions 1 and 2
|
| 43 |
+
def __init__(self, pad):
|
| 44 |
+
super().__init__()
|
| 45 |
+
if isinstance(pad, int):
|
| 46 |
+
self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
|
| 47 |
+
else: # tuple/list
|
| 48 |
+
self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])
|
| 49 |
+
|
| 50 |
+
def call(self, inputs):
|
| 51 |
+
return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TFConv(keras.layers.Layer):
|
| 55 |
+
# Standard convolution
|
| 56 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
|
| 57 |
+
# ch_in, ch_out, weights, kernel, stride, padding, groups
|
| 58 |
+
super().__init__()
|
| 59 |
+
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
|
| 60 |
+
# TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
|
| 61 |
+
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
|
| 62 |
+
conv = keras.layers.Conv2D(
|
| 63 |
+
filters=c2,
|
| 64 |
+
kernel_size=k,
|
| 65 |
+
strides=s,
|
| 66 |
+
padding='SAME' if s == 1 else 'VALID',
|
| 67 |
+
use_bias=not hasattr(w, 'bn'),
|
| 68 |
+
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
|
| 69 |
+
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
|
| 70 |
+
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
|
| 71 |
+
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
|
| 72 |
+
self.act = activations(w.act) if act else tf.identity
|
| 73 |
+
|
| 74 |
+
def call(self, inputs):
|
| 75 |
+
return self.act(self.bn(self.conv(inputs)))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class TFDWConv(keras.layers.Layer):
|
| 79 |
+
# Depthwise convolution
|
| 80 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
|
| 81 |
+
# ch_in, ch_out, weights, kernel, stride, padding, groups
|
| 82 |
+
super().__init__()
|
| 83 |
+
assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels'
|
| 84 |
+
conv = keras.layers.DepthwiseConv2D(
|
| 85 |
+
kernel_size=k,
|
| 86 |
+
depth_multiplier=c2 // c1,
|
| 87 |
+
strides=s,
|
| 88 |
+
padding='SAME' if s == 1 else 'VALID',
|
| 89 |
+
use_bias=not hasattr(w, 'bn'),
|
| 90 |
+
depthwise_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
|
| 91 |
+
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
|
| 92 |
+
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
|
| 93 |
+
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
|
| 94 |
+
self.act = activations(w.act) if act else tf.identity
|
| 95 |
+
|
| 96 |
+
def call(self, inputs):
|
| 97 |
+
return self.act(self.bn(self.conv(inputs)))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class TFDWConvTranspose2d(keras.layers.Layer):
|
| 101 |
+
# Depthwise ConvTranspose2d
|
| 102 |
+
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
|
| 103 |
+
# ch_in, ch_out, weights, kernel, stride, padding, groups
|
| 104 |
+
super().__init__()
|
| 105 |
+
assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
|
| 106 |
+
assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
|
| 107 |
+
weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
|
| 108 |
+
self.c1 = c1
|
| 109 |
+
self.conv = [
|
| 110 |
+
keras.layers.Conv2DTranspose(filters=1,
|
| 111 |
+
kernel_size=k,
|
| 112 |
+
strides=s,
|
| 113 |
+
padding='VALID',
|
| 114 |
+
output_padding=p2,
|
| 115 |
+
use_bias=True,
|
| 116 |
+
kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
|
| 117 |
+
bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]
|
| 118 |
+
|
| 119 |
+
def call(self, inputs):
|
| 120 |
+
return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TFFocus(keras.layers.Layer):
|
| 124 |
+
# Focus wh information into c-space
|
| 125 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
|
| 126 |
+
# ch_in, ch_out, kernel, stride, padding, groups
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
|
| 129 |
+
|
| 130 |
+
def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
|
| 131 |
+
# inputs = inputs / 255 # normalize 0-255 to 0-1
|
| 132 |
+
inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
|
| 133 |
+
return self.conv(tf.concat(inputs, 3))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class TFBottleneck(keras.layers.Layer):
|
| 137 |
+
# Standard bottleneck
|
| 138 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
|
| 139 |
+
super().__init__()
|
| 140 |
+
c_ = int(c2 * e) # hidden channels
|
| 141 |
+
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
| 142 |
+
self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
|
| 143 |
+
self.add = shortcut and c1 == c2
|
| 144 |
+
|
| 145 |
+
def call(self, inputs):
|
| 146 |
+
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class TFCrossConv(keras.layers.Layer):
|
| 150 |
+
# Cross Convolution
|
| 151 |
+
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
|
| 152 |
+
super().__init__()
|
| 153 |
+
c_ = int(c2 * e) # hidden channels
|
| 154 |
+
self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
|
| 155 |
+
self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
|
| 156 |
+
self.add = shortcut and c1 == c2
|
| 157 |
+
|
| 158 |
+
def call(self, inputs):
|
| 159 |
+
return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class TFConv2d(keras.layers.Layer):
|
| 163 |
+
# Substitution for PyTorch nn.Conv2D
|
| 164 |
+
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
|
| 165 |
+
super().__init__()
|
| 166 |
+
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
|
| 167 |
+
self.conv = keras.layers.Conv2D(filters=c2,
|
| 168 |
+
kernel_size=k,
|
| 169 |
+
strides=s,
|
| 170 |
+
padding='VALID',
|
| 171 |
+
use_bias=bias,
|
| 172 |
+
kernel_initializer=keras.initializers.Constant(
|
| 173 |
+
w.weight.permute(2, 3, 1, 0).numpy()),
|
| 174 |
+
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)
|
| 175 |
+
|
| 176 |
+
def call(self, inputs):
|
| 177 |
+
return self.conv(inputs)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class TFBottleneckCSP(keras.layers.Layer):
|
| 181 |
+
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
| 182 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
| 183 |
+
# ch_in, ch_out, number, shortcut, groups, expansion
|
| 184 |
+
super().__init__()
|
| 185 |
+
c_ = int(c2 * e) # hidden channels
|
| 186 |
+
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
| 187 |
+
self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
|
| 188 |
+
self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
|
| 189 |
+
self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
|
| 190 |
+
self.bn = TFBN(w.bn)
|
| 191 |
+
self.act = lambda x: keras.activations.swish(x)
|
| 192 |
+
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
|
| 193 |
+
|
| 194 |
+
def call(self, inputs):
|
| 195 |
+
y1 = self.cv3(self.m(self.cv1(inputs)))
|
| 196 |
+
y2 = self.cv2(inputs)
|
| 197 |
+
return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class TFC3(keras.layers.Layer):
|
| 201 |
+
# CSP Bottleneck with 3 convolutions
|
| 202 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
| 203 |
+
# ch_in, ch_out, number, shortcut, groups, expansion
|
| 204 |
+
super().__init__()
|
| 205 |
+
c_ = int(c2 * e) # hidden channels
|
| 206 |
+
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
| 207 |
+
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
|
| 208 |
+
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
|
| 209 |
+
self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
|
| 210 |
+
|
| 211 |
+
def call(self, inputs):
|
| 212 |
+
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class TFC3x(keras.layers.Layer):
|
| 216 |
+
# 3 module with cross-convolutions
|
| 217 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
|
| 218 |
+
# ch_in, ch_out, number, shortcut, groups, expansion
|
| 219 |
+
super().__init__()
|
| 220 |
+
c_ = int(c2 * e) # hidden channels
|
| 221 |
+
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
| 222 |
+
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
|
| 223 |
+
self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
|
| 224 |
+
self.m = keras.Sequential([
|
| 225 |
+
TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)])
|
| 226 |
+
|
| 227 |
+
def call(self, inputs):
|
| 228 |
+
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class TFSPP(keras.layers.Layer):
|
| 232 |
+
# Spatial pyramid pooling layer used in YOLOv3-SPP
|
| 233 |
+
def __init__(self, c1, c2, k=(5, 9, 13), w=None):
|
| 234 |
+
super().__init__()
|
| 235 |
+
c_ = c1 // 2 # hidden channels
|
| 236 |
+
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
| 237 |
+
self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
|
| 238 |
+
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
|
| 239 |
+
|
| 240 |
+
def call(self, inputs):
|
| 241 |
+
x = self.cv1(inputs)
|
| 242 |
+
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class TFSPPF(keras.layers.Layer):
|
| 246 |
+
# Spatial pyramid pooling-Fast layer
|
| 247 |
+
def __init__(self, c1, c2, k=5, w=None):
|
| 248 |
+
super().__init__()
|
| 249 |
+
c_ = c1 // 2 # hidden channels
|
| 250 |
+
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
| 251 |
+
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
|
| 252 |
+
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
|
| 253 |
+
|
| 254 |
+
def call(self, inputs):
|
| 255 |
+
x = self.cv1(inputs)
|
| 256 |
+
y1 = self.m(x)
|
| 257 |
+
y2 = self.m(y1)
|
| 258 |
+
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class TFDetect(keras.layers.Layer):
|
| 262 |
+
# TF YOLO Detect layer
|
| 263 |
+
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
|
| 266 |
+
self.nc = nc # number of classes
|
| 267 |
+
self.no = nc + 5 # number of outputs per anchor
|
| 268 |
+
self.nl = len(anchors) # number of detection layers
|
| 269 |
+
self.na = len(anchors[0]) // 2 # number of anchors
|
| 270 |
+
self.grid = [tf.zeros(1)] * self.nl # init grid
|
| 271 |
+
self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
|
| 272 |
+
self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
|
| 273 |
+
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
|
| 274 |
+
self.training = False # set to False after building model
|
| 275 |
+
self.imgsz = imgsz
|
| 276 |
+
for i in range(self.nl):
|
| 277 |
+
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
|
| 278 |
+
self.grid[i] = self._make_grid(nx, ny)
|
| 279 |
+
|
| 280 |
+
def call(self, inputs):
|
| 281 |
+
z = [] # inference output
|
| 282 |
+
x = []
|
| 283 |
+
for i in range(self.nl):
|
| 284 |
+
x.append(self.m[i](inputs[i]))
|
| 285 |
+
# x(bs,20,20,255) to x(bs,3,20,20,85)
|
| 286 |
+
ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
|
| 287 |
+
x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
|
| 288 |
+
|
| 289 |
+
if not self.training: # inference
|
| 290 |
+
y = x[i]
|
| 291 |
+
grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
|
| 292 |
+
anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
|
| 293 |
+
xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
|
| 294 |
+
wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
|
| 295 |
+
# Normalize xywh to 0-1 to reduce calibration error
|
| 296 |
+
xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
| 297 |
+
wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
|
| 298 |
+
y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
|
| 299 |
+
z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
|
| 300 |
+
|
| 301 |
+
return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1),)
|
| 302 |
+
|
| 303 |
+
@staticmethod
|
| 304 |
+
def _make_grid(nx=20, ny=20):
|
| 305 |
+
# yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
| 306 |
+
# return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
|
| 307 |
+
xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
|
| 308 |
+
return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class TFSegment(TFDetect):
|
| 312 |
+
# YOLO Segment head for segmentation models
|
| 313 |
+
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
|
| 314 |
+
super().__init__(nc, anchors, ch, imgsz, w)
|
| 315 |
+
self.nm = nm # number of masks
|
| 316 |
+
self.npr = npr # number of protos
|
| 317 |
+
self.no = 5 + nc + self.nm # number of outputs per anchor
|
| 318 |
+
self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv
|
| 319 |
+
self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos
|
| 320 |
+
self.detect = TFDetect.call
|
| 321 |
+
|
| 322 |
+
def call(self, x):
|
| 323 |
+
p = self.proto(x[0])
|
| 324 |
+
# p = TFUpsample(None, scale_factor=4, mode='nearest')(self.proto(x[0])) # (optional) full-size protos
|
| 325 |
+
p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160)
|
| 326 |
+
x = self.detect(self, x)
|
| 327 |
+
return (x, p) if self.training else (x[0], p)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class TFProto(keras.layers.Layer):
|
| 331 |
+
|
| 332 |
+
def __init__(self, c1, c_=256, c2=32, w=None):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
|
| 335 |
+
self.upsample = TFUpsample(None, scale_factor=2, mode='nearest')
|
| 336 |
+
self.cv2 = TFConv(c_, c_, k=3, w=w.cv2)
|
| 337 |
+
self.cv3 = TFConv(c_, c2, w=w.cv3)
|
| 338 |
+
|
| 339 |
+
def call(self, inputs):
|
| 340 |
+
return self.cv3(self.cv2(self.upsample(self.cv1(inputs))))
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class TFUpsample(keras.layers.Layer):
|
| 344 |
+
# TF version of torch.nn.Upsample()
|
| 345 |
+
def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
|
| 346 |
+
super().__init__()
|
| 347 |
+
assert scale_factor % 2 == 0, "scale_factor must be multiple of 2"
|
| 348 |
+
self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * scale_factor, x.shape[2] * scale_factor), mode)
|
| 349 |
+
# self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
|
| 350 |
+
# with default arguments: align_corners=False, half_pixel_centers=False
|
| 351 |
+
# self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
|
| 352 |
+
# size=(x.shape[1] * 2, x.shape[2] * 2))
|
| 353 |
+
|
| 354 |
+
def call(self, inputs):
|
| 355 |
+
return self.upsample(inputs)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class TFConcat(keras.layers.Layer):
|
| 359 |
+
# TF version of torch.concat()
|
| 360 |
+
def __init__(self, dimension=1, w=None):
|
| 361 |
+
super().__init__()
|
| 362 |
+
assert dimension == 1, "convert only NCHW to NHWC concat"
|
| 363 |
+
self.d = 3
|
| 364 |
+
|
| 365 |
+
def call(self, inputs):
|
| 366 |
+
return tf.concat(inputs, self.d)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
|
| 370 |
+
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
| 371 |
+
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
|
| 372 |
+
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
| 373 |
+
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
| 374 |
+
|
| 375 |
+
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
| 376 |
+
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
| 377 |
+
m_str = m
|
| 378 |
+
m = eval(m) if isinstance(m, str) else m # eval strings
|
| 379 |
+
for j, a in enumerate(args):
|
| 380 |
+
try:
|
| 381 |
+
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
| 382 |
+
except NameError:
|
| 383 |
+
pass
|
| 384 |
+
|
| 385 |
+
n = max(round(n * gd), 1) if n > 1 else n # depth gain
|
| 386 |
+
if m in [
|
| 387 |
+
nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
|
| 388 |
+
BottleneckCSP, C3, C3x]:
|
| 389 |
+
c1, c2 = ch[f], args[0]
|
| 390 |
+
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
|
| 391 |
+
|
| 392 |
+
args = [c1, c2, *args[1:]]
|
| 393 |
+
if m in [BottleneckCSP, C3, C3x]:
|
| 394 |
+
args.insert(2, n)
|
| 395 |
+
n = 1
|
| 396 |
+
elif m is nn.BatchNorm2d:
|
| 397 |
+
args = [ch[f]]
|
| 398 |
+
elif m is Concat:
|
| 399 |
+
c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
|
| 400 |
+
elif m in [Detect, Segment]:
|
| 401 |
+
args.append([ch[x + 1] for x in f])
|
| 402 |
+
if isinstance(args[1], int): # number of anchors
|
| 403 |
+
args[1] = [list(range(args[1] * 2))] * len(f)
|
| 404 |
+
if m is Segment:
|
| 405 |
+
args[3] = make_divisible(args[3] * gw, 8)
|
| 406 |
+
args.append(imgsz)
|
| 407 |
+
else:
|
| 408 |
+
c2 = ch[f]
|
| 409 |
+
|
| 410 |
+
tf_m = eval('TF' + m_str.replace('nn.', ''))
|
| 411 |
+
m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
|
| 412 |
+
else tf_m(*args, w=model.model[i]) # module
|
| 413 |
+
|
| 414 |
+
torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
| 415 |
+
t = str(m)[8:-2].replace('__main__.', '') # module type
|
| 416 |
+
np = sum(x.numel() for x in torch_m_.parameters()) # number params
|
| 417 |
+
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
| 418 |
+
LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
|
| 419 |
+
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
| 420 |
+
layers.append(m_)
|
| 421 |
+
ch.append(c2)
|
| 422 |
+
return keras.Sequential(layers), sorted(save)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class TFModel:
|
| 426 |
+
# TF YOLO model
|
| 427 |
+
def __init__(self, cfg='yolo.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
|
| 428 |
+
super().__init__()
|
| 429 |
+
if isinstance(cfg, dict):
|
| 430 |
+
self.yaml = cfg # model dict
|
| 431 |
+
else: # is *.yaml
|
| 432 |
+
import yaml # for torch hub
|
| 433 |
+
self.yaml_file = Path(cfg).name
|
| 434 |
+
with open(cfg) as f:
|
| 435 |
+
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
| 436 |
+
|
| 437 |
+
# Define model
|
| 438 |
+
if nc and nc != self.yaml['nc']:
|
| 439 |
+
LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
|
| 440 |
+
self.yaml['nc'] = nc # override yaml value
|
| 441 |
+
self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
|
| 442 |
+
|
| 443 |
+
def predict(self,
|
| 444 |
+
inputs,
|
| 445 |
+
tf_nms=False,
|
| 446 |
+
agnostic_nms=False,
|
| 447 |
+
topk_per_class=100,
|
| 448 |
+
topk_all=100,
|
| 449 |
+
iou_thres=0.45,
|
| 450 |
+
conf_thres=0.25):
|
| 451 |
+
y = [] # outputs
|
| 452 |
+
x = inputs
|
| 453 |
+
for m in self.model.layers:
|
| 454 |
+
if m.f != -1: # if not from previous layer
|
| 455 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
| 456 |
+
|
| 457 |
+
x = m(x) # run
|
| 458 |
+
y.append(x if m.i in self.savelist else None) # save output
|
| 459 |
+
|
| 460 |
+
# Add TensorFlow NMS
|
| 461 |
+
if tf_nms:
|
| 462 |
+
boxes = self._xywh2xyxy(x[0][..., :4])
|
| 463 |
+
probs = x[0][:, :, 4:5]
|
| 464 |
+
classes = x[0][:, :, 5:]
|
| 465 |
+
scores = probs * classes
|
| 466 |
+
if agnostic_nms:
|
| 467 |
+
nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
|
| 468 |
+
else:
|
| 469 |
+
boxes = tf.expand_dims(boxes, 2)
|
| 470 |
+
nms = tf.image.combined_non_max_suppression(boxes,
|
| 471 |
+
scores,
|
| 472 |
+
topk_per_class,
|
| 473 |
+
topk_all,
|
| 474 |
+
iou_thres,
|
| 475 |
+
conf_thres,
|
| 476 |
+
clip_boxes=False)
|
| 477 |
+
return (nms,)
|
| 478 |
+
return x # output [1,6300,85] = [xywh, conf, class0, class1, ...]
|
| 479 |
+
# x = x[0] # [x(1,6300,85), ...] to x(6300,85)
|
| 480 |
+
# xywh = x[..., :4] # x(6300,4) boxes
|
| 481 |
+
# conf = x[..., 4:5] # x(6300,1) confidences
|
| 482 |
+
# cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
|
| 483 |
+
# return tf.concat([conf, cls, xywh], 1)
|
| 484 |
+
|
| 485 |
+
@staticmethod
|
| 486 |
+
def _xywh2xyxy(xywh):
|
| 487 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 488 |
+
x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
|
| 489 |
+
return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class AgnosticNMS(keras.layers.Layer):
|
| 493 |
+
# TF Agnostic NMS
|
| 494 |
+
def call(self, input, topk_all, iou_thres, conf_thres):
|
| 495 |
+
# wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
|
| 496 |
+
return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
|
| 497 |
+
input,
|
| 498 |
+
fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
|
| 499 |
+
name='agnostic_nms')
|
| 500 |
+
|
| 501 |
+
@staticmethod
|
| 502 |
+
def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
|
| 503 |
+
boxes, classes, scores = x
|
| 504 |
+
class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
|
| 505 |
+
scores_inp = tf.reduce_max(scores, -1)
|
| 506 |
+
selected_inds = tf.image.non_max_suppression(boxes,
|
| 507 |
+
scores_inp,
|
| 508 |
+
max_output_size=topk_all,
|
| 509 |
+
iou_threshold=iou_thres,
|
| 510 |
+
score_threshold=conf_thres)
|
| 511 |
+
selected_boxes = tf.gather(boxes, selected_inds)
|
| 512 |
+
padded_boxes = tf.pad(selected_boxes,
|
| 513 |
+
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
|
| 514 |
+
mode="CONSTANT",
|
| 515 |
+
constant_values=0.0)
|
| 516 |
+
selected_scores = tf.gather(scores_inp, selected_inds)
|
| 517 |
+
padded_scores = tf.pad(selected_scores,
|
| 518 |
+
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
|
| 519 |
+
mode="CONSTANT",
|
| 520 |
+
constant_values=-1.0)
|
| 521 |
+
selected_classes = tf.gather(class_inds, selected_inds)
|
| 522 |
+
padded_classes = tf.pad(selected_classes,
|
| 523 |
+
paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
|
| 524 |
+
mode="CONSTANT",
|
| 525 |
+
constant_values=-1.0)
|
| 526 |
+
valid_detections = tf.shape(selected_inds)[0]
|
| 527 |
+
return padded_boxes, padded_scores, padded_classes, valid_detections
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def activations(act=nn.SiLU):
|
| 531 |
+
# Returns TF activation from input PyTorch activation
|
| 532 |
+
if isinstance(act, nn.LeakyReLU):
|
| 533 |
+
return lambda x: keras.activations.relu(x, alpha=0.1)
|
| 534 |
+
elif isinstance(act, nn.Hardswish):
|
| 535 |
+
return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
|
| 536 |
+
elif isinstance(act, (nn.SiLU, SiLU)):
|
| 537 |
+
return lambda x: keras.activations.swish(x)
|
| 538 |
+
else:
|
| 539 |
+
raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}')
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def representative_dataset_gen(dataset, ncalib=100):
|
| 543 |
+
# Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
|
| 544 |
+
for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
|
| 545 |
+
im = np.transpose(img, [1, 2, 0])
|
| 546 |
+
im = np.expand_dims(im, axis=0).astype(np.float32)
|
| 547 |
+
im /= 255
|
| 548 |
+
yield [im]
|
| 549 |
+
if n >= ncalib:
|
| 550 |
+
break
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def run(
|
| 554 |
+
weights=ROOT / 'yolo.pt', # weights path
|
| 555 |
+
imgsz=(640, 640), # inference size h,w
|
| 556 |
+
batch_size=1, # batch size
|
| 557 |
+
dynamic=False, # dynamic batch size
|
| 558 |
+
):
|
| 559 |
+
# PyTorch model
|
| 560 |
+
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
| 561 |
+
model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
|
| 562 |
+
_ = model(im) # inference
|
| 563 |
+
model.info()
|
| 564 |
+
|
| 565 |
+
# TensorFlow model
|
| 566 |
+
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
|
| 567 |
+
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
| 568 |
+
_ = tf_model.predict(im) # inference
|
| 569 |
+
|
| 570 |
+
# Keras model
|
| 571 |
+
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
|
| 572 |
+
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
|
| 573 |
+
keras_model.summary()
|
| 574 |
+
|
| 575 |
+
LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def parse_opt():
|
| 579 |
+
parser = argparse.ArgumentParser()
|
| 580 |
+
parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='weights path')
|
| 581 |
+
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
|
| 582 |
+
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
|
| 583 |
+
parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
|
| 584 |
+
opt = parser.parse_args()
|
| 585 |
+
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
|
| 586 |
+
print_args(vars(opt))
|
| 587 |
+
return opt
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def main(opt):
|
| 591 |
+
run(**vars(opt))
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
if __name__ == "__main__":
|
| 595 |
+
opt = parse_opt()
|
| 596 |
+
main(opt)
|
models/yolo.py
ADDED
|
@@ -0,0 +1,771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import platform
|
| 4 |
+
import sys
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
FILE = Path(__file__).resolve()
|
| 9 |
+
ROOT = FILE.parents[1] # YOLO root directory
|
| 10 |
+
if str(ROOT) not in sys.path:
|
| 11 |
+
sys.path.append(str(ROOT)) # add ROOT to PATH
|
| 12 |
+
if platform.system() != 'Windows':
|
| 13 |
+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
| 14 |
+
|
| 15 |
+
from models.common import *
|
| 16 |
+
from models.experimental import *
|
| 17 |
+
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
|
| 18 |
+
from utils.plots import feature_visualization
|
| 19 |
+
from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
|
| 20 |
+
time_sync)
|
| 21 |
+
from utils.tal.anchor_generator import make_anchors, dist2bbox
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import thop # for FLOPs computation
|
| 25 |
+
except ImportError:
|
| 26 |
+
thop = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Detect(nn.Module):
|
| 30 |
+
# YOLO Detect head for detection models
|
| 31 |
+
dynamic = False # force grid reconstruction
|
| 32 |
+
export = False # export mode
|
| 33 |
+
shape = None
|
| 34 |
+
anchors = torch.empty(0) # init
|
| 35 |
+
strides = torch.empty(0) # init
|
| 36 |
+
|
| 37 |
+
def __init__(self, nc=80, ch=(), inplace=True): # detection layer
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.nc = nc # number of classes
|
| 40 |
+
self.nl = len(ch) # number of detection layers
|
| 41 |
+
self.reg_max = 16
|
| 42 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
| 43 |
+
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
| 44 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
| 45 |
+
|
| 46 |
+
c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
|
| 47 |
+
self.cv2 = nn.ModuleList(
|
| 48 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
|
| 49 |
+
self.cv3 = nn.ModuleList(
|
| 50 |
+
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
| 51 |
+
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
shape = x[0].shape # BCHW
|
| 55 |
+
for i in range(self.nl):
|
| 56 |
+
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
| 57 |
+
if self.training:
|
| 58 |
+
return x
|
| 59 |
+
elif self.dynamic or self.shape != shape:
|
| 60 |
+
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
| 61 |
+
self.shape = shape
|
| 62 |
+
|
| 63 |
+
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
|
| 64 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 65 |
+
y = torch.cat((dbox, cls.sigmoid()), 1)
|
| 66 |
+
return y if self.export else (y, x)
|
| 67 |
+
|
| 68 |
+
def bias_init(self):
|
| 69 |
+
# Initialize Detect() biases, WARNING: requires stride availability
|
| 70 |
+
m = self # self.model[-1] # Detect() module
|
| 71 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
| 72 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
| 73 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
| 74 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 75 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DDetect(nn.Module):
|
| 79 |
+
# YOLO Detect head for detection models
|
| 80 |
+
dynamic = False # force grid reconstruction
|
| 81 |
+
export = False # export mode
|
| 82 |
+
shape = None
|
| 83 |
+
anchors = torch.empty(0) # init
|
| 84 |
+
strides = torch.empty(0) # init
|
| 85 |
+
|
| 86 |
+
def __init__(self, nc=80, ch=(), inplace=True): # detection layer
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.nc = nc # number of classes
|
| 89 |
+
self.nl = len(ch) # number of detection layers
|
| 90 |
+
self.reg_max = 16
|
| 91 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
| 92 |
+
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
| 93 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
| 94 |
+
|
| 95 |
+
c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
|
| 96 |
+
self.cv2 = nn.ModuleList(
|
| 97 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch)
|
| 98 |
+
self.cv3 = nn.ModuleList(
|
| 99 |
+
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
| 100 |
+
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
shape = x[0].shape # BCHW
|
| 104 |
+
for i in range(self.nl):
|
| 105 |
+
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
| 106 |
+
if self.training:
|
| 107 |
+
return x
|
| 108 |
+
elif self.dynamic or self.shape != shape:
|
| 109 |
+
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
| 110 |
+
self.shape = shape
|
| 111 |
+
|
| 112 |
+
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
|
| 113 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 114 |
+
y = torch.cat((dbox, cls.sigmoid()), 1)
|
| 115 |
+
return y if self.export else (y, x)
|
| 116 |
+
|
| 117 |
+
def bias_init(self):
|
| 118 |
+
# Initialize Detect() biases, WARNING: requires stride availability
|
| 119 |
+
m = self # self.model[-1] # Detect() module
|
| 120 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
| 121 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
| 122 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
| 123 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 124 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class DualDetect(nn.Module):
|
| 128 |
+
# YOLO Detect head for detection models
|
| 129 |
+
dynamic = False # force grid reconstruction
|
| 130 |
+
export = False # export mode
|
| 131 |
+
shape = None
|
| 132 |
+
anchors = torch.empty(0) # init
|
| 133 |
+
strides = torch.empty(0) # init
|
| 134 |
+
|
| 135 |
+
def __init__(self, nc=80, ch=(), inplace=True): # detection layer
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.nc = nc # number of classes
|
| 138 |
+
self.nl = len(ch) // 2 # number of detection layers
|
| 139 |
+
self.reg_max = 16
|
| 140 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
| 141 |
+
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
| 142 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
| 143 |
+
|
| 144 |
+
c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
|
| 145 |
+
c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
|
| 146 |
+
self.cv2 = nn.ModuleList(
|
| 147 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
|
| 148 |
+
self.cv3 = nn.ModuleList(
|
| 149 |
+
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
|
| 150 |
+
self.cv4 = nn.ModuleList(
|
| 151 |
+
nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:])
|
| 152 |
+
self.cv5 = nn.ModuleList(
|
| 153 |
+
nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
|
| 154 |
+
self.dfl = DFL(self.reg_max)
|
| 155 |
+
self.dfl2 = DFL(self.reg_max)
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
shape = x[0].shape # BCHW
|
| 159 |
+
d1 = []
|
| 160 |
+
d2 = []
|
| 161 |
+
for i in range(self.nl):
|
| 162 |
+
d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
|
| 163 |
+
d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
|
| 164 |
+
if self.training:
|
| 165 |
+
return [d1, d2]
|
| 166 |
+
elif self.dynamic or self.shape != shape:
|
| 167 |
+
self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
|
| 168 |
+
self.shape = shape
|
| 169 |
+
|
| 170 |
+
box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
|
| 171 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 172 |
+
box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
|
| 173 |
+
dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 174 |
+
y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
|
| 175 |
+
return y if self.export else (y, [d1, d2])
|
| 176 |
+
|
| 177 |
+
def bias_init(self):
|
| 178 |
+
# Initialize Detect() biases, WARNING: requires stride availability
|
| 179 |
+
m = self # self.model[-1] # Detect() module
|
| 180 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
| 181 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
| 182 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
| 183 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 184 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 185 |
+
for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
|
| 186 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 187 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class DualDDetect(nn.Module):
|
| 191 |
+
# YOLO Detect head for detection models
|
| 192 |
+
dynamic = False # force grid reconstruction
|
| 193 |
+
export = False # export mode
|
| 194 |
+
shape = None
|
| 195 |
+
anchors = torch.empty(0) # init
|
| 196 |
+
strides = torch.empty(0) # init
|
| 197 |
+
|
| 198 |
+
def __init__(self, nc=80, ch=(), inplace=True): # detection layer
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.nc = nc # number of classes
|
| 201 |
+
self.nl = len(ch) // 2 # number of detection layers
|
| 202 |
+
self.reg_max = 16
|
| 203 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
| 204 |
+
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
| 205 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
| 206 |
+
|
| 207 |
+
c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
|
| 208 |
+
c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
|
| 209 |
+
self.cv2 = nn.ModuleList(
|
| 210 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
|
| 211 |
+
self.cv3 = nn.ModuleList(
|
| 212 |
+
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
|
| 213 |
+
self.cv4 = nn.ModuleList(
|
| 214 |
+
nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4), nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:])
|
| 215 |
+
self.cv5 = nn.ModuleList(
|
| 216 |
+
nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
|
| 217 |
+
self.dfl = DFL(self.reg_max)
|
| 218 |
+
self.dfl2 = DFL(self.reg_max)
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
shape = x[0].shape # BCHW
|
| 222 |
+
d1 = []
|
| 223 |
+
d2 = []
|
| 224 |
+
for i in range(self.nl):
|
| 225 |
+
d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
|
| 226 |
+
d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
|
| 227 |
+
if self.training:
|
| 228 |
+
return [d1, d2]
|
| 229 |
+
elif self.dynamic or self.shape != shape:
|
| 230 |
+
self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
|
| 231 |
+
self.shape = shape
|
| 232 |
+
|
| 233 |
+
box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
|
| 234 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 235 |
+
box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
|
| 236 |
+
dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 237 |
+
y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
|
| 238 |
+
return y if self.export else (y, [d1, d2])
|
| 239 |
+
#y = torch.cat((dbox2, cls2.sigmoid()), 1)
|
| 240 |
+
#return y if self.export else (y, d2)
|
| 241 |
+
#y1 = torch.cat((dbox, cls.sigmoid()), 1)
|
| 242 |
+
#y2 = torch.cat((dbox2, cls2.sigmoid()), 1)
|
| 243 |
+
#return [y1, y2] if self.export else [(y1, d1), (y2, d2)]
|
| 244 |
+
#return [y1, y2] if self.export else [(y1, y2), (d1, d2)]
|
| 245 |
+
|
| 246 |
+
def bias_init(self):
|
| 247 |
+
# Initialize Detect() biases, WARNING: requires stride availability
|
| 248 |
+
m = self # self.model[-1] # Detect() module
|
| 249 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
| 250 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
| 251 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
| 252 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 253 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 254 |
+
for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
|
| 255 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 256 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class TripleDetect(nn.Module):
|
| 260 |
+
# YOLO Detect head for detection models
|
| 261 |
+
dynamic = False # force grid reconstruction
|
| 262 |
+
export = False # export mode
|
| 263 |
+
shape = None
|
| 264 |
+
anchors = torch.empty(0) # init
|
| 265 |
+
strides = torch.empty(0) # init
|
| 266 |
+
|
| 267 |
+
def __init__(self, nc=80, ch=(), inplace=True): # detection layer
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.nc = nc # number of classes
|
| 270 |
+
self.nl = len(ch) // 3 # number of detection layers
|
| 271 |
+
self.reg_max = 16
|
| 272 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
| 273 |
+
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
| 274 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
| 275 |
+
|
| 276 |
+
c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
|
| 277 |
+
c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
|
| 278 |
+
c6, c7 = max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
|
| 279 |
+
self.cv2 = nn.ModuleList(
|
| 280 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
|
| 281 |
+
self.cv3 = nn.ModuleList(
|
| 282 |
+
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
|
| 283 |
+
self.cv4 = nn.ModuleList(
|
| 284 |
+
nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:self.nl*2])
|
| 285 |
+
self.cv5 = nn.ModuleList(
|
| 286 |
+
nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
|
| 287 |
+
self.cv6 = nn.ModuleList(
|
| 288 |
+
nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, 4 * self.reg_max, 1)) for x in ch[self.nl*2:self.nl*3])
|
| 289 |
+
self.cv7 = nn.ModuleList(
|
| 290 |
+
nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
|
| 291 |
+
self.dfl = DFL(self.reg_max)
|
| 292 |
+
self.dfl2 = DFL(self.reg_max)
|
| 293 |
+
self.dfl3 = DFL(self.reg_max)
|
| 294 |
+
|
| 295 |
+
def forward(self, x):
|
| 296 |
+
shape = x[0].shape # BCHW
|
| 297 |
+
d1 = []
|
| 298 |
+
d2 = []
|
| 299 |
+
d3 = []
|
| 300 |
+
for i in range(self.nl):
|
| 301 |
+
d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
|
| 302 |
+
d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
|
| 303 |
+
d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
|
| 304 |
+
if self.training:
|
| 305 |
+
return [d1, d2, d3]
|
| 306 |
+
elif self.dynamic or self.shape != shape:
|
| 307 |
+
self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
|
| 308 |
+
self.shape = shape
|
| 309 |
+
|
| 310 |
+
box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
|
| 311 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 312 |
+
box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
|
| 313 |
+
dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 314 |
+
box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
|
| 315 |
+
dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 316 |
+
y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
|
| 317 |
+
return y if self.export else (y, [d1, d2, d3])
|
| 318 |
+
|
| 319 |
+
def bias_init(self):
|
| 320 |
+
# Initialize Detect() biases, WARNING: requires stride availability
|
| 321 |
+
m = self # self.model[-1] # Detect() module
|
| 322 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
| 323 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
| 324 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
| 325 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 326 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 327 |
+
for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
|
| 328 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 329 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 330 |
+
for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
|
| 331 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 332 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class TripleDDetect(nn.Module):
|
| 336 |
+
# YOLO Detect head for detection models
|
| 337 |
+
dynamic = False # force grid reconstruction
|
| 338 |
+
export = False # export mode
|
| 339 |
+
shape = None
|
| 340 |
+
anchors = torch.empty(0) # init
|
| 341 |
+
strides = torch.empty(0) # init
|
| 342 |
+
|
| 343 |
+
def __init__(self, nc=80, ch=(), inplace=True): # detection layer
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.nc = nc # number of classes
|
| 346 |
+
self.nl = len(ch) // 3 # number of detection layers
|
| 347 |
+
self.reg_max = 16
|
| 348 |
+
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
| 349 |
+
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
| 350 |
+
self.stride = torch.zeros(self.nl) # strides computed during build
|
| 351 |
+
|
| 352 |
+
c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), \
|
| 353 |
+
max((ch[0], min((self.nc * 2, 128)))) # channels
|
| 354 |
+
c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), \
|
| 355 |
+
max((ch[self.nl], min((self.nc * 2, 128)))) # channels
|
| 356 |
+
c6, c7 = make_divisible(max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), 4), \
|
| 357 |
+
max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
|
| 358 |
+
self.cv2 = nn.ModuleList(
|
| 359 |
+
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4),
|
| 360 |
+
nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
|
| 361 |
+
self.cv3 = nn.ModuleList(
|
| 362 |
+
nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
|
| 363 |
+
self.cv4 = nn.ModuleList(
|
| 364 |
+
nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4),
|
| 365 |
+
nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:self.nl*2])
|
| 366 |
+
self.cv5 = nn.ModuleList(
|
| 367 |
+
nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
|
| 368 |
+
self.cv6 = nn.ModuleList(
|
| 369 |
+
nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3, g=4),
|
| 370 |
+
nn.Conv2d(c6, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl*2:self.nl*3])
|
| 371 |
+
self.cv7 = nn.ModuleList(
|
| 372 |
+
nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
|
| 373 |
+
self.dfl = DFL(self.reg_max)
|
| 374 |
+
self.dfl2 = DFL(self.reg_max)
|
| 375 |
+
self.dfl3 = DFL(self.reg_max)
|
| 376 |
+
|
| 377 |
+
def forward(self, x):
|
| 378 |
+
shape = x[0].shape # BCHW
|
| 379 |
+
d1 = []
|
| 380 |
+
d2 = []
|
| 381 |
+
d3 = []
|
| 382 |
+
for i in range(self.nl):
|
| 383 |
+
d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
|
| 384 |
+
d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
|
| 385 |
+
d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
|
| 386 |
+
if self.training:
|
| 387 |
+
return [d1, d2, d3]
|
| 388 |
+
elif self.dynamic or self.shape != shape:
|
| 389 |
+
self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
|
| 390 |
+
self.shape = shape
|
| 391 |
+
|
| 392 |
+
box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
|
| 393 |
+
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 394 |
+
box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
|
| 395 |
+
dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 396 |
+
box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
|
| 397 |
+
dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
| 398 |
+
#y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
|
| 399 |
+
#return y if self.export else (y, [d1, d2, d3])
|
| 400 |
+
y = torch.cat((dbox3, cls3.sigmoid()), 1)
|
| 401 |
+
return y if self.export else (y, d3)
|
| 402 |
+
|
| 403 |
+
def bias_init(self):
|
| 404 |
+
# Initialize Detect() biases, WARNING: requires stride availability
|
| 405 |
+
m = self # self.model[-1] # Detect() module
|
| 406 |
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
| 407 |
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
| 408 |
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
| 409 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 410 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 411 |
+
for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
|
| 412 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 413 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 414 |
+
for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
|
| 415 |
+
a[-1].bias.data[:] = 1.0 # box
|
| 416 |
+
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class Segment(Detect):
|
| 420 |
+
# YOLO Segment head for segmentation models
|
| 421 |
+
def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
|
| 422 |
+
super().__init__(nc, ch, inplace)
|
| 423 |
+
self.nm = nm # number of masks
|
| 424 |
+
self.npr = npr # number of protos
|
| 425 |
+
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
| 426 |
+
self.detect = Detect.forward
|
| 427 |
+
|
| 428 |
+
c4 = max(ch[0] // 4, self.nm)
|
| 429 |
+
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
| 430 |
+
|
| 431 |
+
def forward(self, x):
|
| 432 |
+
p = self.proto(x[0])
|
| 433 |
+
bs = p.shape[0]
|
| 434 |
+
|
| 435 |
+
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
| 436 |
+
x = self.detect(self, x)
|
| 437 |
+
if self.training:
|
| 438 |
+
return x, mc, p
|
| 439 |
+
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class Panoptic(Detect):
|
| 443 |
+
# YOLO Panoptic head for panoptic segmentation models
|
| 444 |
+
def __init__(self, nc=80, sem_nc=93, nm=32, npr=256, ch=(), inplace=True):
|
| 445 |
+
super().__init__(nc, ch, inplace)
|
| 446 |
+
self.sem_nc = sem_nc
|
| 447 |
+
self.nm = nm # number of masks
|
| 448 |
+
self.npr = npr # number of protos
|
| 449 |
+
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
| 450 |
+
self.uconv = UConv(ch[0], ch[0]//4, self.sem_nc+self.nc)
|
| 451 |
+
self.detect = Detect.forward
|
| 452 |
+
|
| 453 |
+
c4 = max(ch[0] // 4, self.nm)
|
| 454 |
+
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def forward(self, x):
|
| 458 |
+
p = self.proto(x[0])
|
| 459 |
+
s = self.uconv(x[0])
|
| 460 |
+
bs = p.shape[0]
|
| 461 |
+
|
| 462 |
+
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
| 463 |
+
x = self.detect(self, x)
|
| 464 |
+
if self.training:
|
| 465 |
+
return x, mc, p, s
|
| 466 |
+
return (torch.cat([x, mc], 1), p, s) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p, s))
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class BaseModel(nn.Module):
|
| 470 |
+
# YOLO base model
|
| 471 |
+
def forward(self, x, profile=False, visualize=False):
|
| 472 |
+
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
| 473 |
+
|
| 474 |
+
def _forward_once(self, x, profile=False, visualize=False):
|
| 475 |
+
y, dt = [], [] # outputs
|
| 476 |
+
for m in self.model:
|
| 477 |
+
if m.f != -1: # if not from previous layer
|
| 478 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
| 479 |
+
if profile:
|
| 480 |
+
self._profile_one_layer(m, x, dt)
|
| 481 |
+
#print(m)
|
| 482 |
+
x = m(x) # run
|
| 483 |
+
y.append(x if m.i in self.save else None) # save output
|
| 484 |
+
if visualize:
|
| 485 |
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
| 486 |
+
return x
|
| 487 |
+
|
| 488 |
+
def _profile_one_layer(self, m, x, dt):
|
| 489 |
+
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
| 490 |
+
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
| 491 |
+
t = time_sync()
|
| 492 |
+
for _ in range(10):
|
| 493 |
+
m(x.copy() if c else x)
|
| 494 |
+
dt.append((time_sync() - t) * 100)
|
| 495 |
+
if m == self.model[0]:
|
| 496 |
+
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
| 497 |
+
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
| 498 |
+
if c:
|
| 499 |
+
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
| 500 |
+
|
| 501 |
+
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
| 502 |
+
LOGGER.info('Fusing layers... ')
|
| 503 |
+
for m in self.model.modules():
|
| 504 |
+
if isinstance(m, (RepConvN)) and hasattr(m, 'fuse_convs'):
|
| 505 |
+
m.fuse_convs()
|
| 506 |
+
m.forward = m.forward_fuse # update forward
|
| 507 |
+
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
| 508 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
| 509 |
+
delattr(m, 'bn') # remove batchnorm
|
| 510 |
+
m.forward = m.forward_fuse # update forward
|
| 511 |
+
self.info()
|
| 512 |
+
return self
|
| 513 |
+
|
| 514 |
+
def info(self, verbose=False, img_size=640): # print model information
|
| 515 |
+
model_info(self, verbose, img_size)
|
| 516 |
+
|
| 517 |
+
def _apply(self, fn):
|
| 518 |
+
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
| 519 |
+
self = super()._apply(fn)
|
| 520 |
+
m = self.model[-1] # Detect()
|
| 521 |
+
if isinstance(m, (Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, Panoptic)):
|
| 522 |
+
m.stride = fn(m.stride)
|
| 523 |
+
m.anchors = fn(m.anchors)
|
| 524 |
+
m.strides = fn(m.strides)
|
| 525 |
+
# m.grid = list(map(fn, m.grid))
|
| 526 |
+
return self
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class DetectionModel(BaseModel):
|
| 530 |
+
# YOLO detection model
|
| 531 |
+
def __init__(self, cfg='yolo.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
| 532 |
+
super().__init__()
|
| 533 |
+
if isinstance(cfg, dict):
|
| 534 |
+
self.yaml = cfg # model dict
|
| 535 |
+
else: # is *.yaml
|
| 536 |
+
import yaml # for torch hub
|
| 537 |
+
self.yaml_file = Path(cfg).name
|
| 538 |
+
with open(cfg, encoding='ascii', errors='ignore') as f:
|
| 539 |
+
self.yaml = yaml.safe_load(f) # model dict
|
| 540 |
+
|
| 541 |
+
# Define model
|
| 542 |
+
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
| 543 |
+
if nc and nc != self.yaml['nc']:
|
| 544 |
+
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
| 545 |
+
self.yaml['nc'] = nc # override yaml value
|
| 546 |
+
if anchors:
|
| 547 |
+
LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
|
| 548 |
+
self.yaml['anchors'] = round(anchors) # override yaml value
|
| 549 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
| 550 |
+
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
| 551 |
+
self.inplace = self.yaml.get('inplace', True)
|
| 552 |
+
|
| 553 |
+
# Build strides, anchors
|
| 554 |
+
m = self.model[-1] # Detect()
|
| 555 |
+
if isinstance(m, (Detect, DDetect, Segment, Panoptic)):
|
| 556 |
+
s = 256 # 2x min stride
|
| 557 |
+
m.inplace = self.inplace
|
| 558 |
+
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Panoptic)) else self.forward(x)
|
| 559 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
| 560 |
+
# check_anchor_order(m)
|
| 561 |
+
# m.anchors /= m.stride.view(-1, 1, 1)
|
| 562 |
+
self.stride = m.stride
|
| 563 |
+
m.bias_init() # only run once
|
| 564 |
+
if isinstance(m, (DualDetect, TripleDetect, DualDDetect, TripleDDetect)):
|
| 565 |
+
s = 256 # 2x min stride
|
| 566 |
+
m.inplace = self.inplace
|
| 567 |
+
#forward = lambda x: self.forward(x)[0][0] if isinstance(m, (DualSegment, DualPanoptic)) else self.forward(x)[0]
|
| 568 |
+
forward = lambda x: self.forward(x)[0]
|
| 569 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
| 570 |
+
# check_anchor_order(m)
|
| 571 |
+
# m.anchors /= m.stride.view(-1, 1, 1)
|
| 572 |
+
self.stride = m.stride
|
| 573 |
+
m.bias_init() # only run once
|
| 574 |
+
|
| 575 |
+
# Init weights, biases
|
| 576 |
+
initialize_weights(self)
|
| 577 |
+
self.info()
|
| 578 |
+
LOGGER.info('')
|
| 579 |
+
|
| 580 |
+
def forward(self, x, augment=False, profile=False, visualize=False):
|
| 581 |
+
if augment:
|
| 582 |
+
return self._forward_augment(x) # augmented inference, None
|
| 583 |
+
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
| 584 |
+
|
| 585 |
+
def _forward_augment(self, x):
|
| 586 |
+
img_size = x.shape[-2:] # height, width
|
| 587 |
+
s = [1, 0.83, 0.67] # scales
|
| 588 |
+
f = [None, 3, None] # flips (2-ud, 3-lr)
|
| 589 |
+
y = [] # outputs
|
| 590 |
+
for si, fi in zip(s, f):
|
| 591 |
+
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
|
| 592 |
+
yi = self._forward_once(xi)[0] # forward
|
| 593 |
+
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
| 594 |
+
yi = self._descale_pred(yi, fi, si, img_size)
|
| 595 |
+
y.append(yi)
|
| 596 |
+
y = self._clip_augmented(y) # clip augmented tails
|
| 597 |
+
return torch.cat(y, 1), None # augmented inference, train
|
| 598 |
+
|
| 599 |
+
def _descale_pred(self, p, flips, scale, img_size):
|
| 600 |
+
# de-scale predictions following augmented inference (inverse operation)
|
| 601 |
+
if self.inplace:
|
| 602 |
+
p[..., :4] /= scale # de-scale
|
| 603 |
+
if flips == 2:
|
| 604 |
+
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
|
| 605 |
+
elif flips == 3:
|
| 606 |
+
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
|
| 607 |
+
else:
|
| 608 |
+
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
|
| 609 |
+
if flips == 2:
|
| 610 |
+
y = img_size[0] - y # de-flip ud
|
| 611 |
+
elif flips == 3:
|
| 612 |
+
x = img_size[1] - x # de-flip lr
|
| 613 |
+
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
| 614 |
+
return p
|
| 615 |
+
|
| 616 |
+
def _clip_augmented(self, y):
|
| 617 |
+
# Clip YOLO augmented inference tails
|
| 618 |
+
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
| 619 |
+
g = sum(4 ** x for x in range(nl)) # grid points
|
| 620 |
+
e = 1 # exclude layer count
|
| 621 |
+
i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
|
| 622 |
+
y[0] = y[0][:, :-i] # large
|
| 623 |
+
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
| 624 |
+
y[-1] = y[-1][:, i:] # small
|
| 625 |
+
return y
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
Model = DetectionModel # retain YOLO 'Model' class for backwards compatibility
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class SegmentationModel(DetectionModel):
|
| 632 |
+
# YOLO segmentation model
|
| 633 |
+
def __init__(self, cfg='yolo-seg.yaml', ch=3, nc=None, anchors=None):
|
| 634 |
+
super().__init__(cfg, ch, nc, anchors)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
class ClassificationModel(BaseModel):
|
| 638 |
+
# YOLO classification model
|
| 639 |
+
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
|
| 640 |
+
super().__init__()
|
| 641 |
+
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
| 642 |
+
|
| 643 |
+
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
| 644 |
+
# Create a YOLO classification model from a YOLO detection model
|
| 645 |
+
if isinstance(model, DetectMultiBackend):
|
| 646 |
+
model = model.model # unwrap DetectMultiBackend
|
| 647 |
+
model.model = model.model[:cutoff] # backbone
|
| 648 |
+
m = model.model[-1] # last layer
|
| 649 |
+
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
| 650 |
+
c = Classify(ch, nc) # Classify()
|
| 651 |
+
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
| 652 |
+
model.model[-1] = c # replace
|
| 653 |
+
self.model = model.model
|
| 654 |
+
self.stride = model.stride
|
| 655 |
+
self.save = []
|
| 656 |
+
self.nc = nc
|
| 657 |
+
|
| 658 |
+
def _from_yaml(self, cfg):
|
| 659 |
+
# Create a YOLO classification model from a *.yaml file
|
| 660 |
+
self.model = None
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def parse_model(d, ch): # model_dict, input_channels(3)
|
| 664 |
+
# Parse a YOLO model.yaml dictionary
|
| 665 |
+
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
|
| 666 |
+
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
| 667 |
+
if act:
|
| 668 |
+
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
| 669 |
+
RepConvN.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
| 670 |
+
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
| 671 |
+
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
|
| 672 |
+
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
|
| 673 |
+
|
| 674 |
+
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
| 675 |
+
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
| 676 |
+
m = eval(m) if isinstance(m, str) else m # eval strings
|
| 677 |
+
for j, a in enumerate(args):
|
| 678 |
+
with contextlib.suppress(NameError):
|
| 679 |
+
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
| 680 |
+
|
| 681 |
+
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
| 682 |
+
if m in {
|
| 683 |
+
Conv, AConv, ConvTranspose,
|
| 684 |
+
Bottleneck, SPP, SPPF, DWConv, BottleneckCSP, nn.ConvTranspose2d, DWConvTranspose2d, SPPCSPC, ADown,
|
| 685 |
+
RepNCSPELAN4, SPPELAN}:
|
| 686 |
+
c1, c2 = ch[f], args[0]
|
| 687 |
+
if c2 != no: # if not output
|
| 688 |
+
c2 = make_divisible(c2 * gw, 8)
|
| 689 |
+
|
| 690 |
+
args = [c1, c2, *args[1:]]
|
| 691 |
+
if m in {BottleneckCSP, SPPCSPC}:
|
| 692 |
+
args.insert(2, n) # number of repeats
|
| 693 |
+
n = 1
|
| 694 |
+
elif m is nn.BatchNorm2d:
|
| 695 |
+
args = [ch[f]]
|
| 696 |
+
elif m in [Down0,Down1,Down2,Down3,Down4]:
|
| 697 |
+
c2 = args[0]
|
| 698 |
+
|
| 699 |
+
elif m is Concat:
|
| 700 |
+
c2 = sum(ch[x] for x in f)
|
| 701 |
+
elif m is Shortcut:
|
| 702 |
+
c2 = ch[f[0]]
|
| 703 |
+
elif m is ReOrg:
|
| 704 |
+
c2 = ch[f] * 4
|
| 705 |
+
elif m is CBLinear:
|
| 706 |
+
c2 = args[0]
|
| 707 |
+
c1 = ch[f]
|
| 708 |
+
args = [c1, c2, *args[1:]]
|
| 709 |
+
elif m is CBFuse:
|
| 710 |
+
c2 = ch[f[-1]]
|
| 711 |
+
# TODO: channel, gw, gd
|
| 712 |
+
elif m in {Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, Panoptic}:
|
| 713 |
+
args.append([ch[x] for x in f])
|
| 714 |
+
# if isinstance(args[1], int): # number of anchors
|
| 715 |
+
# args[1] = [list(range(args[1] * 2))] * len(f)
|
| 716 |
+
if m in {Segment, Panoptic}:
|
| 717 |
+
args[2] = make_divisible(args[2] * gw, 8)
|
| 718 |
+
elif m is Contract:
|
| 719 |
+
c2 = ch[f] * args[0] ** 2
|
| 720 |
+
elif m is Expand:
|
| 721 |
+
c2 = ch[f] // args[0] ** 2
|
| 722 |
+
else:
|
| 723 |
+
c2 = ch[f]
|
| 724 |
+
|
| 725 |
+
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
| 726 |
+
t = str(m)[8:-2].replace('__main__.', '') # module type
|
| 727 |
+
np = sum(x.numel() for x in m_.parameters()) # number params
|
| 728 |
+
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
|
| 729 |
+
LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
|
| 730 |
+
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
| 731 |
+
layers.append(m_)
|
| 732 |
+
if i == 0:
|
| 733 |
+
ch = []
|
| 734 |
+
ch.append(c2)
|
| 735 |
+
return nn.Sequential(*layers), sorted(save)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
if __name__ == '__main__':
|
| 739 |
+
parser = argparse.ArgumentParser()
|
| 740 |
+
parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml')
|
| 741 |
+
parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
|
| 742 |
+
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
| 743 |
+
parser.add_argument('--profile', action='store_true', help='profile model speed')
|
| 744 |
+
parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
|
| 745 |
+
parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
|
| 746 |
+
opt = parser.parse_args()
|
| 747 |
+
opt.cfg = check_yaml(opt.cfg) # check YAML
|
| 748 |
+
print_args(vars(opt))
|
| 749 |
+
device = select_device(opt.device)
|
| 750 |
+
|
| 751 |
+
# Create model
|
| 752 |
+
im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
|
| 753 |
+
model = Model(opt.cfg).to(device)
|
| 754 |
+
model.eval()
|
| 755 |
+
|
| 756 |
+
# Options
|
| 757 |
+
if opt.line_profile: # profile layer by layer
|
| 758 |
+
model(im, profile=True)
|
| 759 |
+
|
| 760 |
+
elif opt.profile: # profile forward-backward
|
| 761 |
+
results = profile(input=im, ops=[model], n=3)
|
| 762 |
+
|
| 763 |
+
elif opt.test: # test all models
|
| 764 |
+
for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
|
| 765 |
+
try:
|
| 766 |
+
_ = Model(cfg)
|
| 767 |
+
except Exception as e:
|
| 768 |
+
print(f'Error in {cfg}: {e}')
|
| 769 |
+
|
| 770 |
+
else: # report fused model summary
|
| 771 |
+
model.fuse()
|
spark repvit/repvit_1kpretrained_timm_style.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89abdcdc1a2865f96822bb61e3096159087c6f5d331961dd1fed8e0a9c58988e
|
| 3 |
+
size 269763237
|
spark/downstream_d2/README.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## About code isolation
|
| 2 |
+
|
| 3 |
+
This `downstream_d2` is isolated from pre-training codes. One can treat this `downstream_d2` as an independent codebase 🛠️.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## Fine-tuned ResNet-50 weights, log files, and performance
|
| 7 |
+
|
| 8 |
+
<div align="center">
|
| 9 |
+
|
| 10 |
+
[[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1H8605HbxGvrsu4x4rIoNr-Wkd7JkxFPQ/view?usp=share_link)]
|
| 11 |
+
[[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1Ue7SiQ1E_AwgtYo56Fm-iUlQPZ8vIwYj/view?usp=share_link)]
|
| 12 |
+
[[`metrics.json`](https://drive.google.com/file/d/1wfbUWh4svV8sPWya_0PAhsLHVayDQRCi/view?usp=share_link)]
|
| 13 |
+
[[`log.txt`](https://drive.google.com/file/d/11zVo_87pe9DMAmfNQK9FUfyjQWHTRKxV/view?usp=share_link)]
|
| 14 |
+
[[`tensorboard file`](https://drive.google.com/file/d/1aM1qj8c3-Uka1dZuYmKhgp1lNJpeMDMl/view?usp=share_link)]
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
<p align="center">
|
| 18 |
+
<img src="https://user-images.githubusercontent.com/39692511/211497479-0563e891-f2ad-4cf1-b682-a21c2be1442d.png" width=80%>
|
| 19 |
+
<p>
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Installation [Detectron2 v0.6](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ResNet on COCO
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
1. Let you in some python environment, e.g.:
|
| 26 |
+
```shell script
|
| 27 |
+
$ conda create -n spark python=3.8 -y
|
| 28 |
+
$ conda activate spark
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
2. Install `detectron2==0.6` (e.g., with `torch==1.10.0` and `cuda11.3`):
|
| 32 |
+
```shell script
|
| 33 |
+
$ pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
You can also find instructions for different pytorch/cuda versions on [this page](https://github.com/facebookresearch/detectron2/releases/tag/v0.6).
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
3. Put the COCO dataset folder at `downstream_d2/datasets/coco`.
|
| 40 |
+
The folder should follow the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) requried by `Detectron2`, which should look like this:
|
| 41 |
+
```
|
| 42 |
+
downstream_d2/datasets/coco:
|
| 43 |
+
annotations/:
|
| 44 |
+
captions_train2017.json captions_val2017.json
|
| 45 |
+
instances_train2017.json instances_val2017.json
|
| 46 |
+
person_keypoints_train2017.json person_keypoints_val2017.json
|
| 47 |
+
train2017/:
|
| 48 |
+
a_lot_images.jpg
|
| 49 |
+
val2017/:
|
| 50 |
+
a_lot_images.jpg
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
## Training from pre-trained checkpoint
|
| 55 |
+
|
| 56 |
+
The script file for COCO fine-tuning (object detection and instance segmentation) is [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py),
|
| 57 |
+
which is a modification of [Detectron2's tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py).
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
Before fine-tuning a ResNet50 pre-trained by SparK, you should first convert our checkpoint file to Detectron2-style `.pkl` file:
|
| 61 |
+
|
| 62 |
+
```shell script
|
| 63 |
+
$ cd /path/to/SparK/downstream_d2
|
| 64 |
+
$ python3 convert-timm-to-d2.py /some/path/to/resnet50_1kpretrained_timm_style.pth d2-style.pkl
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
For a ResNet50, you should see a log reporting `len(state)==318`:
|
| 68 |
+
```text
|
| 69 |
+
[convert] .pkl is generated! (from `/some/path/to/resnet50_1kpretrained_timm_style.pth`, to `d2-style.pkl`, len(state)==318)
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Then run fine-tuning on single machine with 8 gpus:
|
| 73 |
+
|
| 74 |
+
```shell script
|
| 75 |
+
$ cd /path/to/SparK/downstream_d2
|
| 76 |
+
$ python3 ./train_net.py --resume --num-gpus 8 --config-file ./configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml \
|
| 77 |
+
MODEL.WEIGHTS d2-style.pkl \
|
| 78 |
+
OUTPUT_DIR <your_output_dir>
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
For multiple machines, plus these args:
|
| 82 |
+
```shell script
|
| 83 |
+
--num-machines <total_num> --machine-rank <this_rank> --dist-url <url:port>
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
In `<your_output_dir>` you'll see the log files generated by `Detectron2`.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
## Details: how we modify the official Detectron2's [tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py) to get our [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py)
|
| 90 |
+
|
| 91 |
+
1. We add two new hyperparameters:
|
| 92 |
+
- str `SOLVER.OPTIMIZER`: use 'ADAM' (the same as 'ADAMW') or 'SGD' optimizer
|
| 93 |
+
- float `SOLVER.LR_DECAY`: the decay ratio (from 0. to 1.) of layer-wise learning rate decay trick
|
| 94 |
+
|
| 95 |
+
2. We implement layer-wise lr decay in [downstream_d2/lr_decay.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/lr_decay.py).
|
| 96 |
+
|
| 97 |
+
3. We write a script to convert our timm-style pre-trained ResNet weights to Detectron2-style in [downstream_d2/convert-timm-to-d2.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/convert-timm-to-d2.py).
|
| 98 |
+
|
| 99 |
+
4. We also add a hook for logging results to `cfg.OUTPUT_DIR/d2_coco_log.txt`.
|
| 100 |
+
|
| 101 |
+
All of our modifications to the original are commented with `# [modification] ...` in [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py) or other files.
|
spark/downstream_d2/configs/Base-RCNN-FPN.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL:
|
| 2 |
+
META_ARCHITECTURE: "GeneralizedRCNN"
|
| 3 |
+
BACKBONE:
|
| 4 |
+
NAME: "build_resnet_fpn_backbone"
|
| 5 |
+
RESNETS:
|
| 6 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
| 7 |
+
FPN:
|
| 8 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
| 9 |
+
ANCHOR_GENERATOR:
|
| 10 |
+
SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
|
| 11 |
+
ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
|
| 12 |
+
RPN:
|
| 13 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
|
| 14 |
+
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
|
| 15 |
+
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
|
| 16 |
+
# Detectron1 uses 2000 proposals per-batch,
|
| 17 |
+
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
|
| 18 |
+
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
|
| 19 |
+
POST_NMS_TOPK_TRAIN: 1000
|
| 20 |
+
POST_NMS_TOPK_TEST: 1000
|
| 21 |
+
ROI_HEADS:
|
| 22 |
+
NAME: "StandardROIHeads"
|
| 23 |
+
IN_FEATURES: ["p2", "p3", "p4", "p5"]
|
| 24 |
+
ROI_BOX_HEAD:
|
| 25 |
+
NAME: "FastRCNNConvFCHead"
|
| 26 |
+
NUM_FC: 2
|
| 27 |
+
POOLER_RESOLUTION: 7
|
| 28 |
+
ROI_MASK_HEAD:
|
| 29 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
| 30 |
+
NUM_CONV: 4
|
| 31 |
+
POOLER_RESOLUTION: 14
|
| 32 |
+
DATASETS:
|
| 33 |
+
TRAIN: ("coco_2017_train",)
|
| 34 |
+
TEST: ("coco_2017_val",)
|
| 35 |
+
SOLVER:
|
| 36 |
+
IMS_PER_BATCH: 16
|
| 37 |
+
BASE_LR: 0.02
|
| 38 |
+
STEPS: (60000, 80000)
|
| 39 |
+
MAX_ITER: 90000
|
| 40 |
+
INPUT:
|
| 41 |
+
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
|
| 42 |
+
VERSION: 2
|
spark/downstream_d2/configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_BASE_: "Base-RCNN-FPN.yaml"
|
| 2 |
+
MODEL:
|
| 3 |
+
WEIGHTS: "<see instructions>"
|
| 4 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
| 5 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
| 6 |
+
|
| 7 |
+
MASK_ON: True
|
| 8 |
+
BACKBONE:
|
| 9 |
+
FREEZE_AT: 0
|
| 10 |
+
RESNETS:
|
| 11 |
+
DEPTH: 50
|
| 12 |
+
NORM: "SyncBN"
|
| 13 |
+
STRIDE_IN_1X1: False
|
| 14 |
+
FPN:
|
| 15 |
+
NORM: "SyncBN"
|
| 16 |
+
ROI_BOX_HEAD:
|
| 17 |
+
NAME: "FastRCNNConvFCHead"
|
| 18 |
+
NUM_FC: 1
|
| 19 |
+
NUM_CONV: 4
|
| 20 |
+
POOLER_RESOLUTION: 7
|
| 21 |
+
NORM: "SyncBN"
|
| 22 |
+
ROI_MASK_HEAD:
|
| 23 |
+
NAME: "MaskRCNNConvUpsampleHead"
|
| 24 |
+
NUM_CONV: 4
|
| 25 |
+
POOLER_RESOLUTION: 14
|
| 26 |
+
NORM: "SyncBN"
|
| 27 |
+
|
| 28 |
+
INPUT:
|
| 29 |
+
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896)
|
| 30 |
+
CROP:
|
| 31 |
+
ENABLED: False
|
| 32 |
+
TYPE: "absolute_range"
|
| 33 |
+
SIZE: (384, 600)
|
| 34 |
+
FORMAT: "RGB"
|
| 35 |
+
TEST:
|
| 36 |
+
EVAL_PERIOD: 5000
|
| 37 |
+
PRECISE_BN:
|
| 38 |
+
ENABLED: True
|
| 39 |
+
|
| 40 |
+
SOLVER:
|
| 41 |
+
STEPS: (60000, 80000)
|
| 42 |
+
MAX_ITER: 90000
|
| 43 |
+
GAMMA: 0.25
|
| 44 |
+
BASE_LR: 0.00025
|
| 45 |
+
WARMUP_FACTOR: 0.01
|
| 46 |
+
WARMUP_ITERS: 1000
|
| 47 |
+
WEIGHT_DECAY: 0.0001
|
| 48 |
+
CHECKPOINT_PERIOD: 5000
|
| 49 |
+
CLIP_GRADIENTS:
|
| 50 |
+
ENABLED: False
|
| 51 |
+
CLIP_TYPE: "value"
|
| 52 |
+
CLIP_VALUE: 1.0
|
| 53 |
+
NORM_TYPE: 2.0
|
| 54 |
+
|
| 55 |
+
# compared to standard detectron2, we add these two new configurations:
|
| 56 |
+
OPTIMIZER: "ADAMW"
|
| 57 |
+
LR_DECAY: 0.6
|
spark/downstream_d2/convert-timm-to-d2.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
import pickle as pkl
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# we use `timm.models.ResNet` in pre-training, so keys are timm-style
|
| 15 |
+
def timm_resnet_to_detectron2_resnet(source_file, target_file):
|
| 16 |
+
pretrained: dict = torch.load(source_file, map_location='cpu')
|
| 17 |
+
for mod_k in {'state_dict', 'state', 'module', 'model'}:
|
| 18 |
+
if mod_k in pretrained:
|
| 19 |
+
pretrained = pretrained[mod_k]
|
| 20 |
+
if any(k.startswith('module.encoder_q.') for k in pretrained.keys()):
|
| 21 |
+
pretrained = {k.replace('module.encoder_q.', ''): v for k, v in pretrained.items() if k.startswith('module.encoder_q.')}
|
| 22 |
+
|
| 23 |
+
pkl_state = {}
|
| 24 |
+
for k, v in pretrained.items(): # convert resnet's keys from timm-style to d2-style
|
| 25 |
+
if 'layer' not in k:
|
| 26 |
+
k = 'stem.' + k
|
| 27 |
+
for t in [1, 2, 3, 4]:
|
| 28 |
+
k = k.replace(f'layer{t}', f'res{t+1}')
|
| 29 |
+
for t in [1, 2, 3]:
|
| 30 |
+
k = k.replace(f'bn{t}', f'conv{t}.norm')
|
| 31 |
+
k = k.replace('downsample.0', 'shortcut')
|
| 32 |
+
k = k.replace('downsample.1', 'shortcut.norm')
|
| 33 |
+
|
| 34 |
+
pkl_state[k] = v.detach().numpy()
|
| 35 |
+
|
| 36 |
+
with open(target_file, 'wb') as fp:
|
| 37 |
+
print(f'[convert] .pkl is generated! (from `{source_file}`, to `{target_file}`, len(state)=={len(pkl_state)})')
|
| 38 |
+
pkl.dump({'model': pkl_state, '__author__': 'https://github.com/keyu-tian/SparK', 'matching_heuristics': True}, fp)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == '__main__':
|
| 42 |
+
import sys
|
| 43 |
+
timm_resnet_to_detectron2_resnet(sys.argv[1], sys.argv[2])
|
spark/downstream_d2/lr_decay.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Set, Optional, Callable, Any
|
| 2 |
+
import torch
|
| 3 |
+
import copy
|
| 4 |
+
|
| 5 |
+
from detectron2.solver.build import reduce_param_groups
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def lr_factor_func(para_name: str, is_resnet50, dec: float, debug=False) -> float:
|
| 9 |
+
if dec == 0:
|
| 10 |
+
dec = 1.
|
| 11 |
+
|
| 12 |
+
N = 5 if is_resnet50 else 11
|
| 13 |
+
if '.stem.' in para_name:
|
| 14 |
+
layer_id = 0
|
| 15 |
+
elif '.res' in para_name:
|
| 16 |
+
ls = para_name.split('.res')[1].split('.')
|
| 17 |
+
if ls[0].isnumeric() and ls[1].isnumeric():
|
| 18 |
+
stage_id, block_id = int(ls[0]), int(ls[1])
|
| 19 |
+
if stage_id == 2: # res2
|
| 20 |
+
layer_id = 1
|
| 21 |
+
elif stage_id == 3: # res3
|
| 22 |
+
layer_id = 2
|
| 23 |
+
elif stage_id == 4: # res4
|
| 24 |
+
layer_id = 3 + block_id // 3 # 3, 4 or 4, 5
|
| 25 |
+
else: # res5
|
| 26 |
+
layer_id = N
|
| 27 |
+
else:
|
| 28 |
+
assert para_name.startswith('roi_heads.res5.norm.')
|
| 29 |
+
layer_id = N + 1 # roi_heads.res5.norm.weight and roi_heads.res5.norm.bias of C4
|
| 30 |
+
else:
|
| 31 |
+
layer_id = N + 1
|
| 32 |
+
|
| 33 |
+
exp = N + 1 - layer_id
|
| 34 |
+
return f'{dec:g} ** {exp}' if debug else dec ** exp
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# [modification] see: https://github.com/facebookresearch/detectron2/blob/v0.6/detectron2/solver/build.py#L134
|
| 38 |
+
# add the `lr_factor_func` to implement lr decay
|
| 39 |
+
def get_default_optimizer_params(
|
| 40 |
+
model: torch.nn.Module,
|
| 41 |
+
base_lr: Optional[float] = None,
|
| 42 |
+
weight_decay: Optional[float] = None,
|
| 43 |
+
weight_decay_norm: Optional[float] = None,
|
| 44 |
+
bias_lr_factor: Optional[float] = 1.0,
|
| 45 |
+
weight_decay_bias: Optional[float] = None,
|
| 46 |
+
lr_factor_func: Optional[Callable] = None,
|
| 47 |
+
overrides: Optional[Dict[str, Dict[str, float]]] = None,
|
| 48 |
+
) -> List[Dict[str, Any]]:
|
| 49 |
+
"""
|
| 50 |
+
Get default param list for optimizer, with support for a few types of
|
| 51 |
+
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
|
| 55 |
+
weight_decay: weight decay for every group by default. Can be omitted to use the one
|
| 56 |
+
in optimizer.
|
| 57 |
+
weight_decay_norm: override weight decay for params in normalization layers
|
| 58 |
+
bias_lr_factor: multiplier of lr for bias parameters.
|
| 59 |
+
weight_decay_bias: override weight decay for bias parameters.
|
| 60 |
+
lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
|
| 61 |
+
corresponding lr decay rate. Note that setting this option requires
|
| 62 |
+
also setting ``base_lr``.
|
| 63 |
+
overrides: if not `None`, provides values for optimizer hyperparameters
|
| 64 |
+
(LR, weight decay) for module parameters with a given name; e.g.
|
| 65 |
+
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
|
| 66 |
+
weight decay values for all module parameters named `embedding`.
|
| 67 |
+
|
| 68 |
+
For common detection models, ``weight_decay_norm`` is the only option
|
| 69 |
+
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
|
| 70 |
+
from Detectron1 that are not found useful.
|
| 71 |
+
|
| 72 |
+
Example:
|
| 73 |
+
::
|
| 74 |
+
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
|
| 75 |
+
lr=0.01, weight_decay=1e-4, momentum=0.9)
|
| 76 |
+
"""
|
| 77 |
+
if overrides is None:
|
| 78 |
+
overrides = {}
|
| 79 |
+
defaults = {}
|
| 80 |
+
if base_lr is not None:
|
| 81 |
+
defaults["lr"] = base_lr
|
| 82 |
+
if weight_decay is not None:
|
| 83 |
+
defaults["weight_decay"] = weight_decay
|
| 84 |
+
bias_overrides = {}
|
| 85 |
+
if bias_lr_factor is not None and bias_lr_factor != 1.0:
|
| 86 |
+
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
|
| 87 |
+
# exactly the same as regular weights.
|
| 88 |
+
if base_lr is None:
|
| 89 |
+
raise ValueError("bias_lr_factor requires base_lr")
|
| 90 |
+
bias_overrides["lr"] = base_lr * bias_lr_factor
|
| 91 |
+
if weight_decay_bias is not None:
|
| 92 |
+
bias_overrides["weight_decay"] = weight_decay_bias
|
| 93 |
+
if len(bias_overrides):
|
| 94 |
+
if "bias" in overrides:
|
| 95 |
+
raise ValueError("Conflicting overrides for 'bias'")
|
| 96 |
+
overrides["bias"] = bias_overrides
|
| 97 |
+
if lr_factor_func is not None:
|
| 98 |
+
if base_lr is None:
|
| 99 |
+
raise ValueError("lr_factor_func requires base_lr")
|
| 100 |
+
norm_module_types = (
|
| 101 |
+
torch.nn.BatchNorm1d,
|
| 102 |
+
torch.nn.BatchNorm2d,
|
| 103 |
+
torch.nn.BatchNorm3d,
|
| 104 |
+
torch.nn.SyncBatchNorm,
|
| 105 |
+
# NaiveSyncBatchNorm inherits from BatchNorm2d
|
| 106 |
+
torch.nn.GroupNorm,
|
| 107 |
+
torch.nn.InstanceNorm1d,
|
| 108 |
+
torch.nn.InstanceNorm2d,
|
| 109 |
+
torch.nn.InstanceNorm3d,
|
| 110 |
+
torch.nn.LayerNorm,
|
| 111 |
+
torch.nn.LocalResponseNorm,
|
| 112 |
+
)
|
| 113 |
+
params: List[Dict[str, Any]] = []
|
| 114 |
+
memo: Set[torch.nn.parameter.Parameter] = set()
|
| 115 |
+
for module_name, module in model.named_modules():
|
| 116 |
+
for module_param_name, value in module.named_parameters(recurse=False):
|
| 117 |
+
if not value.requires_grad:
|
| 118 |
+
continue
|
| 119 |
+
# Avoid duplicating parameters
|
| 120 |
+
if value in memo:
|
| 121 |
+
continue
|
| 122 |
+
memo.add(value)
|
| 123 |
+
|
| 124 |
+
hyperparams = copy.copy(defaults)
|
| 125 |
+
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
|
| 126 |
+
hyperparams["weight_decay"] = weight_decay_norm
|
| 127 |
+
if lr_factor_func is not None:
|
| 128 |
+
hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
|
| 129 |
+
|
| 130 |
+
hyperparams.update(overrides.get(module_param_name, {}))
|
| 131 |
+
params.append({"params": [value], **hyperparams})
|
| 132 |
+
return reduce_param_groups(params)
|
spark/downstream_d2/train_net.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This source code is licensed under the license found in the
|
| 7 |
+
# LICENSE file in the root directory of this source tree.
|
| 8 |
+
|
| 9 |
+
import datetime
|
| 10 |
+
import json
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import time
|
| 14 |
+
from collections import OrderedDict, defaultdict
|
| 15 |
+
from functools import partial
|
| 16 |
+
from pprint import pformat
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import detectron2.utils.comm as comm
|
| 21 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
| 22 |
+
from detectron2.config import get_cfg
|
| 23 |
+
from detectron2.data import MetadataCatalog
|
| 24 |
+
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch, PeriodicWriter
|
| 25 |
+
from detectron2.evaluation import (
|
| 26 |
+
CityscapesInstanceEvaluator,
|
| 27 |
+
CityscapesSemSegEvaluator,
|
| 28 |
+
COCOEvaluator,
|
| 29 |
+
COCOPanopticEvaluator,
|
| 30 |
+
DatasetEvaluators,
|
| 31 |
+
LVISEvaluator,
|
| 32 |
+
PascalVOCDetectionEvaluator,
|
| 33 |
+
SemSegEvaluator,
|
| 34 |
+
verify_results,
|
| 35 |
+
)
|
| 36 |
+
from detectron2.layers import get_norm
|
| 37 |
+
from detectron2.modeling import GeneralizedRCNNWithTTA
|
| 38 |
+
from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
|
| 39 |
+
from detectron2.solver.build import maybe_add_gradient_clipping
|
| 40 |
+
from detectron2.utils.events import EventWriter
|
| 41 |
+
|
| 42 |
+
from lr_decay import get_default_optimizer_params, lr_factor_func
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# [modification] for better logging
|
| 46 |
+
def _ex_repr(self):
|
| 47 |
+
d = vars(self)
|
| 48 |
+
ex = ', '.join(f'{k}={v}' for k, v in d.items() if not k.startswith('__') and k not in [
|
| 49 |
+
'trainer', 'before_train', 'after_train', 'before_step', 'after_step', 'state_dict',
|
| 50 |
+
'_model', '_data_loader', 'logger',
|
| 51 |
+
])
|
| 52 |
+
return f'{type(self).__name__}({ex})'
|
| 53 |
+
hooks.HookBase.__repr__ = _ex_repr
|
| 54 |
+
EventWriter.__repr__ = _ex_repr
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# [modification] add norm
|
| 58 |
+
@ROI_HEADS_REGISTRY.register()
|
| 59 |
+
class Res5ROIHeadsExtraNorm(Res5ROIHeads):
|
| 60 |
+
"""
|
| 61 |
+
As described in the MOCO paper, there is an extra BN layer
|
| 62 |
+
following the res5 stage.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def _build_res5_block(self, cfg):
|
| 66 |
+
seq, out_channels = super()._build_res5_block(cfg)
|
| 67 |
+
norm = cfg.MODEL.RESNETS.NORM
|
| 68 |
+
norm = get_norm(norm, out_channels)
|
| 69 |
+
seq.add_module("norm", norm)
|
| 70 |
+
return seq, out_channels
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Trainer(DefaultTrainer):
|
| 74 |
+
"""
|
| 75 |
+
We use the "DefaultTrainer" which contains pre-defined default logic for
|
| 76 |
+
standard training workflow. They may not work for you, especially if you
|
| 77 |
+
are working on a new research project. In that case you can write your
|
| 78 |
+
own training loop. You can use "tools/plain_train_net.py" as an example.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
# [modification] override the `build_optimizer` for using Adam and layer-wise lr decay
|
| 82 |
+
lr_decay_ratio: float = 1.0
|
| 83 |
+
@classmethod
|
| 84 |
+
def build_optimizer(cls, cfg, model):
|
| 85 |
+
is_resnet50 = int(cfg.MODEL.RESNETS.DEPTH) == 50
|
| 86 |
+
if comm.is_main_process():
|
| 87 |
+
dbg = defaultdict(list)
|
| 88 |
+
for module_name, module in model.named_modules():
|
| 89 |
+
for module_param_name, value in module.named_parameters(recurse=False):
|
| 90 |
+
if not value.requires_grad:
|
| 91 |
+
continue
|
| 92 |
+
lrf = lr_factor_func(f"{module_name}.{module_param_name}", is_resnet50=is_resnet50, dec=cls.lr_decay_ratio, debug=True)
|
| 93 |
+
dbg[lrf].append(f"{module_name}.{module_param_name}")
|
| 94 |
+
for k in sorted(dbg.keys()):
|
| 95 |
+
print(f'[{k}] {sorted(dbg[k])}')
|
| 96 |
+
print()
|
| 97 |
+
|
| 98 |
+
params = get_default_optimizer_params(
|
| 99 |
+
model,
|
| 100 |
+
base_lr=cfg.SOLVER.BASE_LR,
|
| 101 |
+
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
|
| 102 |
+
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
|
| 103 |
+
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
|
| 104 |
+
lr_factor_func=partial(lr_factor_func, is_resnet50=is_resnet50, dec=cls.lr_decay_ratio, debug=False)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
opt_clz = {
|
| 108 |
+
'sgd': partial(torch.optim.SGD, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV),
|
| 109 |
+
'adamw': torch.optim.AdamW,
|
| 110 |
+
'adam': torch.optim.AdamW,
|
| 111 |
+
}[cfg.SOLVER.OPTIMIZER.lower()]
|
| 112 |
+
return maybe_add_gradient_clipping(cfg, opt_clz)(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
| 116 |
+
return build_evaluator(cfg, dataset_name, output_folder)
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def test_with_TTA(cls, cfg, model):
|
| 120 |
+
logger = logging.getLogger("detectron2.trainer")
|
| 121 |
+
# In the end of training, run an evaluation with TTA
|
| 122 |
+
# Only support some R-CNN models.
|
| 123 |
+
logger.info("Running inference with test-time augmentation ...")
|
| 124 |
+
model = GeneralizedRCNNWithTTA(cfg, model)
|
| 125 |
+
evaluators = [
|
| 126 |
+
cls.build_evaluator(
|
| 127 |
+
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
|
| 128 |
+
)
|
| 129 |
+
for name in cfg.DATASETS.TEST
|
| 130 |
+
]
|
| 131 |
+
res = cls.test(cfg, model, evaluators)
|
| 132 |
+
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
|
| 133 |
+
return res
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def setup(args):
|
| 137 |
+
"""
|
| 138 |
+
Create configs and perform basic setups.
|
| 139 |
+
"""
|
| 140 |
+
cfg = get_cfg()
|
| 141 |
+
# [modification] we add these two new keys
|
| 142 |
+
cfg.SOLVER.OPTIMIZER, cfg.SOLVER.LR_DECAY = 'sgd', 1.0 # by default using SGD and no lr_decay
|
| 143 |
+
cfg.merge_from_file(args.config_file)
|
| 144 |
+
cfg.merge_from_list(args.opts)
|
| 145 |
+
cfg.freeze()
|
| 146 |
+
default_setup(cfg, args)
|
| 147 |
+
return cfg
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def main(args):
|
| 151 |
+
cfg = setup(args)
|
| 152 |
+
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
|
| 153 |
+
|
| 154 |
+
# [modification] for implementing lr decay and for logging
|
| 155 |
+
Trainer.lr_decay_ratio = cfg.SOLVER.LR_DECAY
|
| 156 |
+
|
| 157 |
+
if args.eval_only:
|
| 158 |
+
model = Trainer.build_model(cfg)
|
| 159 |
+
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
| 160 |
+
cfg.MODEL.WEIGHTS, resume=args.resume
|
| 161 |
+
)
|
| 162 |
+
res = Trainer.test(cfg, model)
|
| 163 |
+
if cfg.TEST.AUG.ENABLED:
|
| 164 |
+
res.update(Trainer.test_with_TTA(cfg, model))
|
| 165 |
+
if comm.is_main_process():
|
| 166 |
+
verify_results(cfg, res)
|
| 167 |
+
return res
|
| 168 |
+
|
| 169 |
+
# [modification] just skip some warnings
|
| 170 |
+
import warnings
|
| 171 |
+
comm.synchronize()
|
| 172 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 173 |
+
_ = np.arange(3, dtype=np.int).astype(np.bool)
|
| 174 |
+
_ = np.array(torch.ones(3, dtype=torch.int32).numpy(), dtype=np.int)
|
| 175 |
+
_ = np.array(torch.ones(3, dtype=torch.int64).numpy(), dtype=np.int)
|
| 176 |
+
_ = np.array(torch.ones(3, dtype=torch.long).numpy(), dtype=np.int)
|
| 177 |
+
_ = torch.rand(100) // 5
|
| 178 |
+
_ = torch.meshgrid(torch.ones(1))
|
| 179 |
+
warnings.resetwarnings()
|
| 180 |
+
comm.synchronize()
|
| 181 |
+
|
| 182 |
+
"""
|
| 183 |
+
If you'd like to do anything fancier than the standard training logic,
|
| 184 |
+
consider writing your own training loop (see plain_train_net.py) or
|
| 185 |
+
subclassing the trainer.
|
| 186 |
+
"""
|
| 187 |
+
trainer = Trainer(cfg)
|
| 188 |
+
trainer.resume_or_load(resume=args.resume)
|
| 189 |
+
for h in trainer._hooks:
|
| 190 |
+
if isinstance(h, PeriodicWriter):
|
| 191 |
+
h._period = 1000 # [modification] less logging
|
| 192 |
+
|
| 193 |
+
# [modification] we add some hooks for logging
|
| 194 |
+
is_local_master = comm.get_rank() % args.num_gpus == 0
|
| 195 |
+
if comm.is_main_process():
|
| 196 |
+
print(f'[default hooks] {pformat(trainer._hooks, indent=2, width=300)}')
|
| 197 |
+
ex_hooks = [
|
| 198 |
+
hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) if cfg.TEST.AUG.ENABLED else None,
|
| 199 |
+
LogHook(cfg.TEST.EVAL_PERIOD, args.config_file, cfg.OUTPUT_DIR, is_local_master) if comm.is_main_process() else None,
|
| 200 |
+
]
|
| 201 |
+
trainer.register_hooks(ex_hooks)
|
| 202 |
+
if comm.is_main_process():
|
| 203 |
+
print(f'[extra hooks] {pformat(ex_hooks, indent=2, width=300)}')
|
| 204 |
+
|
| 205 |
+
return trainer.train()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# [modification] we add a hook for logging results to `cfg.OUTPUT_DIR/d2_coco_log.txt`
|
| 209 |
+
class LogHook(hooks.HookBase):
|
| 210 |
+
def __init__(self, eval_period, config_file, output_dir, is_local_master):
|
| 211 |
+
self.eval_period = eval_period
|
| 212 |
+
self.log_period = eval_period // 4
|
| 213 |
+
self.log = {}
|
| 214 |
+
|
| 215 |
+
self.is_master = comm.is_main_process()
|
| 216 |
+
self.is_local_master = is_local_master
|
| 217 |
+
|
| 218 |
+
self.config_file = config_file
|
| 219 |
+
self.out_dir = output_dir
|
| 220 |
+
self.log_txt_name = os.path.join(self.out_dir, 'd2_coco_log.txt')
|
| 221 |
+
|
| 222 |
+
def __write_to_log_file(self, d):
|
| 223 |
+
if self.is_local_master:
|
| 224 |
+
self.log.update(d)
|
| 225 |
+
with open(self.log_txt_name, 'w') as fp:
|
| 226 |
+
json.dump(self.log, fp)
|
| 227 |
+
fp.write('\n')
|
| 228 |
+
|
| 229 |
+
def update_and_write_to_local_log(self):
|
| 230 |
+
stat = self.trainer.storage.latest()
|
| 231 |
+
self.log['boxAP'], self.log['bAP50'], self.log['bAP75'] = stat['bbox/AP'][0], stat['bbox/AP50'][0], stat['bbox/AP75'][0]
|
| 232 |
+
self.log['mskAP'], self.log['mAP50'], self.log['mAP75'] = stat['segm/AP'][0], stat['segm/AP50'][0], stat['segm/AP75'][0]
|
| 233 |
+
self.log['bAP-l'], self.log['bAP-m'], self.log['bAP-s'] = stat['bbox/APl'][0], stat['bbox/APm'][0], stat['bbox/APs'][0]
|
| 234 |
+
self.log['mAP-l'], self.log['mAP-m'], self.log['mAP-s'] = stat['segm/APl'][0], stat['segm/APm'][0], stat['segm/APs'][0]
|
| 235 |
+
all_ap = sorted([(v[0], k.split('AP-')[-1].strip()) for k, v in stat.items() if k.startswith('bbox/AP-')])
|
| 236 |
+
all_ap = [tu[1] for tu in all_ap]
|
| 237 |
+
self.log['easy'] = ' | '.join(all_ap[-7:])
|
| 238 |
+
self.log['hard'] = ' | '.join(all_ap[:7])
|
| 239 |
+
for k in self.log.keys():
|
| 240 |
+
if 'AP' in k:
|
| 241 |
+
self.log[k] = round(self.log[k], 3)
|
| 242 |
+
self.__write_to_log_file({})
|
| 243 |
+
|
| 244 |
+
def after_step(self):
|
| 245 |
+
next_iter = self.trainer.iter + 1
|
| 246 |
+
if self.eval_period > 0 and next_iter % self.eval_period == 0:
|
| 247 |
+
self.update_and_write_to_local_log()
|
| 248 |
+
|
| 249 |
+
if self.log_period > 0 and next_iter % self.log_period == 0:
|
| 250 |
+
stat = self.trainer.storage.latest()
|
| 251 |
+
remain_secs = round(stat['eta_seconds'][0])
|
| 252 |
+
d = {
|
| 253 |
+
'cfg': self.config_file,
|
| 254 |
+
'rema': str(datetime.timedelta(seconds=remain_secs)), 'fini': time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)),
|
| 255 |
+
'cur_iter': f'{next_iter}/{self.trainer.max_iter}',
|
| 256 |
+
}
|
| 257 |
+
self.__write_to_log_file(d)
|
| 258 |
+
|
| 259 |
+
def after_train(self):
|
| 260 |
+
self.update_and_write_to_local_log()
|
| 261 |
+
last_boxAP, last_mskAP = round(self.log['boxAP'], 3), round(self.log['mskAP'], 3)
|
| 262 |
+
self.__write_to_log_file({
|
| 263 |
+
'rema': '-', 'fini': time.strftime("%m-%d %H:%M", time.localtime(time.time() - 120)),
|
| 264 |
+
'last_boxAP': last_boxAP,
|
| 265 |
+
'last_mskAP': last_mskAP,
|
| 266 |
+
})
|
| 267 |
+
time.sleep(5)
|
| 268 |
+
if self.is_master:
|
| 269 |
+
print(f'\n[finished] ========== last_boxAP={last_boxAP}, last_mskAP={last_mskAP} ==========\n')
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
args = default_argument_parser().parse_args()
|
| 274 |
+
print("Command Line Args:", args)
|
| 275 |
+
launch(
|
| 276 |
+
main,
|
| 277 |
+
args.num_gpus,
|
| 278 |
+
num_machines=args.num_machines,
|
| 279 |
+
machine_rank=args.machine_rank,
|
| 280 |
+
dist_url=args.dist_url,
|
| 281 |
+
args=(args,),
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def build_evaluator(cfg, dataset_name, output_folder=None):
|
| 286 |
+
"""
|
| 287 |
+
Create evaluator(s) for a given dataset.
|
| 288 |
+
This uses the special metadata "evaluator_type" associated with each builtin dataset.
|
| 289 |
+
For your own dataset, you can simply create an evaluator manually in your
|
| 290 |
+
script and do not have to worry about the hacky if-else logic here.
|
| 291 |
+
"""
|
| 292 |
+
if output_folder is None:
|
| 293 |
+
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
| 294 |
+
evaluator_list = []
|
| 295 |
+
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
|
| 296 |
+
if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
|
| 297 |
+
evaluator_list.append(
|
| 298 |
+
SemSegEvaluator(
|
| 299 |
+
dataset_name,
|
| 300 |
+
distributed=True,
|
| 301 |
+
output_dir=output_folder,
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
if evaluator_type in ["coco", "coco_panoptic_seg"]:
|
| 305 |
+
evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
|
| 306 |
+
if evaluator_type == "coco_panoptic_seg":
|
| 307 |
+
evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
|
| 308 |
+
if evaluator_type == "cityscapes_instance":
|
| 309 |
+
return CityscapesInstanceEvaluator(dataset_name)
|
| 310 |
+
if evaluator_type == "cityscapes_sem_seg":
|
| 311 |
+
return CityscapesSemSegEvaluator(dataset_name)
|
| 312 |
+
elif evaluator_type == "pascal_voc":
|
| 313 |
+
return PascalVOCDetectionEvaluator(dataset_name)
|
| 314 |
+
elif evaluator_type == "lvis":
|
| 315 |
+
return LVISEvaluator(dataset_name, output_dir=output_folder)
|
| 316 |
+
if len(evaluator_list) == 0:
|
| 317 |
+
raise NotImplementedError(
|
| 318 |
+
"no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
|
| 319 |
+
)
|
| 320 |
+
elif len(evaluator_list) == 1:
|
| 321 |
+
return evaluator_list[0]
|
| 322 |
+
return DatasetEvaluators(evaluator_list)
|
spark/downstream_imagenet/README.md
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## About code isolation
|
| 2 |
+
|
| 3 |
+
This `downstream_imagenet` is isolated from pre-training codes. One can treat this `downstream_imagenet` as an independent codebase 🛠️.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## Preparation for ImageNet-1k fine-tuning
|
| 7 |
+
|
| 8 |
+
See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
|
| 9 |
+
|
| 10 |
+
**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
## Fine-tuning on ImageNet-1k from pre-trained weights
|
| 14 |
+
|
| 15 |
+
Run [/downstream_imagenet/main.py](/downstream_imagenet/main.py) via `torchrun`.
|
| 16 |
+
**It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), the model name (`--model`, valid choices see the keys of 'HP_DEFAULT_VALUES' in [/downstream_imagenet/arg.py line14](/downstream_imagenet/arg.py#L14)), and the pretrained weight file `--resume_from` to run fine-tuning.
|
| 17 |
+
|
| 18 |
+
All the other configurations have their default values, listed in [/downstream_imagenet/arg.py#L13](/downstream_imagenet/arg.py#L13).
|
| 19 |
+
You can overwrite any defaults by `--bs=1024` or something like that.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Here is an example to pretrain a ConvNeXt-Small on an 8-GPU single machine:
|
| 23 |
+
```shell script
|
| 24 |
+
$ cd /path/to/SparK/downstream_imagenet
|
| 25 |
+
$ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=<some_port> main.py \
|
| 26 |
+
--data_path=/path/to/imagenet --exp_name=<your_exp_name> --exp_dir=/path/to/logdir \
|
| 27 |
+
--model=convnext_small --resume_from=/some/path/to/convnextS_1kpretrained_official_style.pth
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
For multiple machines, change the `--nnodes` and `--master_addr` to your configurations. E.g.:
|
| 31 |
+
```shell script
|
| 32 |
+
$ torchrun --nproc_per_node=8 --nnodes=<your_nnodes> --node_rank=<rank_starts_from_0> --master_address=<some_address> --master_port=<some_port> main.py \
|
| 33 |
+
...
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
## Logging
|
| 38 |
+
|
| 39 |
+
See files under `--exp_dir` to track your experiment:
|
| 40 |
+
|
| 41 |
+
- `<model>_1kfinetuned_last.pth`: the latest model weights
|
| 42 |
+
- `<model>_1kfinetuned_best.pth`: model weights with the highest acc
|
| 43 |
+
- `<model>_1kfinetuned_best_ema.pth`: EMA weights with the highest acc
|
| 44 |
+
- `finetune_log.txt`: records some important information such as:
|
| 45 |
+
- `git_commit_id`: git version
|
| 46 |
+
- `cmd`: all arguments passed to the script
|
| 47 |
+
|
| 48 |
+
It also reports training loss/acc, best evaluation acc, and remaining time at each epoch.
|
| 49 |
+
|
| 50 |
+
- `tensorboard_log/`: saves a lot of tensorboard logs, you can visualize accuracies, loss values, learning rates, gradient norms and more things via `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333`.
|
| 51 |
+
|
| 52 |
+
## Resuming
|
| 53 |
+
|
| 54 |
+
Use `--resume_from` again, like `--resume_from=path/to/<model>_1kfinetuned_last.pth`.
|
spark/downstream_imagenet/arg.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
from tap import Tap
|
| 12 |
+
|
| 13 |
+
HP_DEFAULT_NAMES = ['bs', 'ep', 'wp_ep', 'opt', 'base_lr', 'lr_scale', 'wd', 'mixup', 'rep_aug', 'drop_path', 'ema']
|
| 14 |
+
HP_DEFAULT_VALUES = {
|
| 15 |
+
'convnext_small': (4096, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999),
|
| 16 |
+
'convnext_base': (4096, 400, 20, 'adam', 0.0001, 0.7, 0.01, 0.8, 3, 0.4, 0.9999),
|
| 17 |
+
'convnext_large': (4096, 200, 10, 'adam', 0.0001, 0.7, 0.02, 0.8, 3, 0.5, 0.9999),
|
| 18 |
+
'convnext_large_384': (1024, 200, 20, 'adam', 0.00006, 0.7, 0.01, 0.8, 3, 0.5, 0.99995),
|
| 19 |
+
|
| 20 |
+
'resnet50': (4096, 300, 5, 'lamb', 0.002, 0.7, 0.02, 0.1, 0, 0.05, 0.9999),
|
| 21 |
+
'resnet101': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
|
| 22 |
+
'resnet152': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
|
| 23 |
+
'resnet200': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FineTuneArgs(Tap):
|
| 28 |
+
# environment
|
| 29 |
+
exp_name: str
|
| 30 |
+
exp_dir: str
|
| 31 |
+
data_path: str
|
| 32 |
+
model: str
|
| 33 |
+
resume_from: str = '' # resume from some checkpoint.pth
|
| 34 |
+
|
| 35 |
+
img_size: int = 640
|
| 36 |
+
dataloader_workers: int = 8
|
| 37 |
+
|
| 38 |
+
# ImageNet classification fine-tuning hyperparameters; see `HP_DEFAULT_VALUES` above for detailed default values
|
| 39 |
+
# - batch size, epoch
|
| 40 |
+
bs: int = 0 # global batch size (== batch_size_per_gpu * num_gpus)
|
| 41 |
+
ep: int = 0 # number of epochs
|
| 42 |
+
wp_ep: int = 0 # epochs for warmup
|
| 43 |
+
|
| 44 |
+
# - optimization
|
| 45 |
+
opt: str = '' # optimizer; 'adam' or 'lamb'
|
| 46 |
+
base_lr: float = 0. # lr == base_lr * (bs)
|
| 47 |
+
lr_scale: float = 0. # see file `lr_decay.py` for more details
|
| 48 |
+
clip: int = -1 # use gradient clipping if clip > 0
|
| 49 |
+
|
| 50 |
+
# - regularization tricks
|
| 51 |
+
wd: float = 0. # weight decay
|
| 52 |
+
mixup: float = 0. # use mixup if mixup > 0
|
| 53 |
+
rep_aug: int = 0 # use repeated augmentation if rep_aug > 0
|
| 54 |
+
drop_path: float = 0. # drop_path ratio
|
| 55 |
+
|
| 56 |
+
# - other tricks
|
| 57 |
+
ema: float = 0. # use EMA if ema > 0
|
| 58 |
+
sbn: bool = True # use SyncBatchNorm
|
| 59 |
+
|
| 60 |
+
# NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically
|
| 61 |
+
lr: float = None
|
| 62 |
+
batch_size_per_gpu: int = 0
|
| 63 |
+
glb_batch_size: int = 0
|
| 64 |
+
device: str = 'cpu'
|
| 65 |
+
world_size: int = 1
|
| 66 |
+
global_rank: int = 0
|
| 67 |
+
local_rank: int = 0 # we DO USE this arg
|
| 68 |
+
is_master: bool = False
|
| 69 |
+
is_local_master: bool = False
|
| 70 |
+
cmd: str = ' '.join(sys.argv[1:])
|
| 71 |
+
commit_id: str = os.popen(f'git rev-parse HEAD').read().strip()
|
| 72 |
+
commit_msg: str = os.popen(f'git log -1').read().strip().splitlines()[-1].strip()
|
| 73 |
+
log_txt_name: str = '{args.exp_dir}/pretrain_log.txt'
|
| 74 |
+
tb_lg_dir: str = '' # tensorboard log directory
|
| 75 |
+
|
| 76 |
+
train_loss: float = 0.
|
| 77 |
+
train_acc: float = 0.
|
| 78 |
+
best_val_acc: float = 0.
|
| 79 |
+
cur_ep: str = ''
|
| 80 |
+
remain_time: str = ''
|
| 81 |
+
finish_time: str = ''
|
| 82 |
+
first_logging: bool = True
|
| 83 |
+
|
| 84 |
+
def log_epoch(self):
|
| 85 |
+
if not self.is_local_master:
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
if self.first_logging:
|
| 89 |
+
self.first_logging = False
|
| 90 |
+
with open(self.log_txt_name, 'w') as fp:
|
| 91 |
+
json.dump({
|
| 92 |
+
'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg,
|
| 93 |
+
'model': self.model,
|
| 94 |
+
}, fp)
|
| 95 |
+
fp.write('\n\n')
|
| 96 |
+
|
| 97 |
+
with open(self.log_txt_name, 'a') as fp:
|
| 98 |
+
json.dump({
|
| 99 |
+
'cur_ep': self.cur_ep,
|
| 100 |
+
'train_L': self.train_loss, 'train_acc': self.train_acc,
|
| 101 |
+
'best_val_acc': self.best_val_acc,
|
| 102 |
+
'rema': self.remain_time, 'fini': self.finish_time,
|
| 103 |
+
}, fp)
|
| 104 |
+
fp.write('\n')
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_args(world_size, global_rank, local_rank, device) -> FineTuneArgs:
|
| 108 |
+
# parse args and prepare directories
|
| 109 |
+
args = FineTuneArgs(explicit_bool=True).parse_args()
|
| 110 |
+
d_name, b_name = os.path.dirname(os.path.abspath(args.exp_dir)), os.path.basename(os.path.abspath(args.exp_dir))
|
| 111 |
+
b_name = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in b_name)
|
| 112 |
+
args.exp_dir = os.path.join(d_name, b_name)
|
| 113 |
+
os.makedirs(args.exp_dir, exist_ok=True)
|
| 114 |
+
args.log_txt_name = os.path.join(args.exp_dir, 'finetune_log.txt')
|
| 115 |
+
|
| 116 |
+
args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log')
|
| 117 |
+
try: os.makedirs(args.tb_lg_dir, exist_ok=True)
|
| 118 |
+
except: pass
|
| 119 |
+
|
| 120 |
+
# fill in args.bs, args.ep, etc. with their default values (if their values are not explicitly specified, i.e., if bool(they) == False)
|
| 121 |
+
if args.model == 'convnext_large' and args.img_size == 384:
|
| 122 |
+
default_values = HP_DEFAULT_VALUES['convnext_large_384']
|
| 123 |
+
else:
|
| 124 |
+
default_values = HP_DEFAULT_VALUES[args.model]
|
| 125 |
+
for k, v in zip(HP_DEFAULT_NAMES, default_values):
|
| 126 |
+
if bool(getattr(args, k)) == False:
|
| 127 |
+
setattr(args, k, v)
|
| 128 |
+
|
| 129 |
+
# update other runtime args
|
| 130 |
+
args.world_size, args.global_rank, args.local_rank, args.device = world_size, global_rank, local_rank, device
|
| 131 |
+
args.is_master = global_rank == 0
|
| 132 |
+
args.is_local_master = local_rank == 0
|
| 133 |
+
args.batch_size_per_gpu = args.bs // world_size
|
| 134 |
+
args.glb_batch_size = args.batch_size_per_gpu * world_size
|
| 135 |
+
args.lr = args.base_lr * args.glb_batch_size / 256
|
| 136 |
+
|
| 137 |
+
return args
|
spark/downstream_imagenet/data.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
import PIL.Image as PImage
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision
|
| 15 |
+
from timm.data import AutoAugment as TimmAutoAugment
|
| 16 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
|
| 17 |
+
from timm.data.distributed_sampler import RepeatAugSampler
|
| 18 |
+
from timm.data.transforms_factory import transforms_imagenet_eval
|
| 19 |
+
from torch.utils.data import DataLoader
|
| 20 |
+
from torch.utils.data.sampler import Sampler
|
| 21 |
+
from torchvision.transforms import AutoAugment as TorchAutoAugment
|
| 22 |
+
from torchvision.transforms import transforms, TrivialAugmentWide
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from torchvision.transforms import InterpolationMode
|
| 26 |
+
interpolation = InterpolationMode.BICUBIC
|
| 27 |
+
except:
|
| 28 |
+
import PIL
|
| 29 |
+
interpolation = PIL.Image.BICUBIC
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_classification_dataset(data_path, img_size, rep_aug, workers, batch_size_per_gpu, world_size, global_rank):
|
| 33 |
+
import warnings
|
| 34 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 35 |
+
|
| 36 |
+
mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 37 |
+
trans_train = create_transform(
|
| 38 |
+
is_training=True, input_size=img_size,
|
| 39 |
+
auto_augment='v0', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1,
|
| 40 |
+
mean=mean, std=std,
|
| 41 |
+
)
|
| 42 |
+
if img_size < 384:
|
| 43 |
+
for i, t in enumerate(trans_train.transforms):
|
| 44 |
+
if isinstance(t, (TorchAutoAugment, TimmAutoAugment)):
|
| 45 |
+
trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation)
|
| 46 |
+
break
|
| 47 |
+
trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std)
|
| 48 |
+
else:
|
| 49 |
+
trans_val = transforms.Compose([
|
| 50 |
+
transforms.Resize((img_size, img_size), interpolation=interpolation),
|
| 51 |
+
transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
|
| 52 |
+
])
|
| 53 |
+
print_transform(trans_train, '[train]')
|
| 54 |
+
print_transform(trans_val, '[val]')
|
| 55 |
+
|
| 56 |
+
imagenet_folder = os.path.abspath(data_path)
|
| 57 |
+
for postfix in ('train', 'val'):
|
| 58 |
+
if imagenet_folder.endswith(postfix):
|
| 59 |
+
imagenet_folder = imagenet_folder[:-len(postfix)]
|
| 60 |
+
dataset_train = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'train'), trans_train)
|
| 61 |
+
dataset_val = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'val'), trans_val)
|
| 62 |
+
|
| 63 |
+
if rep_aug:
|
| 64 |
+
print(f'[dataset] using repeated augmentation: count={rep_aug}')
|
| 65 |
+
train_sp = RepeatAugSampler(dataset_train, shuffle=True, num_repeats=rep_aug)
|
| 66 |
+
else:
|
| 67 |
+
train_sp = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True, drop_last=True)
|
| 68 |
+
|
| 69 |
+
loader_train = DataLoader(
|
| 70 |
+
dataset=dataset_train, num_workers=workers, pin_memory=True,
|
| 71 |
+
batch_size=batch_size_per_gpu, sampler=train_sp, persistent_workers=workers > 0,
|
| 72 |
+
worker_init_fn=worker_init_fn,
|
| 73 |
+
)
|
| 74 |
+
iters_train = len(loader_train)
|
| 75 |
+
print(f'[dataset: train] bs={world_size}x{batch_size_per_gpu}={world_size * batch_size_per_gpu}, num_iters={iters_train}')
|
| 76 |
+
|
| 77 |
+
val_ratio = 2
|
| 78 |
+
loader_val = DataLoader(
|
| 79 |
+
dataset=dataset_val, num_workers=workers, pin_memory=True,
|
| 80 |
+
batch_sampler=DistInfiniteBatchSampler(world_size, global_rank, len(dataset_val), glb_batch_size=val_ratio * batch_size_per_gpu, filling=False, shuffle=False),
|
| 81 |
+
worker_init_fn=worker_init_fn,
|
| 82 |
+
)
|
| 83 |
+
iters_val = len(loader_val)
|
| 84 |
+
print(f'[dataset: val] bs={world_size}x{val_ratio * batch_size_per_gpu}={val_ratio * world_size * batch_size_per_gpu}, num_iters={iters_val}')
|
| 85 |
+
|
| 86 |
+
time.sleep(3)
|
| 87 |
+
warnings.resetwarnings()
|
| 88 |
+
return loader_train, iters_train, iter(loader_val), iters_val
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def worker_init_fn(worker_id):
|
| 92 |
+
# see: https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
| 93 |
+
worker_seed = torch.initial_seed() % 2 ** 32
|
| 94 |
+
np.random.seed(worker_seed)
|
| 95 |
+
random.seed(worker_seed)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def print_transform(transform, s):
|
| 99 |
+
print(f'Transform {s} = ')
|
| 100 |
+
for t in transform.transforms:
|
| 101 |
+
print(t)
|
| 102 |
+
print('---------------------------\n')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class DistInfiniteBatchSampler(Sampler):
|
| 106 |
+
def __init__(self, world_size, global_rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True):
|
| 107 |
+
assert glb_batch_size % world_size == 0
|
| 108 |
+
self.world_size, self.rank = world_size, global_rank
|
| 109 |
+
self.dataset_len = dataset_len
|
| 110 |
+
self.glb_batch_size = glb_batch_size
|
| 111 |
+
self.batch_size = glb_batch_size // world_size
|
| 112 |
+
|
| 113 |
+
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
|
| 114 |
+
self.filling = filling
|
| 115 |
+
self.shuffle = shuffle
|
| 116 |
+
self.epoch = 0
|
| 117 |
+
self.seed = seed
|
| 118 |
+
self.indices = self.gener_indices()
|
| 119 |
+
|
| 120 |
+
def gener_indices(self):
|
| 121 |
+
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
|
| 122 |
+
if self.shuffle:
|
| 123 |
+
g = torch.Generator()
|
| 124 |
+
g.manual_seed(self.epoch + self.seed)
|
| 125 |
+
global_indices = torch.randperm(self.dataset_len, generator=g)
|
| 126 |
+
else:
|
| 127 |
+
global_indices = torch.arange(self.dataset_len)
|
| 128 |
+
filling = global_max_p - global_indices.shape[0]
|
| 129 |
+
if filling > 0 and self.filling:
|
| 130 |
+
global_indices = torch.cat((global_indices, global_indices[:filling]))
|
| 131 |
+
global_indices = tuple(global_indices.numpy().tolist())
|
| 132 |
+
|
| 133 |
+
seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
|
| 134 |
+
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
|
| 135 |
+
self.max_p = len(local_indices)
|
| 136 |
+
return local_indices
|
| 137 |
+
|
| 138 |
+
def __iter__(self):
|
| 139 |
+
self.epoch = 0
|
| 140 |
+
while True:
|
| 141 |
+
self.epoch += 1
|
| 142 |
+
p, q = 0, 0
|
| 143 |
+
while p < self.max_p:
|
| 144 |
+
q = p + self.batch_size
|
| 145 |
+
yield self.indices[p:q]
|
| 146 |
+
p = q
|
| 147 |
+
if self.shuffle:
|
| 148 |
+
self.indices = self.gener_indices()
|
| 149 |
+
|
| 150 |
+
def __len__(self):
|
| 151 |
+
return self.iters_per_ep
|
spark/downstream_imagenet/lr_decay.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from pprint import pformat
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def lr_wd_annealing(optimizer, peak_lr, wd, cur_it, wp_it, max_it):
|
| 12 |
+
wp_it = round(wp_it)
|
| 13 |
+
if cur_it < wp_it:
|
| 14 |
+
cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it
|
| 15 |
+
else:
|
| 16 |
+
ratio = (cur_it - wp_it) / (max_it - 1 - wp_it)
|
| 17 |
+
cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio))
|
| 18 |
+
|
| 19 |
+
min_lr, max_lr = cur_lr, cur_lr
|
| 20 |
+
min_wd, max_wd = wd, wd
|
| 21 |
+
for param_group in optimizer.param_groups:
|
| 22 |
+
scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned
|
| 23 |
+
min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr)
|
| 24 |
+
scaled_wd = param_group['weight_decay'] = wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned
|
| 25 |
+
min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd)
|
| 26 |
+
return min_lr, max_lr, min_wd, max_wd
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_param_groups(model, nowd_keys=(), lr_scale=0.0):
|
| 30 |
+
using_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0.0 < lr_scale < 1.0
|
| 31 |
+
print(f'[get_ft_param_groups][lr decay] using_lr_scale={using_lr_scale}, ft_lr_scale={lr_scale}')
|
| 32 |
+
para_groups, para_groups_dbg = {}, {}
|
| 33 |
+
|
| 34 |
+
for name, para in model.named_parameters():
|
| 35 |
+
if not para.requires_grad:
|
| 36 |
+
continue # frozen weights
|
| 37 |
+
if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys):
|
| 38 |
+
wd_scale, group_name = 0., 'no_decay'
|
| 39 |
+
else:
|
| 40 |
+
wd_scale, group_name = 1., 'decay'
|
| 41 |
+
|
| 42 |
+
if using_lr_scale:
|
| 43 |
+
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
|
| 44 |
+
group_name = f'layer{layer_id}_' + group_name
|
| 45 |
+
this_lr_scale = lr_scale ** scale_exp
|
| 46 |
+
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
|
| 47 |
+
else:
|
| 48 |
+
this_lr_scale = 1
|
| 49 |
+
dbg = f'[no scale]'
|
| 50 |
+
|
| 51 |
+
if group_name not in para_groups:
|
| 52 |
+
para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': this_lr_scale}
|
| 53 |
+
para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': dbg}
|
| 54 |
+
para_groups[group_name]['params'].append(para)
|
| 55 |
+
para_groups_dbg[group_name]['params'].append(name)
|
| 56 |
+
|
| 57 |
+
for g in para_groups_dbg.values():
|
| 58 |
+
g['params'] = pformat(', '.join(g['params']), width=200)
|
| 59 |
+
|
| 60 |
+
print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n')
|
| 61 |
+
return list(para_groups.values())
|
spark/downstream_imagenet/main.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import datetime
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as tdist
|
| 12 |
+
from timm.utils import ModelEmaV2
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
|
| 15 |
+
from arg import get_args, FineTuneArgs
|
| 16 |
+
from models import ConvNeXt, ResNet
|
| 17 |
+
__for_timm_registration = ConvNeXt, ResNet
|
| 18 |
+
from lr_decay import lr_wd_annealing
|
| 19 |
+
from util import init_distributed_environ, create_model_opt, load_checkpoint, save_checkpoint
|
| 20 |
+
from data import create_classification_dataset
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main_ft():
|
| 24 |
+
world_size, global_rank, local_rank, device = init_distributed_environ()
|
| 25 |
+
args: FineTuneArgs = get_args(world_size, global_rank, local_rank, device)
|
| 26 |
+
print(f'initial args:\n{str(args)}')
|
| 27 |
+
args.log_epoch()
|
| 28 |
+
|
| 29 |
+
criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer = create_model_opt(args)
|
| 30 |
+
ep_start, performance_desc = load_checkpoint(args.resume_from, model_without_ddp, model_ema, optimizer)
|
| 31 |
+
|
| 32 |
+
if ep_start >= args.ep: # load from a complete checkpoint file
|
| 33 |
+
print(f' [*] [FT already done] Max/Last Acc: {performance_desc}')
|
| 34 |
+
else:
|
| 35 |
+
tb_lg = SummaryWriter(args.tb_lg_dir) if args.is_master else None
|
| 36 |
+
loader_train, iters_train, iterator_val, iters_val = create_classification_dataset(
|
| 37 |
+
args.data_path, args.img_size, args.rep_aug,
|
| 38 |
+
args.dataloader_workers, args.batch_size_per_gpu, args.world_size, args.global_rank
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# train & eval
|
| 42 |
+
tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
|
| 43 |
+
max_acc = last_acc
|
| 44 |
+
max_acc_e = last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)[-1]
|
| 45 |
+
print(f'[fine-tune] initial acc={last_acc:.2f}, ema={last_acc_e:.2f}')
|
| 46 |
+
|
| 47 |
+
ep_eval = set(range(0, args.ep//3, 5)) | set(range(args.ep//3, args.ep))
|
| 48 |
+
print(f'[FT start] ep_eval={sorted(ep_eval)} ')
|
| 49 |
+
print(f'[FT start] from ep{ep_start}')
|
| 50 |
+
|
| 51 |
+
params_req_grad = [p for p in model.parameters() if p.requires_grad]
|
| 52 |
+
ft_start_time = time.time()
|
| 53 |
+
for ep in range(ep_start, args.ep):
|
| 54 |
+
ep_start_time = time.time()
|
| 55 |
+
if hasattr(loader_train, 'sampler') and hasattr(loader_train.sampler, 'set_epoch'):
|
| 56 |
+
loader_train.sampler.set_epoch(ep)
|
| 57 |
+
if 0 <= ep <= 3:
|
| 58 |
+
print(f'[loader_train.sampler.set_epoch({ep})]')
|
| 59 |
+
|
| 60 |
+
train_loss, train_acc = fine_tune_one_epoch(ep, args, tb_lg, loader_train, iters_train, criterion, mixup_fn, model, model_ema, optimizer, params_req_grad)
|
| 61 |
+
if ep in ep_eval:
|
| 62 |
+
eval_start_time = time.time()
|
| 63 |
+
tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
|
| 64 |
+
tot_pred_e, last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)
|
| 65 |
+
eval_cost = round(time.time() - eval_start_time, 2)
|
| 66 |
+
performance_desc = f'Max (Last) Acc: {max(max_acc, last_acc):.2f} ({last_acc:.2f} o {tot_pred}) EMA: {max(max_acc_e, last_acc_e):.2f} ({last_acc_e:.2f} o {tot_pred_e})'
|
| 67 |
+
states = model_without_ddp.state_dict(), model_ema.module.state_dict(), optimizer.state_dict()
|
| 68 |
+
if last_acc > max_acc:
|
| 69 |
+
max_acc = last_acc
|
| 70 |
+
save_checkpoint(f'{args.model}_1kfinetuned_best.pth', args, ep, performance_desc, *states)
|
| 71 |
+
if last_acc_e > max_acc_e:
|
| 72 |
+
max_acc_e = last_acc_e
|
| 73 |
+
save_checkpoint(f'{args.model}_1kfinetuned_best_ema.pth', args, ep, performance_desc, *states)
|
| 74 |
+
save_checkpoint(f'{args.model}_1kfinetuned_last.pth', args, ep, performance_desc, *states)
|
| 75 |
+
else:
|
| 76 |
+
eval_cost = '-'
|
| 77 |
+
|
| 78 |
+
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
|
| 79 |
+
remain_secs = (args.ep-1 - ep) * ep_cost
|
| 80 |
+
remain_time = datetime.timedelta(seconds=round(remain_secs))
|
| 81 |
+
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
|
| 82 |
+
print(f'[ep{ep}/{args.ep}] {performance_desc} Ep cost: {ep_cost}s, Ev cost: {eval_cost}, Remain: {remain_time}, Finish @ {finish_time}')
|
| 83 |
+
args.cur_ep = f'{ep + 1}/{args.ep}'
|
| 84 |
+
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
|
| 85 |
+
args.train_loss, args.train_acc, args.best_val_acc = train_loss, train_acc, max(max_acc, max_acc_e)
|
| 86 |
+
args.log_epoch()
|
| 87 |
+
|
| 88 |
+
if args.is_master:
|
| 89 |
+
tb_lg.add_scalar(f'ft_train/ep_loss', train_loss, ep)
|
| 90 |
+
tb_lg.add_scalar(f'ft_eval/max_acc', max_acc, ep)
|
| 91 |
+
tb_lg.add_scalar(f'ft_eval/last_acc', last_acc, ep)
|
| 92 |
+
tb_lg.add_scalar(f'ft_eval/max_acc_ema', max_acc_e, ep)
|
| 93 |
+
tb_lg.add_scalar(f'ft_eval/last_acc_ema', last_acc_e, ep)
|
| 94 |
+
tb_lg.add_scalar(f'ft_z_burnout/rest_hours', round(remain_secs/60/60, 2), ep)
|
| 95 |
+
tb_lg.flush()
|
| 96 |
+
|
| 97 |
+
# finish fine-tuning
|
| 98 |
+
result_acc = max(max_acc, max_acc_e)
|
| 99 |
+
if args.is_master:
|
| 100 |
+
tb_lg.add_scalar('ft_result/result_acc', result_acc, ep_start)
|
| 101 |
+
tb_lg.add_scalar('ft_result/result_acc', result_acc, args.ep)
|
| 102 |
+
tb_lg.flush()
|
| 103 |
+
tb_lg.close()
|
| 104 |
+
print(f'final args:\n{str(args)}')
|
| 105 |
+
print('\n\n')
|
| 106 |
+
print(f' [*] [FT finished] {performance_desc} Total Cost: {(time.time() - ft_start_time) / 60 / 60:.1f}h\n')
|
| 107 |
+
print(f' [*] [FT finished] max(max_acc, max_acc_e)={result_acc} EMA better={max_acc_e>max_acc}')
|
| 108 |
+
print('\n\n')
|
| 109 |
+
time.sleep(10)
|
| 110 |
+
|
| 111 |
+
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
|
| 112 |
+
args.log_epoch()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def fine_tune_one_epoch(ep, args: FineTuneArgs, tb_lg: SummaryWriter, loader_train, iters_train, criterion, mixup_fn, model, model_ema: ModelEmaV2, optimizer, params_req_grad):
|
| 116 |
+
model.train()
|
| 117 |
+
tot_loss = tot_acc = 0.0
|
| 118 |
+
log_freq = max(1, round(iters_train * 0.7))
|
| 119 |
+
ep_start_time = time.time()
|
| 120 |
+
for it, (inp, tar) in enumerate(loader_train):
|
| 121 |
+
# adjust lr and wd
|
| 122 |
+
cur_it = it + ep * iters_train
|
| 123 |
+
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, cur_it, args.wp_ep * iters_train, args.ep * iters_train)
|
| 124 |
+
|
| 125 |
+
# forward
|
| 126 |
+
inp = inp.to(args.device, non_blocking=True)
|
| 127 |
+
raw_tar = tar = tar.to(args.device, non_blocking=True)
|
| 128 |
+
if mixup_fn is not None:
|
| 129 |
+
inp, tar, raw_tar = mixup_fn(inp, tar)
|
| 130 |
+
oup = model(inp)
|
| 131 |
+
pred = oup.data.argmax(dim=1)
|
| 132 |
+
if mixup_fn is None:
|
| 133 |
+
acc = pred.eq(tar).float().mean().item() * 100
|
| 134 |
+
tot_acc += acc
|
| 135 |
+
else:
|
| 136 |
+
acc = (pred.eq(raw_tar) | pred.eq(raw_tar.flip(0))).float().mean().item() * 100
|
| 137 |
+
tot_acc += acc
|
| 138 |
+
|
| 139 |
+
# backward
|
| 140 |
+
optimizer.zero_grad()
|
| 141 |
+
loss = criterion(oup, tar)
|
| 142 |
+
loss.backward()
|
| 143 |
+
loss = loss.item()
|
| 144 |
+
tot_loss += loss
|
| 145 |
+
if args.clip > 0:
|
| 146 |
+
orig_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
|
| 147 |
+
else:
|
| 148 |
+
orig_norm = None
|
| 149 |
+
optimizer.step()
|
| 150 |
+
model_ema.update(model)
|
| 151 |
+
torch.cuda.synchronize()
|
| 152 |
+
|
| 153 |
+
# log
|
| 154 |
+
if args.is_master and cur_it % log_freq == 0:
|
| 155 |
+
tb_lg.add_scalar(f'ft_train/it_loss', loss, cur_it)
|
| 156 |
+
tb_lg.add_scalar(f'ft_train/it_acc', acc, cur_it)
|
| 157 |
+
tb_lg.add_scalar(f'ft_hp/min_lr', min_lr, cur_it), tb_lg.add_scalar(f'ft_hp/max_lr', max_lr, cur_it)
|
| 158 |
+
tb_lg.add_scalar(f'ft_hp/min_wd', min_wd, cur_it), tb_lg.add_scalar(f'ft_hp/max_wd', max_wd, cur_it)
|
| 159 |
+
if orig_norm is not None:
|
| 160 |
+
tb_lg.add_scalar(f'ft_hp/orig_norm', orig_norm, cur_it)
|
| 161 |
+
|
| 162 |
+
if it in [3, iters_train//2, iters_train-1]:
|
| 163 |
+
remain_secs = (iters_train-1 - it) * (time.time() - ep_start_time) / (it + 1)
|
| 164 |
+
remain_time = datetime.timedelta(seconds=round(remain_secs))
|
| 165 |
+
print(f'[ep{ep} it{it:3d}/{iters_train}] L: {loss:.4f} Acc: {acc:.2f} lr: {min_lr:.1e}~{max_lr:.1e} Remain: {remain_time}')
|
| 166 |
+
|
| 167 |
+
return tot_loss / iters_train, tot_acc / iters_train
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@torch.no_grad()
|
| 171 |
+
def evaluate(dev, iterator_val, iters_val, model):
|
| 172 |
+
training = model.training
|
| 173 |
+
model.train(False)
|
| 174 |
+
tot_pred, tot_correct = 0., 0.
|
| 175 |
+
for _ in range(iters_val):
|
| 176 |
+
inp, tar = next(iterator_val)
|
| 177 |
+
tot_pred += tar.shape[0]
|
| 178 |
+
inp = inp.to(dev, non_blocking=True)
|
| 179 |
+
tar = tar.to(dev, non_blocking=True)
|
| 180 |
+
oup = model(inp)
|
| 181 |
+
tot_correct += oup.argmax(dim=1).eq(tar).sum().item()
|
| 182 |
+
model.train(training)
|
| 183 |
+
t = torch.tensor([tot_pred, tot_correct]).to(dev)
|
| 184 |
+
tdist.all_reduce(t)
|
| 185 |
+
return t[0].item(), (t[1] / t[0]).item() * 100.
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
if __name__ == '__main__':
|
| 189 |
+
main_ft()
|
spark/downstream_imagenet/mixup.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# This file is a modified version of timm.data.Mixup
|
| 8 |
+
# Fixed error of "Batch size should be even when using this"
|
| 9 |
+
|
| 10 |
+
""" Mixup and Cutmix
|
| 11 |
+
|
| 12 |
+
Papers:
|
| 13 |
+
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
|
| 14 |
+
|
| 15 |
+
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
|
| 16 |
+
|
| 17 |
+
Code Reference:
|
| 18 |
+
CutMix: https://github.com/clovaai/CutMix-PyTorch
|
| 19 |
+
|
| 20 |
+
Hacked together by / Copyright 2019, Ross Wightman
|
| 21 |
+
"""
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
| 27 |
+
x = x.long().view(-1, 1)
|
| 28 |
+
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
| 32 |
+
off_value = smoothing / num_classes
|
| 33 |
+
on_value = 1. - smoothing + off_value
|
| 34 |
+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
|
| 35 |
+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
|
| 36 |
+
return y1 * lam + y2 * (1. - lam)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def rand_bbox(img_shape, lam, margin=0., count=None):
|
| 40 |
+
""" Standard CutMix bounding-box
|
| 41 |
+
Generates a random square bbox based on lambda value. This impl includes
|
| 42 |
+
support for enforcing a border margin as percent of bbox dimensions.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
img_shape (tuple): Image shape as tuple
|
| 46 |
+
lam (float): Cutmix lambda value
|
| 47 |
+
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
|
| 48 |
+
count (int): Number of bbox to generate
|
| 49 |
+
"""
|
| 50 |
+
ratio = np.sqrt(1 - lam)
|
| 51 |
+
img_h, img_w = img_shape[-2:]
|
| 52 |
+
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
|
| 53 |
+
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
|
| 54 |
+
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
|
| 55 |
+
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
|
| 56 |
+
yl = np.clip(cy - cut_h // 2, 0, img_h)
|
| 57 |
+
yh = np.clip(cy + cut_h // 2, 0, img_h)
|
| 58 |
+
xl = np.clip(cx - cut_w // 2, 0, img_w)
|
| 59 |
+
xh = np.clip(cx + cut_w // 2, 0, img_w)
|
| 60 |
+
return yl, yh, xl, xh
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def rand_bbox_minmax(img_shape, minmax, count=None):
|
| 64 |
+
""" Min-Max CutMix bounding-box
|
| 65 |
+
Inspired by Darknet cutmix impl, generates a random rectangular bbox
|
| 66 |
+
based on min/max percent values applied to each dimension of the input image.
|
| 67 |
+
|
| 68 |
+
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
img_shape (tuple): Image shape as tuple
|
| 72 |
+
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
|
| 73 |
+
count (int): Number of bbox to generate
|
| 74 |
+
"""
|
| 75 |
+
assert len(minmax) == 2
|
| 76 |
+
img_h, img_w = img_shape[-2:]
|
| 77 |
+
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
|
| 78 |
+
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
|
| 79 |
+
yl = np.random.randint(0, img_h - cut_h, size=count)
|
| 80 |
+
xl = np.random.randint(0, img_w - cut_w, size=count)
|
| 81 |
+
yu = yl + cut_h
|
| 82 |
+
xu = xl + cut_w
|
| 83 |
+
return yl, yu, xl, xu
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
|
| 87 |
+
""" Generate bbox and apply lambda correction.
|
| 88 |
+
"""
|
| 89 |
+
if ratio_minmax is not None:
|
| 90 |
+
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
|
| 91 |
+
else:
|
| 92 |
+
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
|
| 93 |
+
if correct_lam or ratio_minmax is not None:
|
| 94 |
+
bbox_area = (yu - yl) * (xu - xl)
|
| 95 |
+
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
|
| 96 |
+
return (yl, yu, xl, xu), lam
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class BatchMixup:
|
| 100 |
+
""" Mixup/Cutmix that applies different params to each element or whole batch
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
|
| 104 |
+
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
|
| 105 |
+
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
|
| 106 |
+
prob (float): probability of applying mixup or cutmix per batch or element
|
| 107 |
+
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
|
| 108 |
+
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
|
| 109 |
+
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
|
| 110 |
+
label_smoothing (float): apply label smoothing to the mixed target tensor
|
| 111 |
+
num_classes (int): number of classes for target
|
| 112 |
+
"""
|
| 113 |
+
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
|
| 114 |
+
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
| 115 |
+
assert mode == 'batch'
|
| 116 |
+
self.mixup_alpha = mixup_alpha
|
| 117 |
+
self.cutmix_alpha = cutmix_alpha
|
| 118 |
+
self.cutmix_minmax = cutmix_minmax
|
| 119 |
+
if self.cutmix_minmax is not None:
|
| 120 |
+
assert len(self.cutmix_minmax) == 2
|
| 121 |
+
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
|
| 122 |
+
self.cutmix_alpha = 1.0
|
| 123 |
+
self.mix_prob = prob
|
| 124 |
+
self.switch_prob = switch_prob
|
| 125 |
+
self.label_smoothing = label_smoothing
|
| 126 |
+
self.num_classes = num_classes
|
| 127 |
+
self.mode = mode
|
| 128 |
+
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
|
| 129 |
+
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
|
| 130 |
+
|
| 131 |
+
def _params_per_batch(self):
|
| 132 |
+
lam = 1.
|
| 133 |
+
use_cutmix = False
|
| 134 |
+
if self.mixup_enabled and np.random.rand() < self.mix_prob:
|
| 135 |
+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
| 136 |
+
use_cutmix = np.random.rand() < self.switch_prob
|
| 137 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
|
| 138 |
+
np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
| 139 |
+
elif self.mixup_alpha > 0.:
|
| 140 |
+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
| 141 |
+
elif self.cutmix_alpha > 0.:
|
| 142 |
+
use_cutmix = True
|
| 143 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
|
| 144 |
+
else:
|
| 145 |
+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
| 146 |
+
lam = float(lam_mix)
|
| 147 |
+
return lam, use_cutmix
|
| 148 |
+
|
| 149 |
+
def _mix_batch(self, x):
|
| 150 |
+
lam, use_cutmix = self._params_per_batch()
|
| 151 |
+
if lam == 1.:
|
| 152 |
+
return 1.
|
| 153 |
+
if use_cutmix:
|
| 154 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
| 155 |
+
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
| 156 |
+
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
|
| 157 |
+
else:
|
| 158 |
+
x_flipped = x.flip(0).mul_(1. - lam)
|
| 159 |
+
x.mul_(lam).add_(x_flipped)
|
| 160 |
+
return lam
|
| 161 |
+
|
| 162 |
+
def __call__(self, x, raw_target):
|
| 163 |
+
if x.shape[0] % 2 == 1:
|
| 164 |
+
x, raw_target = torch.cat((x[:1], x), dim=0), torch.cat((raw_target[:1], raw_target), dim=0)
|
| 165 |
+
# assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
| 166 |
+
lam = self._mix_batch(x)
|
| 167 |
+
target = mixup_target(raw_target, self.num_classes, lam, self.label_smoothing, x.device)
|
| 168 |
+
return x, target, raw_target
|
spark/downstream_imagenet/models/__init__.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from timm.data import Mixup
|
| 11 |
+
from timm.loss import BinaryCrossEntropy, SoftTargetCrossEntropy
|
| 12 |
+
from timm.models.layers import drop
|
| 13 |
+
from timm.models.resnet import ResNet
|
| 14 |
+
|
| 15 |
+
from .convnext_official import ConvNeXt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def convnext_get_layer_id_and_scale_exp(self: ConvNeXt, para_name: str):
|
| 19 |
+
N = 12 if len(self.stages[-2]) > 9 else 6
|
| 20 |
+
if para_name.startswith("downsample_layers"):
|
| 21 |
+
stage_id = int(para_name.split('.')[1])
|
| 22 |
+
if stage_id == 0:
|
| 23 |
+
layer_id = 0
|
| 24 |
+
elif stage_id == 1 or stage_id == 2:
|
| 25 |
+
layer_id = stage_id + 1
|
| 26 |
+
else: # stage_id == 3:
|
| 27 |
+
layer_id = N
|
| 28 |
+
elif para_name.startswith("stages"):
|
| 29 |
+
stage_id = int(para_name.split('.')[1])
|
| 30 |
+
block_id = int(para_name.split('.')[2])
|
| 31 |
+
if stage_id == 0 or stage_id == 1:
|
| 32 |
+
layer_id = stage_id + 1
|
| 33 |
+
elif stage_id == 2:
|
| 34 |
+
layer_id = 3 + block_id // 3
|
| 35 |
+
else: # stage_id == 3:
|
| 36 |
+
layer_id = N
|
| 37 |
+
else:
|
| 38 |
+
layer_id = N + 1 # after backbone
|
| 39 |
+
|
| 40 |
+
return layer_id, N + 1 - layer_id
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def resnets_get_layer_id_and_scale_exp(self: ResNet, para_name: str):
|
| 44 |
+
# stages:
|
| 45 |
+
# 50 : [3, 4, 6, 3]
|
| 46 |
+
# 101 : [3, 4, 23, 3]
|
| 47 |
+
# 152 : [3, 8, 36, 3]
|
| 48 |
+
# 200 : [3, 24, 36, 3]
|
| 49 |
+
# eca269d: [3, 30, 48, 8]
|
| 50 |
+
|
| 51 |
+
L2, L3 = len(self.layer2), len(self.layer3)
|
| 52 |
+
if L2 == 4 and L3 == 6:
|
| 53 |
+
blk2, blk3 = 2, 3
|
| 54 |
+
elif L2 == 4 and L3 == 23:
|
| 55 |
+
blk2, blk3 = 2, 3
|
| 56 |
+
elif L2 == 8 and L3 == 36:
|
| 57 |
+
blk2, blk3 = 4, 4
|
| 58 |
+
elif L2 == 24 and L3 == 36:
|
| 59 |
+
blk2, blk3 = 4, 4
|
| 60 |
+
elif L2 == 30 and L3 == 48:
|
| 61 |
+
blk2, blk3 = 5, 6
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
|
| 65 |
+
N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5)
|
| 66 |
+
N = 2 + N2 + N3
|
| 67 |
+
if para_name.startswith('layer'): # 1, 2, 3, 4, 5
|
| 68 |
+
stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1])
|
| 69 |
+
if stage_id == 1:
|
| 70 |
+
layer_id = 1
|
| 71 |
+
elif stage_id == 2:
|
| 72 |
+
layer_id = 2 + block_id // blk2 # 2, 3
|
| 73 |
+
elif stage_id == 3:
|
| 74 |
+
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11
|
| 75 |
+
else: # == 4
|
| 76 |
+
layer_id = N # r50: 6 r101: 12
|
| 77 |
+
elif para_name.startswith('fc.'):
|
| 78 |
+
layer_id = N + 1 # r50: 7 r101: 13
|
| 79 |
+
else:
|
| 80 |
+
layer_id = 0
|
| 81 |
+
|
| 82 |
+
return layer_id, N + 1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _ex_repr(self):
|
| 86 |
+
return ', '.join(
|
| 87 |
+
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
|
| 88 |
+
for k, v in vars(self).items()
|
| 89 |
+
if not k.startswith('_') and k != 'training'
|
| 90 |
+
and not isinstance(v, (torch.nn.Module, torch.Tensor))
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# IMPORTANT: update some member functions
|
| 95 |
+
__UPDATED = False
|
| 96 |
+
if not __UPDATED:
|
| 97 |
+
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, BinaryCrossEntropy, Mixup, drop.DropPath):
|
| 98 |
+
if hasattr(clz, 'extra_repr'):
|
| 99 |
+
clz.extra_repr = _ex_repr
|
| 100 |
+
else:
|
| 101 |
+
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
|
| 102 |
+
ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp
|
| 103 |
+
ConvNeXt.get_layer_id_and_scale_exp = convnext_get_layer_id_and_scale_exp
|
| 104 |
+
__UPDATED = True
|
spark/downstream_imagenet/models/convnext_official.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# This file is exactly the same as: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 13 |
+
from timm.models.registry import register_model
|
| 14 |
+
|
| 15 |
+
class Block(nn.Module):
|
| 16 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 17 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 18 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 19 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
dim (int): Number of input channels.
|
| 23 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 24 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 29 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 30 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 31 |
+
self.act = nn.GELU()
|
| 32 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 33 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
| 34 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 35 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
input = x
|
| 39 |
+
x = self.dwconv(x)
|
| 40 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 41 |
+
x = self.norm(x)
|
| 42 |
+
x = self.pwconv1(x)
|
| 43 |
+
x = self.act(x)
|
| 44 |
+
x = self.pwconv2(x)
|
| 45 |
+
if self.gamma is not None:
|
| 46 |
+
x = self.gamma * x
|
| 47 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 48 |
+
|
| 49 |
+
x = input + self.drop_path(x)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
class ConvNeXt(nn.Module):
|
| 53 |
+
r""" ConvNeXt
|
| 54 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
| 55 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
| 56 |
+
Args:
|
| 57 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 58 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 59 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 60 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 61 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 62 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 63 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 64 |
+
"""
|
| 65 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
| 66 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
| 67 |
+
layer_scale_init_value=1e-6, head_init_scale=1.,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
| 72 |
+
stem = nn.Sequential(
|
| 73 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 74 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
| 75 |
+
)
|
| 76 |
+
self.downsample_layers.append(stem)
|
| 77 |
+
for i in range(3):
|
| 78 |
+
downsample_layer = nn.Sequential(
|
| 79 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 80 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
| 81 |
+
)
|
| 82 |
+
self.downsample_layers.append(downsample_layer)
|
| 83 |
+
|
| 84 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 85 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 86 |
+
cur = 0
|
| 87 |
+
for i in range(4):
|
| 88 |
+
stage = nn.Sequential(
|
| 89 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
| 90 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
| 91 |
+
)
|
| 92 |
+
self.stages.append(stage)
|
| 93 |
+
cur += depths[i]
|
| 94 |
+
|
| 95 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
| 96 |
+
self.head = nn.Linear(dims[-1], num_classes)
|
| 97 |
+
|
| 98 |
+
self.apply(self._init_weights)
|
| 99 |
+
self.head.weight.data.mul_(head_init_scale)
|
| 100 |
+
self.head.bias.data.mul_(head_init_scale)
|
| 101 |
+
|
| 102 |
+
def _init_weights(self, m):
|
| 103 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 104 |
+
trunc_normal_(m.weight, std=.02)
|
| 105 |
+
nn.init.constant_(m.bias, 0)
|
| 106 |
+
|
| 107 |
+
def forward_features(self, x):
|
| 108 |
+
for i in range(4):
|
| 109 |
+
x = self.downsample_layers[i](x)
|
| 110 |
+
x = self.stages[i](x)
|
| 111 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
x = self.forward_features(x)
|
| 115 |
+
x = self.head(x)
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
class LayerNorm(nn.Module):
|
| 119 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 120 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 121 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 122 |
+
with shape (batch_size, channels, height, width).
|
| 123 |
+
"""
|
| 124 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 127 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 128 |
+
self.eps = eps
|
| 129 |
+
self.data_format = data_format
|
| 130 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 131 |
+
raise NotImplementedError
|
| 132 |
+
self.normalized_shape = (normalized_shape, )
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
if self.data_format == "channels_last":
|
| 136 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 137 |
+
elif self.data_format == "channels_first":
|
| 138 |
+
u = x.mean(1, keepdim=True)
|
| 139 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 140 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 141 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
model_urls = {
|
| 146 |
+
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
| 147 |
+
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
| 148 |
+
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
| 149 |
+
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
| 150 |
+
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
| 151 |
+
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
| 152 |
+
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
| 153 |
+
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
| 154 |
+
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
@register_model
|
| 158 |
+
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
|
| 159 |
+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
| 160 |
+
if pretrained:
|
| 161 |
+
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
|
| 162 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
|
| 163 |
+
model.load_state_dict(checkpoint["model"])
|
| 164 |
+
return model
|
| 165 |
+
|
| 166 |
+
@register_model
|
| 167 |
+
def convnext_small(pretrained=False,in_22k=False, **kwargs):
|
| 168 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
| 169 |
+
if pretrained:
|
| 170 |
+
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
|
| 171 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
| 172 |
+
model.load_state_dict(checkpoint["model"])
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
@register_model
|
| 176 |
+
def convnext_base(pretrained=False, in_22k=False, **kwargs):
|
| 177 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
| 178 |
+
if pretrained:
|
| 179 |
+
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
|
| 180 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
| 181 |
+
model.load_state_dict(checkpoint["model"])
|
| 182 |
+
return model
|
| 183 |
+
|
| 184 |
+
@register_model
|
| 185 |
+
def convnext_large(pretrained=False, in_22k=False, **kwargs):
|
| 186 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
| 187 |
+
if pretrained:
|
| 188 |
+
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
|
| 189 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
| 190 |
+
model.load_state_dict(checkpoint["model"])
|
| 191 |
+
return model
|
| 192 |
+
|
| 193 |
+
@register_model
|
| 194 |
+
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
|
| 195 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
| 196 |
+
if pretrained:
|
| 197 |
+
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
|
| 198 |
+
url = model_urls['convnext_xlarge_22k']
|
| 199 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
| 200 |
+
model.load_state_dict(checkpoint["model"])
|
| 201 |
+
return model
|
spark/downstream_imagenet/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
Pillow
|
| 3 |
+
typed-argument-parser
|
| 4 |
+
timm==0.5.4
|
| 5 |
+
tensorboardx
|
spark/downstream_imagenet/util.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import datetime
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from functools import partial
|
| 11 |
+
from typing import List, Tuple, Callable
|
| 12 |
+
|
| 13 |
+
import pytz
|
| 14 |
+
import torch
|
| 15 |
+
import torch.distributed as tdist
|
| 16 |
+
import torch.multiprocessing as tmp
|
| 17 |
+
from timm import create_model
|
| 18 |
+
from timm.loss import SoftTargetCrossEntropy, BinaryCrossEntropy
|
| 19 |
+
from timm.optim import AdamW, Lamb
|
| 20 |
+
from timm.utils import ModelEmaV2
|
| 21 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 22 |
+
from torch.optim.optimizer import Optimizer
|
| 23 |
+
|
| 24 |
+
from arg import FineTuneArgs
|
| 25 |
+
from downstream_imagenet.mixup import BatchMixup
|
| 26 |
+
from lr_decay import get_param_groups
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def time_str(for_dirname=False):
|
| 30 |
+
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def init_distributed_environ():
|
| 34 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
| 35 |
+
if tmp.get_start_method(allow_none=True) is None:
|
| 36 |
+
tmp.set_start_method('spawn')
|
| 37 |
+
global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
|
| 38 |
+
local_rank = global_rank % num_gpus
|
| 39 |
+
torch.cuda.set_device(local_rank)
|
| 40 |
+
|
| 41 |
+
tdist.init_process_group(backend='nccl')
|
| 42 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
| 43 |
+
torch.backends.cudnn.benchmark = True
|
| 44 |
+
torch.backends.cudnn.deterministic = False
|
| 45 |
+
|
| 46 |
+
# print only when local_rank == 0 or print(..., force=True)
|
| 47 |
+
import builtins as __builtin__
|
| 48 |
+
builtin_print = __builtin__.print
|
| 49 |
+
|
| 50 |
+
def prt(msg, *args, **kwargs):
|
| 51 |
+
force = kwargs.pop('force', False)
|
| 52 |
+
if local_rank == 0 or force:
|
| 53 |
+
f_back = sys._getframe().f_back
|
| 54 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
| 55 |
+
builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}', *args, **kwargs)
|
| 56 |
+
|
| 57 |
+
__builtin__.print = prt
|
| 58 |
+
tdist.barrier()
|
| 59 |
+
return tdist.get_world_size(), global_rank, local_rank, torch.empty(1).cuda().device
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_model_opt(args: FineTuneArgs) -> Tuple[torch.nn.Module, Callable, torch.nn.Module, DistributedDataParallel, ModelEmaV2, Optimizer]:
|
| 63 |
+
num_classes = 1000
|
| 64 |
+
model_without_ddp: torch.nn.Module = create_model(args.model, num_classes=num_classes, drop_path_rate=args.drop_path).to(args.device)
|
| 65 |
+
model_para = f'{sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1e6:.1f}M'
|
| 66 |
+
# create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
| 67 |
+
model_ema = ModelEmaV2(model_without_ddp, decay=args.ema, device=args.device)
|
| 68 |
+
if args.sbn:
|
| 69 |
+
model_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_without_ddp)
|
| 70 |
+
print(f'[model={args.model}] [#para={model_para}, drop_path={args.drop_path}, ema={args.ema}] {model_without_ddp}\n')
|
| 71 |
+
model = DistributedDataParallel(model_without_ddp, device_ids=[args.local_rank], find_unused_parameters=False, broadcast_buffers=False)
|
| 72 |
+
model.train()
|
| 73 |
+
opt_cls = {
|
| 74 |
+
'adam': AdamW, 'adamw': AdamW,
|
| 75 |
+
'lamb': partial(Lamb, max_grad_norm=1e7, always_adapt=True, bias_correction=False),
|
| 76 |
+
}
|
| 77 |
+
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}, lr_scale=args.lr_scale)
|
| 78 |
+
# param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float}
|
| 79 |
+
optimizer = opt_cls[args.opt](param_groups, lr=args.lr, weight_decay=0)
|
| 80 |
+
print(f'[optimizer={type(optimizer)}]')
|
| 81 |
+
mixup_fn = BatchMixup(
|
| 82 |
+
mixup_alpha=args.mixup, cutmix_alpha=1.0, cutmix_minmax=None,
|
| 83 |
+
prob=1.0, switch_prob=0.5, mode='batch',
|
| 84 |
+
label_smoothing=0.1, num_classes=num_classes
|
| 85 |
+
)
|
| 86 |
+
mixup_fn.mixup_enabled = args.mixup > 0.0
|
| 87 |
+
if 'lamb' in args.opt:
|
| 88 |
+
# label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0
|
| 89 |
+
criterion = BinaryCrossEntropy(smoothing=0, target_threshold=None)
|
| 90 |
+
else:
|
| 91 |
+
criterion = SoftTargetCrossEntropy()
|
| 92 |
+
print(f'[loss_fn] {criterion}')
|
| 93 |
+
print(f'[mixup_fn] {mixup_fn}')
|
| 94 |
+
return criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer):
|
| 98 |
+
if len(resume_from) == 0 or not os.path.exists(resume_from):
|
| 99 |
+
raise AttributeError(f'ckpt `{resume_from}` not found!')
|
| 100 |
+
# return 0, '[no performance_desc]'
|
| 101 |
+
print(f'[try to resume from file `{resume_from}`]')
|
| 102 |
+
checkpoint = torch.load(resume_from, map_location='cpu')
|
| 103 |
+
assert checkpoint.get('is_pretrain', False) == False, 'Please do not use `*_withdecoder_1kpretrained_spark_style.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained_timm_style.pth` or `*_1kfinetuned*.pth` instead.'
|
| 104 |
+
|
| 105 |
+
ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]')
|
| 106 |
+
missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
|
| 107 |
+
print(f'[load_checkpoint] missing_keys={missing}')
|
| 108 |
+
print(f'[load_checkpoint] unexpected_keys={unexpected}')
|
| 109 |
+
print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}')
|
| 110 |
+
|
| 111 |
+
if 'optimizer' in checkpoint:
|
| 112 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 113 |
+
if 'ema' in checkpoint:
|
| 114 |
+
ema_module.load_state_dict(checkpoint['ema'])
|
| 115 |
+
return ep_start, performance_desc
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_state, ema_state, optimizer_state):
|
| 119 |
+
checkpoint_path = os.path.join(args.exp_dir, save_to)
|
| 120 |
+
if args.is_local_master:
|
| 121 |
+
to_save = {
|
| 122 |
+
'args': str(args),
|
| 123 |
+
'arch': args.model,
|
| 124 |
+
'epoch': epoch,
|
| 125 |
+
'performance_desc': performance_desc,
|
| 126 |
+
'module': model_without_ddp_state,
|
| 127 |
+
'ema': ema_state,
|
| 128 |
+
'optimizer': optimizer_state,
|
| 129 |
+
'is_pretrain': False,
|
| 130 |
+
}
|
| 131 |
+
torch.save(to_save, checkpoint_path)
|
spark/downstream_mmdet/README.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## About code isolation
|
| 2 |
+
|
| 3 |
+
This `downstream_mmdet` is isolated from pre-training codes. One can treat this `downstream_mmdet` as an independent codebase 🛠️.
|
| 4 |
+
|
| 5 |
+
## Fine-tuned ConvNeXt-B weights, log files, and performance
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
<div align="center">
|
| 9 |
+
|
| 10 |
+
[[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link)]
|
| 11 |
+
[[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1t10dmzg5KOO27o2yIglK-gQepB5gR4zR/view?usp=share_link)]
|
| 12 |
+
[[`log.json`](https://drive.google.com/file/d/1TuNboXl1qwjf1tggZ3QOssI67uU7Jtig/view?usp=share_link)]
|
| 13 |
+
[[`log`](https://drive.google.com/file/d/1JY5CkL_MX08zJ8P1FBIeC60OJsuIiyZc/view?usp=sharing)]
|
| 14 |
+
</div>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
<p align="center">
|
| 18 |
+
<img src="https://user-images.githubusercontent.com/39692511/211497396-cd031318-ef54-45a4-a283-cd9810c15603.png" width=80%>
|
| 19 |
+
<p>
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO
|
| 23 |
+
|
| 24 |
+
We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
|
| 25 |
+
Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions.
|
| 26 |
+
|
| 27 |
+
Note the COCO dataset folder should be at `downstream_mmdet/data/coco`.
|
| 28 |
+
The folder should follow the directory structure requried by `MMDetection`, which should look like this:
|
| 29 |
+
```
|
| 30 |
+
downstream_mmdet/data/coco:
|
| 31 |
+
annotations/:
|
| 32 |
+
captions_train2017.json captions_val2017.json
|
| 33 |
+
instances_train2017.json instances_val2017.json
|
| 34 |
+
person_keypoints_train2017.json person_keypoints_val2017.json
|
| 35 |
+
train2017/:
|
| 36 |
+
a_lot_images.jpg
|
| 37 |
+
val2017/:
|
| 38 |
+
a_lot_images.jpg
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
### Training
|
| 43 |
+
|
| 44 |
+
To train a detector with pre-trained models, run:
|
| 45 |
+
```
|
| 46 |
+
# single-gpu training
|
| 47 |
+
python tools/train.py <CONFIG_FILE> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
|
| 48 |
+
|
| 49 |
+
# multi-gpu training
|
| 50 |
+
tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
|
| 51 |
+
```
|
| 52 |
+
For example, to train a Mask R-CNN model with a SparK pretrained `ConvNeXt-B` backbone and 4 gpus, run:
|
| 53 |
+
```
|
| 54 |
+
tools/dist_train.sh configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py 4 \
|
| 55 |
+
--cfg-options model.pretrained=/some/path/to/official_convnext_base_1kpretrained.pth
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
The Mask R-CNN 3x fine-tuning config file can be found at [`configs/convnext_spark`](configs/convnext_spark). This config is basically a copy of [https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py](https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py).
|
| 59 |
+
|
| 60 |
+
### Inference
|
| 61 |
+
```
|
| 62 |
+
# single-gpu testing
|
| 63 |
+
python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm
|
| 64 |
+
|
| 65 |
+
# multi-gpu testing
|
| 66 |
+
tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Acknowledgment
|
| 70 |
+
|
| 71 |
+
We appreciate these useful codebases:
|
| 72 |
+
|
| 73 |
+
- [MMDetection](https://github.com/open-mmlab/mmdetection)
|
| 74 |
+
- [ConvNeXt](https://github.com/facebookresearch/ConvNeXt)
|
| 75 |
+
- [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection)
|
| 76 |
+
|
spark/downstream_mmdet/configs/_base_/default_runtime.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_config = dict(interval=1)
|
| 2 |
+
# yapf:disable
|
| 3 |
+
log_config = dict(
|
| 4 |
+
interval=50,
|
| 5 |
+
hooks=[
|
| 6 |
+
dict(type='CustomizedTextLoggerHook'),
|
| 7 |
+
# dict(type='TensorboardLoggerHook')
|
| 8 |
+
])
|
| 9 |
+
# yapf:enable
|
| 10 |
+
custom_hooks = [dict(type='NumClassCheckHook')]
|
| 11 |
+
|
| 12 |
+
dist_params = dict(backend='nccl')
|
| 13 |
+
log_level = 'INFO'
|
| 14 |
+
load_from = None
|
| 15 |
+
resume_from = None
|
| 16 |
+
workflow = [('train', 1)]
|
spark/downstream_mmdet/configs/_base_/models/cascade_mask_rcnn_convnext_fpn.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# model settings
|
| 10 |
+
model = dict(
|
| 11 |
+
type='CascadeRCNN',
|
| 12 |
+
pretrained=None,
|
| 13 |
+
backbone=dict(
|
| 14 |
+
type='ConvNeXt',
|
| 15 |
+
in_chans=3,
|
| 16 |
+
depths=[3, 3, 9, 3],
|
| 17 |
+
dims=[96, 192, 384, 768],
|
| 18 |
+
drop_path_rate=0.2,
|
| 19 |
+
layer_scale_init_value=1e-6,
|
| 20 |
+
out_indices=[0, 1, 2, 3],
|
| 21 |
+
),
|
| 22 |
+
neck=dict(
|
| 23 |
+
type='FPN',
|
| 24 |
+
in_channels=[128, 256, 512, 1024],
|
| 25 |
+
out_channels=256,
|
| 26 |
+
num_outs=5),
|
| 27 |
+
rpn_head=dict(
|
| 28 |
+
type='RPNHead',
|
| 29 |
+
in_channels=256,
|
| 30 |
+
feat_channels=256,
|
| 31 |
+
anchor_generator=dict(
|
| 32 |
+
type='AnchorGenerator',
|
| 33 |
+
scales=[8],
|
| 34 |
+
ratios=[0.5, 1.0, 2.0],
|
| 35 |
+
strides=[4, 8, 16, 32, 64]),
|
| 36 |
+
bbox_coder=dict(
|
| 37 |
+
type='DeltaXYWHBBoxCoder',
|
| 38 |
+
target_means=[.0, .0, .0, .0],
|
| 39 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 40 |
+
loss_cls=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 42 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
|
| 43 |
+
roi_head=dict(
|
| 44 |
+
type='CascadeRoIHead',
|
| 45 |
+
num_stages=3,
|
| 46 |
+
stage_loss_weights=[1, 0.5, 0.25],
|
| 47 |
+
bbox_roi_extractor=dict(
|
| 48 |
+
type='SingleRoIExtractor',
|
| 49 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 50 |
+
out_channels=256,
|
| 51 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 52 |
+
bbox_head=[
|
| 53 |
+
dict(
|
| 54 |
+
type='Shared2FCBBoxHead',
|
| 55 |
+
in_channels=256,
|
| 56 |
+
fc_out_channels=1024,
|
| 57 |
+
roi_feat_size=7,
|
| 58 |
+
num_classes=80,
|
| 59 |
+
bbox_coder=dict(
|
| 60 |
+
type='DeltaXYWHBBoxCoder',
|
| 61 |
+
target_means=[0., 0., 0., 0.],
|
| 62 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 63 |
+
reg_class_agnostic=True,
|
| 64 |
+
loss_cls=dict(
|
| 65 |
+
type='CrossEntropyLoss',
|
| 66 |
+
use_sigmoid=False,
|
| 67 |
+
loss_weight=1.0),
|
| 68 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 69 |
+
loss_weight=1.0)),
|
| 70 |
+
dict(
|
| 71 |
+
type='Shared2FCBBoxHead',
|
| 72 |
+
in_channels=256,
|
| 73 |
+
fc_out_channels=1024,
|
| 74 |
+
roi_feat_size=7,
|
| 75 |
+
num_classes=80,
|
| 76 |
+
bbox_coder=dict(
|
| 77 |
+
type='DeltaXYWHBBoxCoder',
|
| 78 |
+
target_means=[0., 0., 0., 0.],
|
| 79 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
| 80 |
+
reg_class_agnostic=True,
|
| 81 |
+
loss_cls=dict(
|
| 82 |
+
type='CrossEntropyLoss',
|
| 83 |
+
use_sigmoid=False,
|
| 84 |
+
loss_weight=1.0),
|
| 85 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
| 86 |
+
loss_weight=1.0)),
|
| 87 |
+
dict(
|
| 88 |
+
type='Shared2FCBBoxHead',
|
| 89 |
+
in_channels=256,
|
| 90 |
+
fc_out_channels=1024,
|
| 91 |
+
roi_feat_size=7,
|
| 92 |
+
num_classes=80,
|
| 93 |
+
bbox_coder=dict(
|
| 94 |
+
type='DeltaXYWHBBoxCoder',
|
| 95 |
+
target_means=[0., 0., 0., 0.],
|
| 96 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
| 97 |
+
reg_class_agnostic=True,
|
| 98 |
+
loss_cls=dict(
|
| 99 |
+
type='CrossEntropyLoss',
|
| 100 |
+
use_sigmoid=False,
|
| 101 |
+
loss_weight=1.0),
|
| 102 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
| 103 |
+
],
|
| 104 |
+
mask_roi_extractor=dict(
|
| 105 |
+
type='SingleRoIExtractor',
|
| 106 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
| 107 |
+
out_channels=256,
|
| 108 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 109 |
+
mask_head=dict(
|
| 110 |
+
type='FCNMaskHead',
|
| 111 |
+
num_convs=4,
|
| 112 |
+
in_channels=256,
|
| 113 |
+
conv_out_channels=256,
|
| 114 |
+
num_classes=80,
|
| 115 |
+
loss_mask=dict(
|
| 116 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
| 117 |
+
# model training and testing settings
|
| 118 |
+
train_cfg = dict(
|
| 119 |
+
rpn=dict(
|
| 120 |
+
assigner=dict(
|
| 121 |
+
type='MaxIoUAssigner',
|
| 122 |
+
pos_iou_thr=0.7,
|
| 123 |
+
neg_iou_thr=0.3,
|
| 124 |
+
min_pos_iou=0.3,
|
| 125 |
+
match_low_quality=True,
|
| 126 |
+
ignore_iof_thr=-1),
|
| 127 |
+
sampler=dict(
|
| 128 |
+
type='RandomSampler',
|
| 129 |
+
num=256,
|
| 130 |
+
pos_fraction=0.5,
|
| 131 |
+
neg_pos_ub=-1,
|
| 132 |
+
add_gt_as_proposals=False),
|
| 133 |
+
allowed_border=0,
|
| 134 |
+
pos_weight=-1,
|
| 135 |
+
debug=False),
|
| 136 |
+
rpn_proposal=dict(
|
| 137 |
+
nms_across_levels=False,
|
| 138 |
+
nms_pre=2000,
|
| 139 |
+
nms_post=2000,
|
| 140 |
+
max_per_img=2000,
|
| 141 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 142 |
+
min_bbox_size=0),
|
| 143 |
+
rcnn=[
|
| 144 |
+
dict(
|
| 145 |
+
assigner=dict(
|
| 146 |
+
type='MaxIoUAssigner',
|
| 147 |
+
pos_iou_thr=0.5,
|
| 148 |
+
neg_iou_thr=0.5,
|
| 149 |
+
min_pos_iou=0.5,
|
| 150 |
+
match_low_quality=False,
|
| 151 |
+
ignore_iof_thr=-1),
|
| 152 |
+
sampler=dict(
|
| 153 |
+
type='RandomSampler',
|
| 154 |
+
num=512,
|
| 155 |
+
pos_fraction=0.25,
|
| 156 |
+
neg_pos_ub=-1,
|
| 157 |
+
add_gt_as_proposals=True),
|
| 158 |
+
mask_size=28,
|
| 159 |
+
pos_weight=-1,
|
| 160 |
+
debug=False),
|
| 161 |
+
dict(
|
| 162 |
+
assigner=dict(
|
| 163 |
+
type='MaxIoUAssigner',
|
| 164 |
+
pos_iou_thr=0.6,
|
| 165 |
+
neg_iou_thr=0.6,
|
| 166 |
+
min_pos_iou=0.6,
|
| 167 |
+
match_low_quality=False,
|
| 168 |
+
ignore_iof_thr=-1),
|
| 169 |
+
sampler=dict(
|
| 170 |
+
type='RandomSampler',
|
| 171 |
+
num=512,
|
| 172 |
+
pos_fraction=0.25,
|
| 173 |
+
neg_pos_ub=-1,
|
| 174 |
+
add_gt_as_proposals=True),
|
| 175 |
+
mask_size=28,
|
| 176 |
+
pos_weight=-1,
|
| 177 |
+
debug=False),
|
| 178 |
+
dict(
|
| 179 |
+
assigner=dict(
|
| 180 |
+
type='MaxIoUAssigner',
|
| 181 |
+
pos_iou_thr=0.7,
|
| 182 |
+
neg_iou_thr=0.7,
|
| 183 |
+
min_pos_iou=0.7,
|
| 184 |
+
match_low_quality=False,
|
| 185 |
+
ignore_iof_thr=-1),
|
| 186 |
+
sampler=dict(
|
| 187 |
+
type='RandomSampler',
|
| 188 |
+
num=512,
|
| 189 |
+
pos_fraction=0.25,
|
| 190 |
+
neg_pos_ub=-1,
|
| 191 |
+
add_gt_as_proposals=True),
|
| 192 |
+
mask_size=28,
|
| 193 |
+
pos_weight=-1,
|
| 194 |
+
debug=False)
|
| 195 |
+
]),
|
| 196 |
+
test_cfg = dict(
|
| 197 |
+
rpn=dict(
|
| 198 |
+
nms_across_levels=False,
|
| 199 |
+
nms_pre=1000,
|
| 200 |
+
nms_post=1000,
|
| 201 |
+
max_per_img=1000,
|
| 202 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 203 |
+
min_bbox_size=0),
|
| 204 |
+
rcnn=dict(
|
| 205 |
+
score_thr=0.05,
|
| 206 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 207 |
+
max_per_img=100,
|
| 208 |
+
mask_thr_binary=0.5)))
|
spark/downstream_mmdet/configs/_base_/models/mask_rcnn_convnext_fpn.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# model settings
|
| 10 |
+
model = dict(
|
| 11 |
+
type='MaskRCNN',
|
| 12 |
+
pretrained=None,
|
| 13 |
+
backbone=dict(
|
| 14 |
+
type='ConvNeXt',
|
| 15 |
+
in_chans=3,
|
| 16 |
+
depths=[3, 3, 9, 3],
|
| 17 |
+
dims=[96, 192, 384, 768],
|
| 18 |
+
drop_path_rate=0.2,
|
| 19 |
+
layer_scale_init_value=1e-6,
|
| 20 |
+
out_indices=[0, 1, 2, 3],
|
| 21 |
+
),
|
| 22 |
+
neck=dict(
|
| 23 |
+
type='FPN',
|
| 24 |
+
in_channels=[128, 256, 512, 1024],
|
| 25 |
+
out_channels=256,
|
| 26 |
+
num_outs=5),
|
| 27 |
+
rpn_head=dict(
|
| 28 |
+
type='RPNHead',
|
| 29 |
+
in_channels=256,
|
| 30 |
+
feat_channels=256,
|
| 31 |
+
anchor_generator=dict(
|
| 32 |
+
type='AnchorGenerator',
|
| 33 |
+
scales=[8],
|
| 34 |
+
ratios=[0.5, 1.0, 2.0],
|
| 35 |
+
strides=[4, 8, 16, 32, 64]),
|
| 36 |
+
bbox_coder=dict(
|
| 37 |
+
type='DeltaXYWHBBoxCoder',
|
| 38 |
+
target_means=[.0, .0, .0, .0],
|
| 39 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
| 40 |
+
loss_cls=dict(
|
| 41 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
| 42 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 43 |
+
roi_head=dict(
|
| 44 |
+
type='StandardRoIHead',
|
| 45 |
+
bbox_roi_extractor=dict(
|
| 46 |
+
type='SingleRoIExtractor',
|
| 47 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
| 48 |
+
out_channels=256,
|
| 49 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 50 |
+
bbox_head=dict(
|
| 51 |
+
type='Shared2FCBBoxHead',
|
| 52 |
+
in_channels=256,
|
| 53 |
+
fc_out_channels=1024,
|
| 54 |
+
roi_feat_size=7,
|
| 55 |
+
num_classes=80,
|
| 56 |
+
bbox_coder=dict(
|
| 57 |
+
type='DeltaXYWHBBoxCoder',
|
| 58 |
+
target_means=[0., 0., 0., 0.],
|
| 59 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
| 60 |
+
reg_class_agnostic=False,
|
| 61 |
+
loss_cls=dict(
|
| 62 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
| 63 |
+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
|
| 64 |
+
mask_roi_extractor=dict(
|
| 65 |
+
type='SingleRoIExtractor',
|
| 66 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
| 67 |
+
out_channels=256,
|
| 68 |
+
featmap_strides=[4, 8, 16, 32]),
|
| 69 |
+
mask_head=dict(
|
| 70 |
+
type='FCNMaskHead',
|
| 71 |
+
num_convs=4,
|
| 72 |
+
in_channels=256,
|
| 73 |
+
conv_out_channels=256,
|
| 74 |
+
num_classes=80,
|
| 75 |
+
loss_mask=dict(
|
| 76 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
| 77 |
+
# model training and testing settings
|
| 78 |
+
train_cfg=dict(
|
| 79 |
+
rpn=dict(
|
| 80 |
+
assigner=dict(
|
| 81 |
+
type='MaxIoUAssigner',
|
| 82 |
+
pos_iou_thr=0.7,
|
| 83 |
+
neg_iou_thr=0.3,
|
| 84 |
+
min_pos_iou=0.3,
|
| 85 |
+
match_low_quality=True,
|
| 86 |
+
ignore_iof_thr=-1),
|
| 87 |
+
sampler=dict(
|
| 88 |
+
type='RandomSampler',
|
| 89 |
+
num=256,
|
| 90 |
+
pos_fraction=0.5,
|
| 91 |
+
neg_pos_ub=-1,
|
| 92 |
+
add_gt_as_proposals=False),
|
| 93 |
+
allowed_border=-1,
|
| 94 |
+
pos_weight=-1,
|
| 95 |
+
debug=False),
|
| 96 |
+
rpn_proposal=dict(
|
| 97 |
+
nms_pre=2000,
|
| 98 |
+
max_per_img=1000,
|
| 99 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 100 |
+
min_bbox_size=0),
|
| 101 |
+
rcnn=dict(
|
| 102 |
+
assigner=dict(
|
| 103 |
+
type='MaxIoUAssigner',
|
| 104 |
+
pos_iou_thr=0.5,
|
| 105 |
+
neg_iou_thr=0.5,
|
| 106 |
+
min_pos_iou=0.5,
|
| 107 |
+
match_low_quality=True,
|
| 108 |
+
ignore_iof_thr=-1),
|
| 109 |
+
sampler=dict(
|
| 110 |
+
type='RandomSampler',
|
| 111 |
+
num=512,
|
| 112 |
+
pos_fraction=0.25,
|
| 113 |
+
neg_pos_ub=-1,
|
| 114 |
+
add_gt_as_proposals=True),
|
| 115 |
+
mask_size=28,
|
| 116 |
+
pos_weight=-1,
|
| 117 |
+
debug=False)),
|
| 118 |
+
test_cfg=dict(
|
| 119 |
+
rpn=dict(
|
| 120 |
+
nms_pre=1000,
|
| 121 |
+
max_per_img=1000,
|
| 122 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
| 123 |
+
min_bbox_size=0),
|
| 124 |
+
rcnn=dict(
|
| 125 |
+
score_thr=0.05,
|
| 126 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
| 127 |
+
max_per_img=100,
|
| 128 |
+
mask_thr_binary=0.5)))
|
spark/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
We directly take the ConvNeXt-T+MaskRCNN 3x recipe from https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
|
| 3 |
+
And we modify this ConvNeXt-T+MaskRCNN 3x recipe to our ConvNeXt-B+MaskRCNN 3x recipe.
|
| 4 |
+
The modifications (commented as [modified] below) are according to:
|
| 5 |
+
- 1. tiny-to-base: (some configs of ConvNext-T are updated to those of ConvNext-B, referring to https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/cascade_mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco_in22k.py)
|
| 6 |
+
- model.backbone.{depths, dims, drop_path_rate}
|
| 7 |
+
- models.neck
|
| 8 |
+
- optimizer.paramwise_cfg.num_layers
|
| 9 |
+
|
| 10 |
+
- 2. our paper (https://openreview.net/forum?id=NRxydtWup1S, or https://arxiv.org/abs/2301.03580):
|
| 11 |
+
- LR layer decay (optimizer.paramwise_cfg.decay_rate): 0.65
|
| 12 |
+
- LR scheduled ratio (lr_config.gamma): 0.2
|
| 13 |
+
- Learning rate (optimizer.lr): 0.0002
|
| 14 |
+
- optimizer_config.use_fp16: False (we just use fp32 by default; actually we didn't test the performance of using fp16)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
_base_ = [
|
| 18 |
+
'../_base_/models/mask_rcnn_convnext_fpn.py',
|
| 19 |
+
'../_base_/datasets/coco_instance.py',
|
| 20 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
model = dict(
|
| 24 |
+
backbone=dict(
|
| 25 |
+
in_chans=3,
|
| 26 |
+
depths=[3, 3, 27, 3], # [modified] according to tiny-to-base
|
| 27 |
+
dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base
|
| 28 |
+
drop_path_rate=0.5, # [modified] according to tiny-to-base
|
| 29 |
+
layer_scale_init_value=1.0,
|
| 30 |
+
out_indices=[0, 1, 2, 3],
|
| 31 |
+
),
|
| 32 |
+
neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base
|
| 33 |
+
|
| 34 |
+
img_norm_cfg = dict(
|
| 35 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
| 36 |
+
|
| 37 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
| 38 |
+
train_pipeline = [
|
| 39 |
+
dict(type='LoadImageFromFile'),
|
| 40 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
| 41 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
| 42 |
+
dict(type='AutoAugment',
|
| 43 |
+
policies=[
|
| 44 |
+
[
|
| 45 |
+
dict(type='Resize',
|
| 46 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
| 47 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
| 48 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
| 49 |
+
multiscale_mode='value',
|
| 50 |
+
keep_ratio=True)
|
| 51 |
+
],
|
| 52 |
+
[
|
| 53 |
+
dict(type='Resize',
|
| 54 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
| 55 |
+
multiscale_mode='value',
|
| 56 |
+
keep_ratio=True),
|
| 57 |
+
dict(type='RandomCrop',
|
| 58 |
+
crop_type='absolute_range',
|
| 59 |
+
crop_size=(384, 600),
|
| 60 |
+
allow_negative_crop=True),
|
| 61 |
+
dict(type='Resize',
|
| 62 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
| 63 |
+
(576, 1333), (608, 1333), (640, 1333),
|
| 64 |
+
(672, 1333), (704, 1333), (736, 1333),
|
| 65 |
+
(768, 1333), (800, 1333)],
|
| 66 |
+
multiscale_mode='value',
|
| 67 |
+
override=True,
|
| 68 |
+
keep_ratio=True)
|
| 69 |
+
]
|
| 70 |
+
]),
|
| 71 |
+
dict(type='Normalize', **img_norm_cfg),
|
| 72 |
+
dict(type='Pad', size_divisor=32),
|
| 73 |
+
dict(type='DefaultFormatBundle'),
|
| 74 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
| 75 |
+
]
|
| 76 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
| 77 |
+
|
| 78 |
+
optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
|
| 79 |
+
lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05, # [modified] according to our paper
|
| 80 |
+
paramwise_cfg={'decay_rate': 0.65, # [modified] according to our paper
|
| 81 |
+
'decay_type': 'layer_wise',
|
| 82 |
+
'num_layers': 12}) # [modified] according to tiny-to-base
|
| 83 |
+
lr_config = dict(step=[27, 33], gamma=0.2) # [modified] according to our paper
|
| 84 |
+
runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
|
| 85 |
+
|
| 86 |
+
# do not use mmdet version fp16
|
| 87 |
+
fp16 = None
|
| 88 |
+
optimizer_config = dict(
|
| 89 |
+
type="DistOptimizerHook",
|
| 90 |
+
update_interval=1,
|
| 91 |
+
grad_clip=None,
|
| 92 |
+
coalesce=True,
|
| 93 |
+
bucket_size_mb=-1,
|
| 94 |
+
use_fp16=False, # [modified] True => False
|
| 95 |
+
)
|
spark/downstream_mmdet/mmcv_custom/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# -*- coding: utf-8 -*-
|
| 10 |
+
|
| 11 |
+
from .checkpoint import load_checkpoint
|
| 12 |
+
from .layer_decay_optimizer_constructor import LearningRateDecayOptimizerConstructor
|
| 13 |
+
from .customized_text import CustomizedTextLoggerHook
|
| 14 |
+
|
| 15 |
+
__all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'CustomizedTextLoggerHook']
|
spark/downstream_mmdet/mmcv_custom/customized_text.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import datetime
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
import mmcv
|
| 15 |
+
from mmcv.runner import HOOKS
|
| 16 |
+
from mmcv.runner import TextLoggerHook
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@HOOKS.register_module()
|
| 20 |
+
class CustomizedTextLoggerHook(TextLoggerHook):
|
| 21 |
+
"""Customized Text Logger hook.
|
| 22 |
+
|
| 23 |
+
This logger prints out both lr and layer_0_lr.
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def _log_info(self, log_dict, runner):
|
| 28 |
+
# print exp name for users to distinguish experiments
|
| 29 |
+
# at every ``interval_exp_name`` iterations and the end of each epoch
|
| 30 |
+
if runner.meta is not None and 'exp_name' in runner.meta:
|
| 31 |
+
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
| 32 |
+
self.by_epoch and self.end_of_epoch(runner)):
|
| 33 |
+
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
| 34 |
+
runner.logger.info(exp_info)
|
| 35 |
+
|
| 36 |
+
if log_dict['mode'] == 'train':
|
| 37 |
+
lr_str = {}
|
| 38 |
+
for lr_type in ['lr', 'layer_0_lr']:
|
| 39 |
+
if isinstance(log_dict[lr_type], dict):
|
| 40 |
+
lr_str[lr_type] = []
|
| 41 |
+
for k, val in log_dict[lr_type].items():
|
| 42 |
+
lr_str.append(f'{lr_type}_{k}: {val:.3e}')
|
| 43 |
+
lr_str[lr_type] = ' '.join(lr_str)
|
| 44 |
+
else:
|
| 45 |
+
lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}'
|
| 46 |
+
|
| 47 |
+
# by epoch: Epoch [4][100/1000]
|
| 48 |
+
# by iter: Iter [100/100000]
|
| 49 |
+
if self.by_epoch:
|
| 50 |
+
log_str = f'Epoch [{log_dict["epoch"]}]' \
|
| 51 |
+
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
|
| 52 |
+
else:
|
| 53 |
+
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
|
| 54 |
+
log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, '
|
| 55 |
+
|
| 56 |
+
if 'time' in log_dict.keys():
|
| 57 |
+
self.time_sec_tot += (log_dict['time'] * self.interval)
|
| 58 |
+
time_sec_avg = self.time_sec_tot / (
|
| 59 |
+
runner.iter - self.start_iter + 1)
|
| 60 |
+
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
| 61 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
| 62 |
+
log_str += f'eta: {eta_str}, '
|
| 63 |
+
log_str += f'time: {log_dict["time"]:.3f}, ' \
|
| 64 |
+
f'data_time: {log_dict["data_time"]:.3f}, '
|
| 65 |
+
# statistic memory
|
| 66 |
+
if torch.cuda.is_available():
|
| 67 |
+
log_str += f'memory: {log_dict["memory"]}, '
|
| 68 |
+
else:
|
| 69 |
+
# val/test time
|
| 70 |
+
# here 1000 is the length of the val dataloader
|
| 71 |
+
# by epoch: Epoch[val] [4][1000]
|
| 72 |
+
# by iter: Iter[val] [1000]
|
| 73 |
+
if self.by_epoch:
|
| 74 |
+
log_str = f'Epoch({log_dict["mode"]}) ' \
|
| 75 |
+
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
|
| 76 |
+
else:
|
| 77 |
+
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
|
| 78 |
+
|
| 79 |
+
log_items = []
|
| 80 |
+
for name, val in log_dict.items():
|
| 81 |
+
# TODO: resolve this hack
|
| 82 |
+
# these items have been in log_str
|
| 83 |
+
if name in [
|
| 84 |
+
'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time',
|
| 85 |
+
'memory', 'epoch'
|
| 86 |
+
]:
|
| 87 |
+
continue
|
| 88 |
+
if isinstance(val, float):
|
| 89 |
+
val = f'{val:.4f}'
|
| 90 |
+
log_items.append(f'{name}: {val}')
|
| 91 |
+
log_str += ', '.join(log_items)
|
| 92 |
+
|
| 93 |
+
runner.logger.info(log_str)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def log(self, runner):
|
| 97 |
+
if 'eval_iter_num' in runner.log_buffer.output:
|
| 98 |
+
# this doesn't modify runner.iter and is regardless of by_epoch
|
| 99 |
+
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
|
| 100 |
+
else:
|
| 101 |
+
cur_iter = self.get_iter(runner, inner_iter=True)
|
| 102 |
+
|
| 103 |
+
log_dict = OrderedDict(
|
| 104 |
+
mode=self.get_mode(runner),
|
| 105 |
+
epoch=self.get_epoch(runner),
|
| 106 |
+
iter=cur_iter)
|
| 107 |
+
|
| 108 |
+
# record lr and layer_0_lr
|
| 109 |
+
cur_lr = runner.current_lr()
|
| 110 |
+
if isinstance(cur_lr, list):
|
| 111 |
+
log_dict['layer_0_lr'] = min(cur_lr)
|
| 112 |
+
log_dict['lr'] = max(cur_lr)
|
| 113 |
+
else:
|
| 114 |
+
assert isinstance(cur_lr, dict)
|
| 115 |
+
log_dict['lr'], log_dict['layer_0_lr'] = {}, {}
|
| 116 |
+
for k, lr_ in cur_lr.items():
|
| 117 |
+
assert isinstance(lr_, list)
|
| 118 |
+
log_dict['layer_0_lr'].update({k: min(lr_)})
|
| 119 |
+
log_dict['lr'].update({k: max(lr_)})
|
| 120 |
+
|
| 121 |
+
if 'time' in runner.log_buffer.output:
|
| 122 |
+
# statistic memory
|
| 123 |
+
if torch.cuda.is_available():
|
| 124 |
+
log_dict['memory'] = self._get_max_memory(runner)
|
| 125 |
+
|
| 126 |
+
log_dict = dict(log_dict, **runner.log_buffer.output)
|
| 127 |
+
|
| 128 |
+
self._log_info(log_dict, runner)
|
| 129 |
+
self._dump_log(log_dict, runner)
|
| 130 |
+
return log_dict
|
spark/downstream_mmdet/mmcv_custom/layer_decay_optimizer_constructor.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
|
| 11 |
+
from mmcv.runner import get_dist_info
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_num_layer_layer_wise(var_name, num_max_layer=12):
|
| 15 |
+
|
| 16 |
+
if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
|
| 17 |
+
return 0
|
| 18 |
+
elif var_name.startswith("backbone.downsample_layers"):
|
| 19 |
+
stage_id = int(var_name.split('.')[2])
|
| 20 |
+
if stage_id == 0:
|
| 21 |
+
layer_id = 0
|
| 22 |
+
elif stage_id == 1:
|
| 23 |
+
layer_id = 2
|
| 24 |
+
elif stage_id == 2:
|
| 25 |
+
layer_id = 3
|
| 26 |
+
elif stage_id == 3:
|
| 27 |
+
layer_id = num_max_layer
|
| 28 |
+
return layer_id
|
| 29 |
+
elif var_name.startswith("backbone.stages"):
|
| 30 |
+
stage_id = int(var_name.split('.')[2])
|
| 31 |
+
block_id = int(var_name.split('.')[3])
|
| 32 |
+
if stage_id == 0:
|
| 33 |
+
layer_id = 1
|
| 34 |
+
elif stage_id == 1:
|
| 35 |
+
layer_id = 2
|
| 36 |
+
elif stage_id == 2:
|
| 37 |
+
layer_id = 3 + block_id // 3
|
| 38 |
+
elif stage_id == 3:
|
| 39 |
+
layer_id = num_max_layer
|
| 40 |
+
return layer_id
|
| 41 |
+
else:
|
| 42 |
+
return num_max_layer + 1
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_num_layer_stage_wise(var_name, num_max_layer):
|
| 46 |
+
if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
|
| 47 |
+
return 0
|
| 48 |
+
elif var_name.startswith("backbone.downsample_layers"):
|
| 49 |
+
return 0
|
| 50 |
+
elif var_name.startswith("backbone.stages"):
|
| 51 |
+
stage_id = int(var_name.split('.')[2])
|
| 52 |
+
return stage_id + 1
|
| 53 |
+
else:
|
| 54 |
+
return num_max_layer - 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@OPTIMIZER_BUILDERS.register_module()
|
| 58 |
+
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
| 59 |
+
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
| 60 |
+
"""Add all parameters of module to the params list.
|
| 61 |
+
The parameters of the given module will be added to the list of param
|
| 62 |
+
groups, with specific rules defined by paramwise_cfg.
|
| 63 |
+
Args:
|
| 64 |
+
params (list[dict]): A list of param groups, it will be modified
|
| 65 |
+
in place.
|
| 66 |
+
module (nn.Module): The module to be added.
|
| 67 |
+
prefix (str): The prefix of the module
|
| 68 |
+
is_dcn_module (int|float|None): If the current module is a
|
| 69 |
+
submodule of DCN, `is_dcn_module` will be passed to
|
| 70 |
+
control conv_offset layer's learning rate. Defaults to None.
|
| 71 |
+
"""
|
| 72 |
+
parameter_groups = {}
|
| 73 |
+
print(self.paramwise_cfg)
|
| 74 |
+
num_layers = self.paramwise_cfg.get('num_layers') + 2
|
| 75 |
+
decay_rate = self.paramwise_cfg.get('decay_rate')
|
| 76 |
+
decay_type = self.paramwise_cfg.get('decay_type', "layer_wise")
|
| 77 |
+
print("Build LearningRateDecayOptimizerConstructor %s %f - %d" % (decay_type, decay_rate, num_layers))
|
| 78 |
+
weight_decay = self.base_wd
|
| 79 |
+
|
| 80 |
+
for name, param in module.named_parameters():
|
| 81 |
+
if not param.requires_grad:
|
| 82 |
+
continue # frozen weights
|
| 83 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'):
|
| 84 |
+
group_name = "no_decay"
|
| 85 |
+
this_weight_decay = 0.
|
| 86 |
+
else:
|
| 87 |
+
group_name = "decay"
|
| 88 |
+
this_weight_decay = weight_decay
|
| 89 |
+
|
| 90 |
+
if decay_type == "layer_wise":
|
| 91 |
+
layer_id = get_num_layer_layer_wise(name, self.paramwise_cfg.get('num_layers'))
|
| 92 |
+
elif decay_type == "stage_wise":
|
| 93 |
+
layer_id = get_num_layer_stage_wise(name, num_layers)
|
| 94 |
+
|
| 95 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 96 |
+
|
| 97 |
+
if group_name not in parameter_groups:
|
| 98 |
+
scale = decay_rate ** (num_layers - layer_id - 1)
|
| 99 |
+
|
| 100 |
+
parameter_groups[group_name] = {
|
| 101 |
+
"weight_decay": this_weight_decay,
|
| 102 |
+
"params": [],
|
| 103 |
+
"param_names": [],
|
| 104 |
+
"lr_scale": scale,
|
| 105 |
+
"group_name": group_name,
|
| 106 |
+
"lr": scale * self.base_lr,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
parameter_groups[group_name]["params"].append(param)
|
| 110 |
+
parameter_groups[group_name]["param_names"].append(name)
|
| 111 |
+
rank, _ = get_dist_info()
|
| 112 |
+
if rank == 0:
|
| 113 |
+
to_display = {}
|
| 114 |
+
for key in parameter_groups:
|
| 115 |
+
to_display[key] = {
|
| 116 |
+
"param_names": parameter_groups[key]["param_names"],
|
| 117 |
+
"lr_scale": parameter_groups[key]["lr_scale"],
|
| 118 |
+
"lr": parameter_groups[key]["lr"],
|
| 119 |
+
"weight_decay": parameter_groups[key]["weight_decay"],
|
| 120 |
+
}
|
| 121 |
+
print("Param groups = %s" % json.dumps(to_display, indent=2))
|
| 122 |
+
|
| 123 |
+
params.extend(parameter_groups.values())
|
spark/downstream_mmdet/mmcv_custom/runner/checkpoint.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import time
|
| 4 |
+
from tempfile import TemporaryDirectory
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.optim import Optimizer
|
| 8 |
+
|
| 9 |
+
import mmcv
|
| 10 |
+
from mmcv.parallel import is_module_wrapper
|
| 11 |
+
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import apex
|
| 15 |
+
except:
|
| 16 |
+
print('apex is not installed')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
| 20 |
+
"""Save checkpoint to file.
|
| 21 |
+
|
| 22 |
+
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
|
| 23 |
+
``optimizer``, ``amp``. By default ``meta`` will contain version
|
| 24 |
+
and time info.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model (Module): Module whose params are to be saved.
|
| 28 |
+
filename (str): Checkpoint filename.
|
| 29 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
| 30 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
| 31 |
+
"""
|
| 32 |
+
if meta is None:
|
| 33 |
+
meta = {}
|
| 34 |
+
elif not isinstance(meta, dict):
|
| 35 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
| 36 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
| 37 |
+
|
| 38 |
+
if is_module_wrapper(model):
|
| 39 |
+
model = model.module
|
| 40 |
+
|
| 41 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
| 42 |
+
# save class name to the meta
|
| 43 |
+
meta.update(CLASSES=model.CLASSES)
|
| 44 |
+
|
| 45 |
+
checkpoint = {
|
| 46 |
+
'meta': meta,
|
| 47 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
| 48 |
+
}
|
| 49 |
+
# save optimizer state dict in the checkpoint
|
| 50 |
+
if isinstance(optimizer, Optimizer):
|
| 51 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
| 52 |
+
elif isinstance(optimizer, dict):
|
| 53 |
+
checkpoint['optimizer'] = {}
|
| 54 |
+
for name, optim in optimizer.items():
|
| 55 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
| 56 |
+
|
| 57 |
+
# save amp state dict in the checkpoint
|
| 58 |
+
# checkpoint['amp'] = apex.amp.state_dict()
|
| 59 |
+
|
| 60 |
+
if filename.startswith('pavi://'):
|
| 61 |
+
try:
|
| 62 |
+
from pavi import modelcloud
|
| 63 |
+
from pavi.exception import NodeNotFoundError
|
| 64 |
+
except ImportError:
|
| 65 |
+
raise ImportError(
|
| 66 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
| 67 |
+
model_path = filename[7:]
|
| 68 |
+
root = modelcloud.Folder()
|
| 69 |
+
model_dir, model_name = osp.split(model_path)
|
| 70 |
+
try:
|
| 71 |
+
model = modelcloud.get(model_dir)
|
| 72 |
+
except NodeNotFoundError:
|
| 73 |
+
model = root.create_training_model(model_dir)
|
| 74 |
+
with TemporaryDirectory() as tmp_dir:
|
| 75 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
| 76 |
+
with open(checkpoint_file, 'wb') as f:
|
| 77 |
+
torch.save(checkpoint, f)
|
| 78 |
+
f.flush()
|
| 79 |
+
model.create_file(checkpoint_file, name=model_name)
|
| 80 |
+
else:
|
| 81 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
| 82 |
+
# immediately flush buffer
|
| 83 |
+
with open(filename, 'wb') as f:
|
| 84 |
+
torch.save(checkpoint, f)
|
| 85 |
+
f.flush()
|
spark/downstream_mmdet/mmdet/models/backbones/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .darknet import Darknet
|
| 2 |
+
from .detectors_resnet import DetectoRS_ResNet
|
| 3 |
+
from .detectors_resnext import DetectoRS_ResNeXt
|
| 4 |
+
from .hourglass import HourglassNet
|
| 5 |
+
from .hrnet import HRNet
|
| 6 |
+
from .regnet import RegNet
|
| 7 |
+
from .res2net import Res2Net
|
| 8 |
+
from .resnest import ResNeSt
|
| 9 |
+
from .resnet import ResNet, ResNetV1d
|
| 10 |
+
from .resnext import ResNeXt
|
| 11 |
+
from .ssd_vgg import SSDVGG
|
| 12 |
+
from .trident_resnet import TridentResNet
|
| 13 |
+
from .swin_transformer import SwinTransformer
|
| 14 |
+
from .convnext import ConvNeXt
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
|
| 18 |
+
'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
|
| 19 |
+
'ResNeSt', 'TridentResNet', 'SwinTransformer', 'ConvNeXt'
|
| 20 |
+
]
|
spark/downstream_mmdet/mmdet/models/backbones/convnext.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from functools import partial
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 14 |
+
|
| 15 |
+
from mmcv_custom import load_checkpoint
|
| 16 |
+
from mmdet.utils import get_root_logger
|
| 17 |
+
from ..builder import BACKBONES
|
| 18 |
+
|
| 19 |
+
class Block(nn.Module):
|
| 20 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 21 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 22 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 23 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
dim (int): Number of input channels.
|
| 27 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 28 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 33 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 34 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 35 |
+
self.act = nn.GELU()
|
| 36 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 37 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
| 38 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 39 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
input = x
|
| 43 |
+
x = self.dwconv(x)
|
| 44 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 45 |
+
x = self.norm(x)
|
| 46 |
+
x = self.pwconv1(x)
|
| 47 |
+
x = self.act(x)
|
| 48 |
+
x = self.pwconv2(x)
|
| 49 |
+
if self.gamma is not None:
|
| 50 |
+
x = self.gamma * x
|
| 51 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 52 |
+
|
| 53 |
+
x = input + self.drop_path(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
@BACKBONES.register_module()
|
| 57 |
+
class ConvNeXt(nn.Module):
|
| 58 |
+
r""" ConvNeXt
|
| 59 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
| 60 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 64 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 65 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 66 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 67 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 68 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 69 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
|
| 72 |
+
drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
|
| 73 |
+
):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
| 77 |
+
stem = nn.Sequential(
|
| 78 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 79 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
| 80 |
+
)
|
| 81 |
+
self.downsample_layers.append(stem)
|
| 82 |
+
for i in range(3):
|
| 83 |
+
downsample_layer = nn.Sequential(
|
| 84 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 85 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
| 86 |
+
)
|
| 87 |
+
self.downsample_layers.append(downsample_layer)
|
| 88 |
+
|
| 89 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 90 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 91 |
+
cur = 0
|
| 92 |
+
for i in range(4):
|
| 93 |
+
stage = nn.Sequential(
|
| 94 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
| 95 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
| 96 |
+
)
|
| 97 |
+
self.stages.append(stage)
|
| 98 |
+
cur += depths[i]
|
| 99 |
+
|
| 100 |
+
self.out_indices = out_indices
|
| 101 |
+
|
| 102 |
+
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
|
| 103 |
+
for i_layer in range(4):
|
| 104 |
+
layer = norm_layer(dims[i_layer])
|
| 105 |
+
layer_name = f'norm{i_layer}'
|
| 106 |
+
self.add_module(layer_name, layer)
|
| 107 |
+
|
| 108 |
+
self.apply(self._init_weights)
|
| 109 |
+
|
| 110 |
+
def _init_weights(self, m):
|
| 111 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 112 |
+
trunc_normal_(m.weight, std=.02)
|
| 113 |
+
nn.init.constant_(m.bias, 0)
|
| 114 |
+
|
| 115 |
+
def init_weights(self, pretrained=None):
|
| 116 |
+
"""Initialize the weights in backbone.
|
| 117 |
+
Args:
|
| 118 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 119 |
+
Defaults to None.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def _init_weights(m):
|
| 123 |
+
if isinstance(m, nn.Linear):
|
| 124 |
+
trunc_normal_(m.weight, std=.02)
|
| 125 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 126 |
+
nn.init.constant_(m.bias, 0)
|
| 127 |
+
elif isinstance(m, nn.LayerNorm):
|
| 128 |
+
nn.init.constant_(m.bias, 0)
|
| 129 |
+
nn.init.constant_(m.weight, 1.0)
|
| 130 |
+
|
| 131 |
+
if isinstance(pretrained, str):
|
| 132 |
+
self.apply(_init_weights)
|
| 133 |
+
logger = get_root_logger()
|
| 134 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
| 135 |
+
elif pretrained is None:
|
| 136 |
+
self.apply(_init_weights)
|
| 137 |
+
else:
|
| 138 |
+
raise TypeError('pretrained must be a str or None')
|
| 139 |
+
|
| 140 |
+
def forward_features(self, x):
|
| 141 |
+
outs = []
|
| 142 |
+
for i in range(4):
|
| 143 |
+
x = self.downsample_layers[i](x)
|
| 144 |
+
x = self.stages[i](x)
|
| 145 |
+
if i in self.out_indices:
|
| 146 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 147 |
+
x_out = norm_layer(x)
|
| 148 |
+
outs.append(x_out)
|
| 149 |
+
|
| 150 |
+
return tuple(outs)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
x = self.forward_features(x)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
class LayerNorm(nn.Module):
|
| 157 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 158 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 159 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 160 |
+
with shape (batch_size, channels, height, width).
|
| 161 |
+
"""
|
| 162 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 165 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 166 |
+
self.eps = eps
|
| 167 |
+
self.data_format = data_format
|
| 168 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 169 |
+
raise NotImplementedError
|
| 170 |
+
self.normalized_shape = (normalized_shape, )
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
if self.data_format == "channels_last":
|
| 174 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 175 |
+
elif self.data_format == "channels_first":
|
| 176 |
+
u = x.mean(1, keepdim=True)
|
| 177 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 178 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 179 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 180 |
+
return x
|
spark/pretrain/README.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Preparation for ImageNet-1k pretraining
|
| 2 |
+
|
| 3 |
+
See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
|
| 4 |
+
|
| 5 |
+
**Note: for neural network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## Tutorial for pretraining your own CNN model
|
| 9 |
+
|
| 10 |
+
See [/pretrain/models/custom.py](/pretrain/models/custom.py). Your todo list is:
|
| 11 |
+
|
| 12 |
+
- implement `get_downsample_ratio` in [/pretrain/models/custom.py line20](/pretrain/models/custom.py#L20).
|
| 13 |
+
- implement `get_feature_map_channels` in [/pretrain/models/custom.py line29](/pretrain/models/custom.py#L29).
|
| 14 |
+
- implement `forward` in [/pretrain/models/custom.py line38](/pretrain/models/custom.py#L38).
|
| 15 |
+
- define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py line54](/pretrain/models/custom.py#L53-L54).
|
| 16 |
+
- add default kwargs of `your_convnet(...)` in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34).
|
| 17 |
+
- **Note: see [#54](/../../issues/54) if your CNN contains SE module or global average pooling layer, and see [#56](/../../issues/56) if it contains GroupNorm**.
|
| 18 |
+
|
| 19 |
+
Then run the experiment with `--model=your_convnet`.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Tutorial for pretraining on your own dataset
|
| 23 |
+
|
| 24 |
+
See the comment of `build_dataset_to_pretrain` in [line55 of /pretrain/utils/imagenet.py](/pretrain/utils/imagenet.py#L55). Your todo list:
|
| 25 |
+
|
| 26 |
+
- Define a subclass of `torch.utils.data.Dataset` for your own unlabeled dataset, to replace our `ImageNetDataset`.
|
| 27 |
+
- Use `args.data_path` and `args.input_size` to help build your dataset, with `--data_path=... --input_size=...` to specify them.
|
| 28 |
+
- Note the batch size `--bs` is the total batch size of all GPU, which may need to be adjusted based on your dataset size. FYI: we use `--bs=4096` for ImageNet, which contains 1.28 million images.
|
| 29 |
+
|
| 30 |
+
**If your dataset is relatively small**, you can try `--init_weight=/path/to/res50_withdecoder_1kpretrained_spark_style.pth` to do your pretraining *from our pretrained weights*, rather than *form scratch*.
|
| 31 |
+
|
| 32 |
+
## Debug on 1 GPU (without DistributedDataParallel)
|
| 33 |
+
|
| 34 |
+
Use a small batch size `--bs=32` for avoiding OOM.
|
| 35 |
+
|
| 36 |
+
```shell script
|
| 37 |
+
python3 main.py --exp_name=debug --data_path=/path/to/imagenet --model=resnet50 --bs=32
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## Pretraining Any Model on ImageNet-1k (224x224)
|
| 42 |
+
|
| 43 |
+
For pretraining, run [/pretrain/main.py](/pretrain/main.py) with `torchrun`.
|
| 44 |
+
**It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), and the model name (`--model`, valid choices see the keys of 'pretrain_default_model_kwargs' in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34)).
|
| 45 |
+
|
| 46 |
+
We use the **same** pretraining configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts) in 224 pretraining.
|
| 47 |
+
Their **names** and **default values** are in [/pretrain/utils/arg_util.py line23-44](/pretrain/utils/arg_util.py#L23-L44).
|
| 48 |
+
All these default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
|
| 49 |
+
|
| 50 |
+
**Note: the batch size `--bs` is the total batch size of all GPU, and the learning rate `--base_lr` is the base lr. The actual lr would be `lr = base_lr * bs / 256`, as in [/pretrain/utils/arg_util.py line131](/pretrain/utils/arg_util.py#L131). So do not use `--lr` to specify a lr (that will be ignored)**
|
| 51 |
+
|
| 52 |
+
Here is an example to pretrain a ResNet50 on an 8-GPU single machine (we use DistributedDataParallel), overwriting the default batch size to 512:
|
| 53 |
+
```shell script
|
| 54 |
+
$ cd /path/to/SparK/pretrain
|
| 55 |
+
$ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=<some_port> main.py \
|
| 56 |
+
--data_path=/path/to/imagenet --exp_name=<your_exp_name> --exp_dir=/path/to/logdir \
|
| 57 |
+
--model=resnet50 --bs=512
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
For multiple machines, change the `--nnodes`, `--node_rank`, `--master_address` and `--master_port` to your configurations. E.g.:
|
| 61 |
+
```shell script
|
| 62 |
+
$ torchrun --nproc_per_node=8 --nnodes=<your_nnodes> --node_rank=<rank_starts_from_0> --master_address=<some_address> --master_port=<some_port> main.py \
|
| 63 |
+
...
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Pretraining ConvNeXt-Large on ImageNet-1k (384x384)
|
| 67 |
+
|
| 68 |
+
For 384 pretraining we use a larger mask ratio (0.75), a half batch size (2048), and a double base learning rate (4e-4):
|
| 69 |
+
|
| 70 |
+
```shell script
|
| 71 |
+
$ cd /path/to/SparK/pretrain
|
| 72 |
+
$ torchrun --nproc_per_node=8 --nnodes=<your_nnodes> --node_rank=<rank_starts_from_0> --master_address=<some_address> --master_port=<some_port> main.py \
|
| 73 |
+
--data_path=/path/to/imagenet --exp_name=<your_exp_name> --exp_dir=/path/to/logdir \
|
| 74 |
+
--model=convnext_large --input_size=384 --mask=0.75 --bs=2048 --base_lr=4e-4
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Logging
|
| 78 |
+
|
| 79 |
+
See files in your `--exp_dir` to track your experiment:
|
| 80 |
+
|
| 81 |
+
- `<model>_withdecoder_1kpretrained_spark_style.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc.; can be used to resume pretraining; can also be used for visualization in [/pretrain/viz_reconstruction.ipynb](/pretrain/viz_reconstruction.ipynb)
|
| 82 |
+
- `<model>_1kpretrained_timm_style.pth`: can be used for downstream finetuning
|
| 83 |
+
- `pretrain_log.txt`: records some important information such as:
|
| 84 |
+
- `git_commit_id`: git version
|
| 85 |
+
- `cmd`: the command of this experiment
|
| 86 |
+
|
| 87 |
+
It also reports the loss and remaining pretraining time.
|
| 88 |
+
|
| 89 |
+
- `tensorboard_log/`: saves a lot of tensorboard logs including loss values, learning rates, gradient norms and more things. Use `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333` for viz.
|
| 90 |
+
- `stdout_backup.txt` and `stderr_backup.txt`: backups stdout/stderr.
|
| 91 |
+
|
| 92 |
+
## Resuming
|
| 93 |
+
|
| 94 |
+
Specify `--resume_from=path/to/<model>_withdecoder_1kpretrained_spark_style.pth` to resume pretraining. Note this is different from `--init_weight`:
|
| 95 |
+
|
| 96 |
+
- `--resume_from` will load three things: model weights, optimizer states, and current epoch, so it is used to resume some interrupted experiment (will start from that 'current epoch').
|
| 97 |
+
- `--init_weight` ONLY loads the model weights, so it's just like a model initialization (will start from epoch 0).
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
## Regarding sparse convolution
|
| 101 |
+
|
| 102 |
+
We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardware.
|
| 103 |
+
As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
|
| 104 |
+
We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py).
|
| 105 |
+
All these "sparse" layers are implemented through pytorch built-in operators.
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
## Some details: how we mask images and how to set the patch size
|
| 109 |
+
|
| 110 |
+
In SparK, the mask patch size **equals to** the downsample ratio of the CNN model (so there is no configuration like `--patch_size=32`).
|
| 111 |
+
|
| 112 |
+
Here is the reason: when we do mask, we:
|
| 113 |
+
|
| 114 |
+
1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py#L86-L87), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_h, fmap_w]`, and would be used to mask the smallest feature map.
|
| 115 |
+
3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., dim=2)` and `repeat_interleave(..., dim=3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py#L21)) with larger resolutions .
|
| 116 |
+
|
| 117 |
+
So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
|
| 118 |
+
See [Tutorial for pretraining your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model).
|
spark/pretrain/decoder.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from timm.models.layers import trunc_normal_
|
| 13 |
+
|
| 14 |
+
from utils.misc import is_pow2n
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class UNetBlock(nn.Module):
|
| 18 |
+
def __init__(self, cin, cout, bn2d):
|
| 19 |
+
"""
|
| 20 |
+
a UNet block with 2x up sampling
|
| 21 |
+
"""
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.up_sample = nn.ConvTranspose2d(cin, cin, kernel_size=4, stride=2, padding=1, bias=True)
|
| 24 |
+
self.conv = nn.Sequential(
|
| 25 |
+
nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cin), nn.ReLU6(inplace=True),
|
| 26 |
+
nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cout),
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
x = self.up_sample(x)
|
| 31 |
+
return self.conv(x)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LightDecoder(nn.Module):
|
| 35 |
+
def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.width = width
|
| 38 |
+
assert is_pow2n(up_sample_ratio)
|
| 39 |
+
n = round(math.log2(up_sample_ratio))
|
| 40 |
+
channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
|
| 41 |
+
bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
|
| 42 |
+
self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])])
|
| 43 |
+
self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True)
|
| 44 |
+
|
| 45 |
+
self.initialize()
|
| 46 |
+
|
| 47 |
+
def forward(self, to_dec: List[torch.Tensor]):
|
| 48 |
+
x = 0
|
| 49 |
+
for i, d in enumerate(self.dec):
|
| 50 |
+
if i < len(to_dec) and to_dec[i] is not None:
|
| 51 |
+
x = x + to_dec[i]
|
| 52 |
+
x = self.dec[i](x)
|
| 53 |
+
return self.proj(x)
|
| 54 |
+
|
| 55 |
+
def extra_repr(self) -> str:
|
| 56 |
+
return f'width={self.width}'
|
| 57 |
+
|
| 58 |
+
def initialize(self):
|
| 59 |
+
for m in self.modules():
|
| 60 |
+
if isinstance(m, nn.Linear):
|
| 61 |
+
trunc_normal_(m.weight, std=.02)
|
| 62 |
+
if m.bias is not None:
|
| 63 |
+
nn.init.constant_(m.bias, 0)
|
| 64 |
+
elif isinstance(m, nn.Conv2d):
|
| 65 |
+
trunc_normal_(m.weight, std=.02)
|
| 66 |
+
if m.bias is not None:
|
| 67 |
+
nn.init.constant_(m.bias, 0)
|
| 68 |
+
elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 69 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 70 |
+
if m.bias is not None:
|
| 71 |
+
nn.init.constant_(m.bias, 0.)
|
| 72 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)):
|
| 73 |
+
nn.init.constant_(m.bias, 0)
|
| 74 |
+
nn.init.constant_(m.weight, 1.0)
|
spark/pretrain/dist.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from typing import List
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as tdist
|
| 14 |
+
import torch.multiprocessing as mp
|
| 15 |
+
|
| 16 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
|
| 17 |
+
__initialized = False
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def initialized():
|
| 21 |
+
return __initialized
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def initialize(backend='nccl'):
|
| 25 |
+
global __device
|
| 26 |
+
if not torch.cuda.is_available():
|
| 27 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
| 28 |
+
return
|
| 29 |
+
elif 'RANK' not in os.environ:
|
| 30 |
+
__device = torch.empty(1).cuda().device
|
| 31 |
+
print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr)
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
| 35 |
+
if mp.get_start_method(allow_none=True) is None:
|
| 36 |
+
mp.set_start_method('spawn')
|
| 37 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
| 38 |
+
local_rank = global_rank % num_gpus
|
| 39 |
+
torch.cuda.set_device(local_rank)
|
| 40 |
+
tdist.init_process_group(backend=backend)
|
| 41 |
+
|
| 42 |
+
global __rank, __local_rank, __world_size, __initialized
|
| 43 |
+
__local_rank = local_rank
|
| 44 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
| 45 |
+
__device = torch.empty(1).cuda().device
|
| 46 |
+
__initialized = True
|
| 47 |
+
|
| 48 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_rank():
|
| 52 |
+
return __rank
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_local_rank():
|
| 56 |
+
return __local_rank
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_world_size():
|
| 60 |
+
return __world_size
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_device():
|
| 64 |
+
return __device
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def is_master():
|
| 68 |
+
return __rank == 0
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def is_local_master():
|
| 72 |
+
return __local_rank == 0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def barrier():
|
| 76 |
+
if __initialized:
|
| 77 |
+
tdist.barrier()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def parallelize(net, syncbn=False):
|
| 81 |
+
if syncbn:
|
| 82 |
+
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
|
| 83 |
+
net = net.cuda()
|
| 84 |
+
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
| 85 |
+
return net
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def allreduce(t: torch.Tensor) -> None:
|
| 89 |
+
if __initialized:
|
| 90 |
+
if not t.is_cuda:
|
| 91 |
+
cu = t.detach().cuda()
|
| 92 |
+
tdist.all_reduce(cu)
|
| 93 |
+
t.copy_(cu.cpu())
|
| 94 |
+
else:
|
| 95 |
+
tdist.all_reduce(t)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 99 |
+
if __initialized:
|
| 100 |
+
if not t.is_cuda:
|
| 101 |
+
t = t.cuda()
|
| 102 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
| 103 |
+
tdist.all_gather(ls, t)
|
| 104 |
+
else:
|
| 105 |
+
ls = [t]
|
| 106 |
+
if cat:
|
| 107 |
+
ls = torch.cat(ls, dim=0)
|
| 108 |
+
return ls
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
| 112 |
+
if __initialized:
|
| 113 |
+
if not t.is_cuda:
|
| 114 |
+
cu = t.detach().cuda()
|
| 115 |
+
tdist.broadcast(cu, src=src_rank)
|
| 116 |
+
t.copy_(cu.cpu())
|
| 117 |
+
else:
|
| 118 |
+
tdist.broadcast(t, src=src_rank)
|
spark/pretrain/encoder.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from timm.models.layers import DropPath
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_cur_active: torch.Tensor = None # B1ff
|
| 13 |
+
# todo: try to use `gather` for speed?
|
| 14 |
+
def _get_active_ex_or_ii(H, W, returning_active_ex=True):
|
| 15 |
+
h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1]
|
| 16 |
+
active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3)
|
| 17 |
+
return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sp_conv_forward(self, x: torch.Tensor):
|
| 21 |
+
x = super(type(self), self).forward(x)
|
| 22 |
+
x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def sp_bn_forward(self, x: torch.Tensor):
|
| 27 |
+
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
|
| 28 |
+
|
| 29 |
+
bhwc = x.permute(0, 2, 3, 1)
|
| 30 |
+
nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc`
|
| 31 |
+
nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
|
| 32 |
+
|
| 33 |
+
bchw = torch.zeros_like(bhwc)
|
| 34 |
+
bchw[ii] = nc
|
| 35 |
+
bchw = bchw.permute(0, 3, 1, 2)
|
| 36 |
+
return bchw
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SparseConv2d(nn.Conv2d):
|
| 40 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SparseMaxPooling(nn.MaxPool2d):
|
| 44 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SparseAvgPooling(nn.AvgPool2d):
|
| 48 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SparseBatchNorm2d(nn.BatchNorm1d):
|
| 52 |
+
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
|
| 56 |
+
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SparseConvNeXtLayerNorm(nn.LayerNorm):
|
| 60 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 61 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 62 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 63 |
+
with shape (batch_size, channels, height, width).
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
|
| 67 |
+
if data_format not in ["channels_last", "channels_first"]:
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
super().__init__(normalized_shape, eps, elementwise_affine=True)
|
| 70 |
+
self.data_format = data_format
|
| 71 |
+
self.sparse = sparse
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
if x.ndim == 4: # BHWC or BCHW
|
| 75 |
+
if self.data_format == "channels_last": # BHWC
|
| 76 |
+
if self.sparse:
|
| 77 |
+
ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False)
|
| 78 |
+
nc = x[ii]
|
| 79 |
+
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
| 80 |
+
|
| 81 |
+
x = torch.zeros_like(x)
|
| 82 |
+
x[ii] = nc
|
| 83 |
+
return x
|
| 84 |
+
else:
|
| 85 |
+
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
| 86 |
+
else: # channels_first, BCHW
|
| 87 |
+
if self.sparse:
|
| 88 |
+
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
|
| 89 |
+
bhwc = x.permute(0, 2, 3, 1)
|
| 90 |
+
nc = bhwc[ii]
|
| 91 |
+
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
| 92 |
+
|
| 93 |
+
x = torch.zeros_like(bhwc)
|
| 94 |
+
x[ii] = nc
|
| 95 |
+
return x.permute(0, 3, 1, 2)
|
| 96 |
+
else:
|
| 97 |
+
u = x.mean(1, keepdim=True)
|
| 98 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 99 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 100 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 101 |
+
return x
|
| 102 |
+
else: # BLC or BC
|
| 103 |
+
if self.sparse:
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
else:
|
| 106 |
+
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
| 107 |
+
|
| 108 |
+
def __repr__(self):
|
| 109 |
+
return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class SparseConvNeXtBlock(nn.Module):
|
| 113 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 114 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 115 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 116 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
dim (int): Number of input channels.
|
| 120 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 121 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv
|
| 127 |
+
self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
|
| 128 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 129 |
+
self.act = nn.GELU()
|
| 130 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 131 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
| 132 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 133 |
+
self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 134 |
+
self.sparse = sparse
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
input = x
|
| 138 |
+
x = self.dwconv(x)
|
| 139 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 140 |
+
x = self.norm(x)
|
| 141 |
+
x = self.pwconv1(x)
|
| 142 |
+
x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
|
| 143 |
+
x = self.pwconv2(x)
|
| 144 |
+
if self.gamma is not None:
|
| 145 |
+
x = self.gamma * x
|
| 146 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 147 |
+
|
| 148 |
+
if self.sparse:
|
| 149 |
+
x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True)
|
| 150 |
+
|
| 151 |
+
x = input + self.drop_path(x)
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
def __repr__(self):
|
| 155 |
+
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class SparseEncoder(nn.Module):
|
| 159 |
+
def __init__(self, cnn, input_size, sbn=False, verbose=False):
|
| 160 |
+
super(SparseEncoder, self).__init__()
|
| 161 |
+
self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
|
| 162 |
+
self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels()
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
|
| 166 |
+
oup = m
|
| 167 |
+
if isinstance(m, nn.Conv2d):
|
| 168 |
+
m: nn.Conv2d
|
| 169 |
+
bias = m.bias is not None
|
| 170 |
+
oup = SparseConv2d(
|
| 171 |
+
m.in_channels, m.out_channels,
|
| 172 |
+
kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
|
| 173 |
+
dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
|
| 174 |
+
)
|
| 175 |
+
oup.weight.data.copy_(m.weight.data)
|
| 176 |
+
if bias:
|
| 177 |
+
oup.bias.data.copy_(m.bias.data)
|
| 178 |
+
elif isinstance(m, nn.MaxPool2d):
|
| 179 |
+
m: nn.MaxPool2d
|
| 180 |
+
oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode)
|
| 181 |
+
elif isinstance(m, nn.AvgPool2d):
|
| 182 |
+
m: nn.AvgPool2d
|
| 183 |
+
oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
|
| 184 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
|
| 185 |
+
m: nn.BatchNorm2d
|
| 186 |
+
oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats)
|
| 187 |
+
oup.weight.data.copy_(m.weight.data)
|
| 188 |
+
oup.bias.data.copy_(m.bias.data)
|
| 189 |
+
oup.running_mean.data.copy_(m.running_mean.data)
|
| 190 |
+
oup.running_var.data.copy_(m.running_var.data)
|
| 191 |
+
oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
|
| 192 |
+
if hasattr(m, "qconfig"):
|
| 193 |
+
oup.qconfig = m.qconfig
|
| 194 |
+
elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
|
| 195 |
+
m: nn.LayerNorm
|
| 196 |
+
oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
|
| 197 |
+
oup.weight.data.copy_(m.weight.data)
|
| 198 |
+
oup.bias.data.copy_(m.bias.data)
|
| 199 |
+
elif isinstance(m, (nn.Conv1d,)):
|
| 200 |
+
raise NotImplementedError
|
| 201 |
+
|
| 202 |
+
for name, child in m.named_children():
|
| 203 |
+
oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
|
| 204 |
+
del m
|
| 205 |
+
return oup
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
return self.sp_cnn(x, hierarchical=True)
|
spark/pretrain/main.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import datetime
|
| 8 |
+
import math
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
from functools import partial
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
|
| 18 |
+
import dist
|
| 19 |
+
import encoder
|
| 20 |
+
from decoder import LightDecoder
|
| 21 |
+
from models import build_sparse_encoder
|
| 22 |
+
from sampler import DistInfiniteBatchSampler, worker_init_fn
|
| 23 |
+
from spark import SparK
|
| 24 |
+
from utils import arg_util, misc, lamb
|
| 25 |
+
from utils.imagenet import build_dataset_to_pretrain
|
| 26 |
+
from utils.lr_control import lr_wd_annealing, get_param_groups
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class LocalDDP(torch.nn.Module):
|
| 30 |
+
def __init__(self, module):
|
| 31 |
+
super(LocalDDP, self).__init__()
|
| 32 |
+
self.module = module
|
| 33 |
+
|
| 34 |
+
def forward(self, *args, **kwargs):
|
| 35 |
+
return self.module(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main_pt():
|
| 39 |
+
args: arg_util.Args = arg_util.init_dist_and_get_args()
|
| 40 |
+
print(f'initial args:\n{str(args)}')
|
| 41 |
+
args.log_epoch()
|
| 42 |
+
|
| 43 |
+
# build data
|
| 44 |
+
print(f'[build data for pre-training] ...\n')
|
| 45 |
+
dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size)
|
| 46 |
+
data_loader_train = DataLoader(
|
| 47 |
+
dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
|
| 48 |
+
batch_sampler=DistInfiniteBatchSampler(
|
| 49 |
+
dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
|
| 50 |
+
shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
|
| 51 |
+
), worker_init_fn=worker_init_fn
|
| 52 |
+
)
|
| 53 |
+
itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
|
| 54 |
+
print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
|
| 55 |
+
|
| 56 |
+
# build encoder and decoder
|
| 57 |
+
enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
|
| 58 |
+
dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
|
| 59 |
+
model_without_ddp = SparK(
|
| 60 |
+
sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
|
| 61 |
+
densify_norm=args.densify_norm, sbn=args.sbn,
|
| 62 |
+
).to(args.device)
|
| 63 |
+
print(f'[PT model] model = {model_without_ddp}\n')
|
| 64 |
+
|
| 65 |
+
# the model has been randomly initialized in their construction time
|
| 66 |
+
# now try to load some checkpoint as model weight initialization; this ONLY loads the model weights
|
| 67 |
+
misc.initialize_weight(args.init_weight, model_without_ddp)
|
| 68 |
+
|
| 69 |
+
if dist.initialized():
|
| 70 |
+
model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
|
| 71 |
+
else:
|
| 72 |
+
model = LocalDDP(model_without_ddp)
|
| 73 |
+
|
| 74 |
+
# build optimizer and lr_scheduler
|
| 75 |
+
param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'})
|
| 76 |
+
opt_clz = {
|
| 77 |
+
'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
|
| 78 |
+
'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
|
| 79 |
+
'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0),
|
| 80 |
+
}[args.opt]
|
| 81 |
+
optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
|
| 82 |
+
print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
|
| 83 |
+
|
| 84 |
+
# try to resume the experiment from some checkpoint.pth; this will load model weights, optimizer states, and last epoch (ep_start)
|
| 85 |
+
# if loaded, ep_start will be greater than 0
|
| 86 |
+
ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
|
| 87 |
+
if ep_start >= args.ep: # load from a complete checkpoint file
|
| 88 |
+
print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
|
| 89 |
+
else: # perform pre-training
|
| 90 |
+
tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
|
| 91 |
+
min_loss = 1e9
|
| 92 |
+
print(f'[PT start] from ep{ep_start}')
|
| 93 |
+
|
| 94 |
+
pt_start_time = time.time()
|
| 95 |
+
for ep in range(ep_start, args.ep):
|
| 96 |
+
ep_start_time = time.time()
|
| 97 |
+
tb_lg.set_step(ep * iters_train)
|
| 98 |
+
if hasattr(itrt_train, 'set_epoch'):
|
| 99 |
+
itrt_train.set_epoch(ep)
|
| 100 |
+
|
| 101 |
+
stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
|
| 102 |
+
last_loss = stats['last_loss']
|
| 103 |
+
min_loss = min(min_loss, last_loss)
|
| 104 |
+
performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
|
| 105 |
+
misc.save_checkpoint_with_meta_info_and_opt_state(f'{args.model}_withdecoder_1kpretrained_spark_style.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
|
| 106 |
+
misc.save_checkpoint_model_weights_only(f'{args.model}_1kpretrained_timm_style.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())
|
| 107 |
+
|
| 108 |
+
ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
|
| 109 |
+
remain_secs = (args.ep-1 - ep) * ep_cost
|
| 110 |
+
remain_time = datetime.timedelta(seconds=round(remain_secs))
|
| 111 |
+
finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
|
| 112 |
+
print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
|
| 113 |
+
|
| 114 |
+
args.cur_ep = f'{ep + 1}/{args.ep}'
|
| 115 |
+
args.remain_time, args.finish_time = str(remain_time), str(finish_time)
|
| 116 |
+
args.last_loss = last_loss
|
| 117 |
+
args.log_epoch()
|
| 118 |
+
|
| 119 |
+
tb_lg.update(min_loss=min_loss, head='train', step=ep)
|
| 120 |
+
tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
|
| 121 |
+
tb_lg.flush()
|
| 122 |
+
|
| 123 |
+
# finish pre-training
|
| 124 |
+
tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
|
| 125 |
+
tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
|
| 126 |
+
tb_lg.flush()
|
| 127 |
+
print(f'final args:\n{str(args)}')
|
| 128 |
+
print('\n\n')
|
| 129 |
+
print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
|
| 130 |
+
print('\n\n')
|
| 131 |
+
tb_lg.close()
|
| 132 |
+
time.sleep(10)
|
| 133 |
+
|
| 134 |
+
args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
|
| 135 |
+
args.log_epoch()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
|
| 139 |
+
model.train()
|
| 140 |
+
me = misc.MetricLogger(delimiter=' ')
|
| 141 |
+
me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
|
| 142 |
+
header = f'[PT] Epoch {ep}:'
|
| 143 |
+
|
| 144 |
+
optimizer.zero_grad()
|
| 145 |
+
early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
|
| 146 |
+
late_clipping = hasattr(optimizer, 'global_grad_norm')
|
| 147 |
+
if early_clipping:
|
| 148 |
+
params_req_grad = [p for p in model.parameters() if p.requires_grad]
|
| 149 |
+
|
| 150 |
+
for it, inp in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
|
| 151 |
+
# adjust lr and wd
|
| 152 |
+
min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
|
| 153 |
+
|
| 154 |
+
# forward and backward
|
| 155 |
+
inp = inp.to(args.device, non_blocking=True)
|
| 156 |
+
SparK.forward
|
| 157 |
+
loss = model(inp, active_b1ff=None, vis=False)
|
| 158 |
+
optimizer.zero_grad()
|
| 159 |
+
loss.backward()
|
| 160 |
+
loss = loss.item()
|
| 161 |
+
if not math.isfinite(loss):
|
| 162 |
+
print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
|
| 163 |
+
sys.exit(-1)
|
| 164 |
+
|
| 165 |
+
# optimize
|
| 166 |
+
grad_norm = None
|
| 167 |
+
if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
|
| 168 |
+
optimizer.step()
|
| 169 |
+
if late_clipping: grad_norm = optimizer.global_grad_norm
|
| 170 |
+
torch.cuda.synchronize()
|
| 171 |
+
|
| 172 |
+
# log
|
| 173 |
+
me.update(last_loss=loss)
|
| 174 |
+
me.update(max_lr=max_lr)
|
| 175 |
+
tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
|
| 176 |
+
tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
|
| 177 |
+
tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
|
| 178 |
+
tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
|
| 179 |
+
tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
|
| 180 |
+
|
| 181 |
+
if grad_norm is not None:
|
| 182 |
+
me.update(orig_norm=grad_norm)
|
| 183 |
+
tb_lg.update(orig_norm=grad_norm, head='train_hp')
|
| 184 |
+
tb_lg.set_step()
|
| 185 |
+
|
| 186 |
+
me.synchronize_between_processes()
|
| 187 |
+
return {k: meter.global_avg for k, meter in me.meters.items()}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == '__main__':
|
| 191 |
+
main_pt()
|
spark/pretrain/models/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from timm import create_model
|
| 9 |
+
from timm.loss import SoftTargetCrossEntropy
|
| 10 |
+
from timm.models.layers import drop
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from models.convnext import ConvNeXt
|
| 14 |
+
from models.resnet import ResNet
|
| 15 |
+
from models.custom import YourConvNet
|
| 16 |
+
_import_resnets_for_timm_registration = (ResNet,)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# log more
|
| 20 |
+
def _ex_repr(self):
|
| 21 |
+
return ', '.join(
|
| 22 |
+
f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
|
| 23 |
+
for k, v in vars(self).items()
|
| 24 |
+
if not k.startswith('_') and k != 'training'
|
| 25 |
+
and not isinstance(v, (torch.nn.Module, torch.Tensor))
|
| 26 |
+
)
|
| 27 |
+
for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
|
| 28 |
+
if hasattr(clz, 'extra_repr'):
|
| 29 |
+
clz.extra_repr = _ex_repr
|
| 30 |
+
else:
|
| 31 |
+
clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
pretrain_default_model_kwargs = {
|
| 35 |
+
'V9back': dict(),
|
| 36 |
+
|
| 37 |
+
'resnet50': dict(drop_path_rate=0.05),
|
| 38 |
+
'resnet101': dict(drop_path_rate=0.08),
|
| 39 |
+
'resnet152': dict(drop_path_rate=0.10),
|
| 40 |
+
'resnet200': dict(drop_path_rate=0.15),
|
| 41 |
+
'convnext_small': dict(sparse=True, drop_path_rate=0.2),
|
| 42 |
+
'convnext_base': dict(sparse=True, drop_path_rate=0.3),
|
| 43 |
+
'convnext_large': dict(sparse=True, drop_path_rate=0.4),
|
| 44 |
+
|
| 45 |
+
}
|
| 46 |
+
for kw in pretrain_default_model_kwargs.values():
|
| 47 |
+
kw['pretrained'] = False
|
| 48 |
+
kw['num_classes'] = 0
|
| 49 |
+
kw['global_pool'] = ''
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
|
| 53 |
+
from encoder import SparseEncoder
|
| 54 |
+
|
| 55 |
+
kwargs = pretrain_default_model_kwargs[name]
|
| 56 |
+
if drop_path_rate != 0:
|
| 57 |
+
kwargs['drop_path_rate'] = drop_path_rate
|
| 58 |
+
print(f'[build_sparse_encoder] model kwargs={kwargs}')
|
| 59 |
+
cnn = create_model(name, **kwargs)
|
| 60 |
+
|
| 61 |
+
return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose)
|
| 62 |
+
|
spark/pretrain/models/convnext.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from timm.models.layers import trunc_normal_
|
| 13 |
+
from timm.models.registry import register_model
|
| 14 |
+
|
| 15 |
+
from encoder import SparseConvNeXtBlock, SparseConvNeXtLayerNorm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ConvNeXt(nn.Module):
|
| 19 |
+
r""" ConvNeXt
|
| 20 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
| 21 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
| 22 |
+
Args:
|
| 23 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 24 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 25 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 26 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 27 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 28 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 29 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
| 33 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
| 34 |
+
layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg',
|
| 35 |
+
sparse=True,
|
| 36 |
+
):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.dims: List[int] = dims
|
| 39 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
| 40 |
+
stem = nn.Sequential(
|
| 41 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 42 |
+
SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse)
|
| 43 |
+
)
|
| 44 |
+
self.downsample_layers.append(stem)
|
| 45 |
+
for i in range(3):
|
| 46 |
+
downsample_layer = nn.Sequential(
|
| 47 |
+
SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse),
|
| 48 |
+
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
| 49 |
+
)
|
| 50 |
+
self.downsample_layers.append(downsample_layer)
|
| 51 |
+
|
| 52 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 53 |
+
self.drop_path_rate = drop_path_rate
|
| 54 |
+
self.layer_scale_init_value = layer_scale_init_value
|
| 55 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 56 |
+
cur = 0
|
| 57 |
+
for i in range(4):
|
| 58 |
+
stage = nn.Sequential(
|
| 59 |
+
*[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j],
|
| 60 |
+
layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])]
|
| 61 |
+
)
|
| 62 |
+
self.stages.append(stage)
|
| 63 |
+
cur += depths[i]
|
| 64 |
+
self.depths = depths
|
| 65 |
+
|
| 66 |
+
self.apply(self._init_weights)
|
| 67 |
+
if num_classes > 0:
|
| 68 |
+
self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False) # final norm layer for LE/FT; should not be sparse
|
| 69 |
+
self.fc = nn.Linear(dims[-1], num_classes)
|
| 70 |
+
else:
|
| 71 |
+
self.norm = nn.Identity()
|
| 72 |
+
self.fc = nn.Identity()
|
| 73 |
+
|
| 74 |
+
def _init_weights(self, m):
|
| 75 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 76 |
+
trunc_normal_(m.weight, std=.02)
|
| 77 |
+
nn.init.constant_(m.bias, 0)
|
| 78 |
+
|
| 79 |
+
def get_downsample_ratio(self) -> int:
|
| 80 |
+
return 32
|
| 81 |
+
|
| 82 |
+
def get_feature_map_channels(self) -> List[int]:
|
| 83 |
+
return self.dims
|
| 84 |
+
|
| 85 |
+
def forward(self, x, hierarchical=False):
|
| 86 |
+
if hierarchical:
|
| 87 |
+
ls = []
|
| 88 |
+
for i in range(4):
|
| 89 |
+
x = self.downsample_layers[i](x)
|
| 90 |
+
x = self.stages[i](x)
|
| 91 |
+
ls.append(x)
|
| 92 |
+
return ls
|
| 93 |
+
else:
|
| 94 |
+
return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls)
|
| 95 |
+
|
| 96 |
+
def get_classifier(self):
|
| 97 |
+
return self.fc
|
| 98 |
+
|
| 99 |
+
def extra_repr(self):
|
| 100 |
+
return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}'
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@register_model
|
| 104 |
+
def convnext_tiny(pretrained=False, in_22k=False, **kwargs):
|
| 105 |
+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
| 106 |
+
return model
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@register_model
|
| 110 |
+
def convnext_small(pretrained=False, in_22k=False, **kwargs):
|
| 111 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
| 112 |
+
return model
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@register_model
|
| 116 |
+
def convnext_base(pretrained=False, in_22k=False, **kwargs):
|
| 117 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
| 118 |
+
return model
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@register_model
|
| 122 |
+
def convnext_large(pretrained=False, in_22k=False, **kwargs):
|
| 123 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
| 124 |
+
return model
|
| 125 |
+
|
spark/pretrain/models/custom.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import List
|
| 10 |
+
from timm.models.registry import register_model
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
import sys
|
| 14 |
+
from HG.HGBlock import HGStem,HGBlock
|
| 15 |
+
from HG.block import DWConv
|
| 16 |
+
from v9back.common import *
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class YourConvNet(nn.Module):
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.mlist=nn.ModuleList(
|
| 24 |
+
[Silence(),
|
| 25 |
+
Bbackbone(),
|
| 26 |
+
]
|
| 27 |
+
)
|
| 28 |
+
self.d0= Down0(64)
|
| 29 |
+
self.d1 = Down1(128)
|
| 30 |
+
self.d2 = Down2(256)
|
| 31 |
+
self.d3 = Down3(512)
|
| 32 |
+
self.d4 = Down4(1024)
|
| 33 |
+
self.alld = [self.d0,self.d1,self.d2,self.d3,self.d4]
|
| 34 |
+
self.cblinear1 = CBLinear(64,[64])
|
| 35 |
+
self.cblinear3 = CBLinear(128, [64, 128])
|
| 36 |
+
self.cblinear5 = CBLinear(256, [64, 128, 256])
|
| 37 |
+
self.cblinear7 = CBLinear(512, [64, 128, 256, 512])
|
| 38 |
+
self.cblinear9 = CBLinear(1024, [64, 128, 256, 512, 1024])
|
| 39 |
+
self.allcblinear = [self.cblinear1,self.cblinear3,self.cblinear5,self.cblinear7,self.cblinear9]
|
| 40 |
+
# # conv down 1
|
| 41 |
+
self.conv1 = Conv(3, 64, 3, 2 )
|
| 42 |
+
self.cbfuse1 = CBFuse([0, 0, 0, 0, 0])
|
| 43 |
+
|
| 44 |
+
## conv down 2
|
| 45 |
+
self.conv2= Conv(64, 128, 3, 2)
|
| 46 |
+
self.cbfuse2 = CBFuse([1, 1, 1, 1])
|
| 47 |
+
self.rep2 = RepNCSPELAN4(128, 256, 128, 64, 2)
|
| 48 |
+
## avg-conv down fuse 1
|
| 49 |
+
self.adown3 = ADown(256, 256)
|
| 50 |
+
self.cbfuse3 = CBFuse([2, 2, 2])
|
| 51 |
+
self.rep3 = RepNCSPELAN4(256, 512, 256, 128, 2)
|
| 52 |
+
|
| 53 |
+
## avg-conv down fuse 2
|
| 54 |
+
self.adown4 = ADown(512, 512)
|
| 55 |
+
self.cbfuse4 = CBFuse([3,3])
|
| 56 |
+
self.rep4 = RepNCSPELAN4(512, 1024, 512, 256, 2)
|
| 57 |
+
|
| 58 |
+
## avg-conv down fuse 3
|
| 59 |
+
self.adown5 = ADown(1024, 1024)
|
| 60 |
+
self.cbfuse5 = CBFuse([4])
|
| 61 |
+
self.rep5 = RepNCSPELAN4(1024, 1024, 512, 256, 2)
|
| 62 |
+
|
| 63 |
+
def get_downsample_ratio(self) -> int:
|
| 64 |
+
return 32
|
| 65 |
+
|
| 66 |
+
def get_feature_map_channels(self) -> List[int]:
|
| 67 |
+
return [ 256,512,1024,1024]
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor, hierarchical=False):
|
| 70 |
+
if hierarchical:
|
| 71 |
+
origin = x.clone()
|
| 72 |
+
ls = []
|
| 73 |
+
tmp = []
|
| 74 |
+
bx = None
|
| 75 |
+
for index,modules in enumerate( self.mlist):
|
| 76 |
+
x = modules(x)
|
| 77 |
+
if index ==1:
|
| 78 |
+
bx = x
|
| 79 |
+
for i in range(5):
|
| 80 |
+
tmp.append(self.allcblinear[i](self.alld[i](bx)))
|
| 81 |
+
|
| 82 |
+
fuse1 = self.cbfuse1([tmp[0],tmp[1],tmp[2],tmp[3],tmp[4],self.conv1(origin)])
|
| 83 |
+
fuse2 = self.cbfuse2([tmp[1],tmp[2],tmp[3],tmp[4],self.conv2(fuse1)])
|
| 84 |
+
fuse2 = self.rep2(fuse2)
|
| 85 |
+
|
| 86 |
+
fuse3= self.cbfuse3([ tmp[2], tmp[3], tmp[4], self.adown3(fuse2)])
|
| 87 |
+
fuse3 = self.rep3(fuse3)
|
| 88 |
+
|
| 89 |
+
fuse4 = self.cbfuse4([tmp[3], tmp[4], self.adown4(fuse3)])
|
| 90 |
+
fuse4 = self.rep4(fuse4)
|
| 91 |
+
|
| 92 |
+
fuse5 = self.cbfuse5([tmp[4], self.adown5(fuse4)])
|
| 93 |
+
fuse5 = self.rep5(fuse5)
|
| 94 |
+
|
| 95 |
+
ls.append(fuse2)
|
| 96 |
+
ls.append(fuse3)
|
| 97 |
+
ls.append(fuse4)
|
| 98 |
+
ls.append(fuse5)
|
| 99 |
+
return ls
|
| 100 |
+
else:
|
| 101 |
+
for modules in self.mlist:
|
| 102 |
+
x = modules(x)
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@register_model
|
| 107 |
+
def V9back(pretrained=False, **kwargs):
|
| 108 |
+
return YourConvNet(**kwargs)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def convnet_test():
|
| 113 |
+
from timm.models import create_model
|
| 114 |
+
cnn = create_model('V9back')
|
| 115 |
+
print('get_downsample_ratio:', cnn.get_downsample_ratio())
|
| 116 |
+
print('get_feature_map_channels:', cnn.get_feature_map_channels())
|
| 117 |
+
|
| 118 |
+
downsample_ratio = cnn.get_downsample_ratio()
|
| 119 |
+
feature_map_channels = cnn.get_feature_map_channels()
|
| 120 |
+
|
| 121 |
+
# check the forward function
|
| 122 |
+
B, C, H, W = 4, 3, 224, 224
|
| 123 |
+
inp = torch.rand(B, C, H, W)
|
| 124 |
+
feats = cnn(inp, hierarchical=True)
|
| 125 |
+
assert isinstance(feats, list)
|
| 126 |
+
assert len(feats) == len(feature_map_channels)
|
| 127 |
+
print([tuple(t.shape) for t in feats])
|
| 128 |
+
|
| 129 |
+
# check the downsample ratio
|
| 130 |
+
feats = cnn(inp, hierarchical=True)
|
| 131 |
+
assert feats[-1].shape[-2] == H // downsample_ratio
|
| 132 |
+
assert feats[-1].shape[-1] == W // downsample_ratio
|
| 133 |
+
|
| 134 |
+
# check the channel number
|
| 135 |
+
for feat, ch in zip(feats, feature_map_channels):
|
| 136 |
+
assert feat.ndim == 4
|
| 137 |
+
assert feat.shape[1] == ch
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == '__main__':
|
| 141 |
+
convnet_test()
|
spark/pretrain/models/custom_detr.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import List
|
| 10 |
+
from timm.models.registry import register_model
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
import sys
|
| 14 |
+
from HG.HGBlock import HGStem,HGBlock
|
| 15 |
+
from HG.block import DWConv
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class YourConvNet(nn.Module):
|
| 19 |
+
def __init__(self, *args, **kwargs):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
self.mlist=nn.ModuleList(
|
| 23 |
+
[HGStem(3, 32, 64),
|
| 24 |
+
HGBlock(64, 64, 128, 3, n=6),
|
| 25 |
+
|
| 26 |
+
DWConv(128, 128, 3, 2, 1, False),
|
| 27 |
+
HGBlock(128, 128, 512, 3, n=6),
|
| 28 |
+
HGBlock(512, 128, 512, 3, lightconv=False,shortcut=True,n=6),
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DWConv(512, 512, 3, 2, 1, False),
|
| 32 |
+
HGBlock(512, 256, 1024, 5,lightconv=True,shortcut=False,n=6),
|
| 33 |
+
HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
|
| 34 |
+
HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
|
| 35 |
+
HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
|
| 36 |
+
HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
DWConv(1024, 1024, 3, 2, 1, False),
|
| 41 |
+
HGBlock(1024, 512, 2048, 5, lightconv=True, shortcut=False, n=6),
|
| 42 |
+
HGBlock(2048, 512, 2048, 5, lightconv=True, shortcut=True, n=6)
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_downsample_ratio(self) -> int:
|
| 48 |
+
return 32
|
| 49 |
+
|
| 50 |
+
def get_feature_map_channels(self) -> List[int]:
|
| 51 |
+
return [128,512,1024,2048]
|
| 52 |
+
|
| 53 |
+
def forward(self, x: torch.Tensor, hierarchical=False):
|
| 54 |
+
if hierarchical:
|
| 55 |
+
ls = []
|
| 56 |
+
for index,modules in enumerate( self.mlist):
|
| 57 |
+
x = modules(x)
|
| 58 |
+
if index in [1,4,10,13]:
|
| 59 |
+
ls.append(x)
|
| 60 |
+
return ls
|
| 61 |
+
else:
|
| 62 |
+
for modules in self.mlist:
|
| 63 |
+
x = modules(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@register_model
|
| 68 |
+
def HGNetv2(pretrained=False, **kwargs):
|
| 69 |
+
return YourConvNet(**kwargs)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def convnet_test():
|
| 74 |
+
from timm.models import create_model
|
| 75 |
+
cnn = create_model('HGNetv2')
|
| 76 |
+
print('get_downsample_ratio:', cnn.get_downsample_ratio())
|
| 77 |
+
print('get_feature_map_channels:', cnn.get_feature_map_channels())
|
| 78 |
+
|
| 79 |
+
downsample_ratio = cnn.get_downsample_ratio()
|
| 80 |
+
feature_map_channels = cnn.get_feature_map_channels()
|
| 81 |
+
|
| 82 |
+
# check the forward function
|
| 83 |
+
B, C, H, W = 4, 3, 224, 224
|
| 84 |
+
inp = torch.rand(B, C, H, W)
|
| 85 |
+
feats = cnn(inp, hierarchical=True)
|
| 86 |
+
assert isinstance(feats, list)
|
| 87 |
+
assert len(feats) == len(feature_map_channels)
|
| 88 |
+
print([tuple(t.shape) for t in feats])
|
| 89 |
+
|
| 90 |
+
# check the downsample ratio
|
| 91 |
+
feats = cnn(inp, hierarchical=True)
|
| 92 |
+
assert feats[-1].shape[-2] == H // downsample_ratio
|
| 93 |
+
assert feats[-1].shape[-1] == W // downsample_ratio
|
| 94 |
+
|
| 95 |
+
# check the channel number
|
| 96 |
+
for feat, ch in zip(feats, feature_map_channels):
|
| 97 |
+
assert feat.ndim == 4
|
| 98 |
+
assert feat.shape[1] == ch
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == '__main__':
|
| 102 |
+
convnet_test()
|
spark/pretrain/models/custom_origin.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing import List
|
| 10 |
+
from timm.models.registry import register_model
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class YourConvNet(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
This is a template for your custom ConvNet.
|
| 16 |
+
It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
|
| 17 |
+
You can refer to the implementations in `pretrain\models\resnet.py` for an example.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def get_downsample_ratio(self) -> int:
|
| 21 |
+
"""
|
| 22 |
+
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
|
| 23 |
+
|
| 24 |
+
:return: the TOTAL downsample ratio of the ConvNet.
|
| 25 |
+
E.g., for a ResNet-50, this should return 32.
|
| 26 |
+
"""
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
def get_feature_map_channels(self) -> List[int]:
|
| 30 |
+
"""
|
| 31 |
+
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
|
| 32 |
+
|
| 33 |
+
:return: a list of the number of channels of each feature map.
|
| 34 |
+
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
|
| 35 |
+
"""
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
|
| 39 |
+
"""
|
| 40 |
+
The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
|
| 41 |
+
|
| 42 |
+
:param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
|
| 43 |
+
:param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
|
| 44 |
+
:return:
|
| 45 |
+
- hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
|
| 46 |
+
- hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
|
| 47 |
+
E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
|
| 48 |
+
for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
|
| 49 |
+
"""
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@register_model
|
| 54 |
+
def your_convnet_small(pretrained=False, **kwargs):
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
return YourConvNet(**kwargs)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def convnet_test():
|
| 61 |
+
from timm.models import create_model
|
| 62 |
+
cnn = create_model('your_convnet_small')
|
| 63 |
+
print('get_downsample_ratio:', cnn.get_downsample_ratio())
|
| 64 |
+
print('get_feature_map_channels:', cnn.get_feature_map_channels())
|
| 65 |
+
|
| 66 |
+
downsample_ratio = cnn.get_downsample_ratio()
|
| 67 |
+
feature_map_channels = cnn.get_feature_map_channels()
|
| 68 |
+
|
| 69 |
+
# check the forward function
|
| 70 |
+
B, C, H, W = 4, 3, 224, 224
|
| 71 |
+
inp = torch.rand(B, C, H, W)
|
| 72 |
+
feats = cnn(inp, hierarchical=True)
|
| 73 |
+
assert isinstance(feats, list)
|
| 74 |
+
assert len(feats) == len(feature_map_channels)
|
| 75 |
+
print([tuple(t.shape) for t in feats])
|
| 76 |
+
|
| 77 |
+
# check the downsample ratio
|
| 78 |
+
feats = cnn(inp, hierarchical=True)
|
| 79 |
+
assert feats[-1].shape[-2] == H // downsample_ratio
|
| 80 |
+
assert feats[-1].shape[-1] == W // downsample_ratio
|
| 81 |
+
|
| 82 |
+
# check the channel number
|
| 83 |
+
for feat, ch in zip(feats, feature_map_channels):
|
| 84 |
+
assert feat.ndim == 4
|
| 85 |
+
assert feat.shape[1] == ch
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == '__main__':
|
| 89 |
+
convnet_test()
|