mkang315 commited on
Commit
7757a1a
·
verified ·
1 Parent(s): 1a8eb3e

Upload 123 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. data/hyps/hyp.scratch-high.yaml +30 -0
  3. data/images/horses.jpg +3 -0
  4. data/multiplane.yaml +9 -0
  5. models/__init__.py +1 -0
  6. models/common.py +1296 -0
  7. models/detect/pk-yolo.yaml +126 -0
  8. models/detect/yolov9-e.yaml +144 -0
  9. models/experimental.py +275 -0
  10. models/repvit.py +440 -0
  11. models/tf.py +596 -0
  12. models/yolo.py +771 -0
  13. spark repvit/repvit_1kpretrained_timm_style.pth +3 -0
  14. spark/downstream_d2/README.md +101 -0
  15. spark/downstream_d2/configs/Base-RCNN-FPN.yaml +42 -0
  16. spark/downstream_d2/configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml +57 -0
  17. spark/downstream_d2/convert-timm-to-d2.py +43 -0
  18. spark/downstream_d2/lr_decay.py +132 -0
  19. spark/downstream_d2/train_net.py +322 -0
  20. spark/downstream_imagenet/README.md +54 -0
  21. spark/downstream_imagenet/arg.py +137 -0
  22. spark/downstream_imagenet/data.py +151 -0
  23. spark/downstream_imagenet/lr_decay.py +61 -0
  24. spark/downstream_imagenet/main.py +189 -0
  25. spark/downstream_imagenet/mixup.py +168 -0
  26. spark/downstream_imagenet/models/__init__.py +104 -0
  27. spark/downstream_imagenet/models/convnext_official.py +201 -0
  28. spark/downstream_imagenet/requirements.txt +5 -0
  29. spark/downstream_imagenet/util.py +131 -0
  30. spark/downstream_mmdet/README.md +76 -0
  31. spark/downstream_mmdet/configs/_base_/default_runtime.py +16 -0
  32. spark/downstream_mmdet/configs/_base_/models/cascade_mask_rcnn_convnext_fpn.py +208 -0
  33. spark/downstream_mmdet/configs/_base_/models/mask_rcnn_convnext_fpn.py +128 -0
  34. spark/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py +95 -0
  35. spark/downstream_mmdet/mmcv_custom/__init__.py +15 -0
  36. spark/downstream_mmdet/mmcv_custom/customized_text.py +130 -0
  37. spark/downstream_mmdet/mmcv_custom/layer_decay_optimizer_constructor.py +123 -0
  38. spark/downstream_mmdet/mmcv_custom/runner/checkpoint.py +85 -0
  39. spark/downstream_mmdet/mmdet/models/backbones/__init__.py +20 -0
  40. spark/downstream_mmdet/mmdet/models/backbones/convnext.py +180 -0
  41. spark/pretrain/README.md +118 -0
  42. spark/pretrain/decoder.py +74 -0
  43. spark/pretrain/dist.py +118 -0
  44. spark/pretrain/encoder.py +208 -0
  45. spark/pretrain/main.py +191 -0
  46. spark/pretrain/models/__init__.py +62 -0
  47. spark/pretrain/models/convnext.py +125 -0
  48. spark/pretrain/models/custom.py +141 -0
  49. spark/pretrain/models/custom_detr.py +102 -0
  50. spark/pretrain/models/custom_origin.py +89 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/images/horses.jpg filter=lfs diff=lfs merge=lfs -text
37
+ spark/pretrain/viz_imgs/recon.png filter=lfs diff=lfs merge=lfs -text
data/hyps/hyp.scratch-high.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
2
+ lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
3
+ momentum: 0.937 # SGD momentum/Adam beta1
4
+ weight_decay: 0.0005 # optimizer weight decay 5e-4
5
+ warmup_epochs: 3.0 # warmup epochs (fractions ok)
6
+ warmup_momentum: 0.8 # warmup initial momentum
7
+ warmup_bias_lr: 0.1 # warmup initial bias lr
8
+ box: 7.5 # box loss gain
9
+ cls: 0.5 # cls loss gain
10
+ cls_pw: 1.0 # cls BCELoss positive_weight
11
+ obj: 0.7 # obj loss gain (scale with pixels)
12
+ obj_pw: 1.0 # obj BCELoss positive_weight
13
+ dfl: 1.5 # dfl loss gain
14
+ iou_t: 0.20 # IoU training threshold
15
+ anchor_t: 5.0 # anchor-multiple threshold
16
+ # anchors: 3 # anchors per output layer (0 to ignore)
17
+ fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
18
+ hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
19
+ hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
20
+ hsv_v: 0.4 # image HSV-Value augmentation (fraction)
21
+ degrees: 0.0 # image rotation (+/- deg)
22
+ translate: 0.1 # image translation (+/- fraction)
23
+ scale: 0.9 # image scale (+/- gain)
24
+ shear: 0.0 # image shear (+/- deg)
25
+ perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
26
+ flipud: 0.0 # image flip up-down (probability)
27
+ fliplr: 0.5 # image flip left-right (probability)
28
+ mosaic: 1.0 # image mosaic (probability)
29
+ mixup: 0.15 # image mixup (probability)
30
+ copy_paste: 0.3 # segment copy-paste (probability)
data/images/horses.jpg ADDED

Git LFS Details

  • SHA256: c8f0a677a1356569e2ce71d2fa88c1030c0ae57ecf5e14170e02d9a86a20dcb4
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
data/multiplane.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ train: ./axial_t1wce_2_class/images/train
2
+ val: ./axial_t1wce_2_class/images/test
3
+ # train: ./coronal_t1wce_2_class/images/train
4
+ # val: ./coronal_t1wce_2_class/images/test
5
+ # train: ./sagittal_t1wce_2_class/images/train
6
+ # val: ./sagittal_t1wce_2_class/images/test
7
+
8
+ nc: 2
9
+ names: ['negative','positive']
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
models/common.py ADDED
@@ -0,0 +1,1296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import contextlib
3
+ import json
4
+ import math
5
+ import platform
6
+ import warnings
7
+ import zipfile
8
+ from collections import OrderedDict, namedtuple
9
+ from copy import copy
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ from typing import Optional
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import pandas as pd
18
+ import requests
19
+ import torch
20
+ import torch.nn as nn
21
+ from IPython.display import display
22
+ from PIL import Image
23
+ from torch.cuda import amp
24
+
25
+ from models.repvit import RepViT
26
+ from utils import TryExcept
27
+ from utils.dataloaders import exif_transpose, letterbox
28
+ from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
29
+ increment_path, is_notebook, make_divisible, non_max_suppression, scale_boxes,
30
+ xywh2xyxy, xyxy2xywh, yaml_load)
31
+ from utils.plots import Annotator, colors, save_one_box
32
+ from utils.torch_utils import copy_attr, smart_inference_mode
33
+
34
+
35
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
36
+ # Pad to 'same' shape outputs
37
+ if d > 1:
38
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
39
+ if p is None:
40
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
41
+ return p
42
+
43
+
44
+ class Conv(nn.Module):
45
+ # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
46
+ default_act = nn.SiLU() # default activation
47
+
48
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
49
+ super().__init__()
50
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
51
+ self.bn = nn.BatchNorm2d(c2)
52
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
53
+
54
+ def forward(self, x):
55
+ return self.act(self.bn(self.conv(x)))
56
+
57
+ def forward_fuse(self, x):
58
+ return self.act(self.conv(x))
59
+
60
+
61
+ class AConv(nn.Module):
62
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
63
+ super().__init__()
64
+ self.cv1 = Conv(c1, c2, 3, 2, 1)
65
+
66
+ def forward(self, x):
67
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
68
+ return self.cv1(x)
69
+
70
+
71
+ class ADown(nn.Module):
72
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
73
+ super().__init__()
74
+ self.c = c2 // 2
75
+ self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
76
+ self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
77
+
78
+ def forward(self, x):
79
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
80
+ x1,x2 = x.chunk(2, 1)
81
+ x1 = self.cv1(x1)
82
+ x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
83
+ x2 = self.cv2(x2)
84
+ return torch.cat((x1, x2), 1)
85
+
86
+
87
+ class RepConvN(nn.Module):
88
+ """RepConv is a basic rep-style block, including training and deploy status
89
+ This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
90
+ """
91
+ default_act = nn.SiLU() # default activation
92
+
93
+ def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
94
+ super().__init__()
95
+ assert k == 3 and p == 1
96
+ self.g = g
97
+ self.c1 = c1
98
+ self.c2 = c2
99
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
100
+
101
+ self.bn = None
102
+ self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
103
+ self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
104
+
105
+ def forward_fuse(self, x):
106
+ """Forward process"""
107
+ return self.act(self.conv(x))
108
+
109
+ def forward(self, x):
110
+ """Forward process"""
111
+ id_out = 0 if self.bn is None else self.bn(x)
112
+ return self.act(self.conv1(x) + self.conv2(x) + id_out)
113
+
114
+ def get_equivalent_kernel_bias(self):
115
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
116
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
117
+ kernelid, biasid = self._fuse_bn_tensor(self.bn)
118
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
119
+
120
+ def _avg_to_3x3_tensor(self, avgp):
121
+ channels = self.c1
122
+ groups = self.g
123
+ kernel_size = avgp.kernel_size
124
+ input_dim = channels // groups
125
+ k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
126
+ k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
127
+ return k
128
+
129
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
130
+ if kernel1x1 is None:
131
+ return 0
132
+ else:
133
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
134
+
135
+ def _fuse_bn_tensor(self, branch):
136
+ if branch is None:
137
+ return 0, 0
138
+ if isinstance(branch, Conv):
139
+ kernel = branch.conv.weight
140
+ running_mean = branch.bn.running_mean
141
+ running_var = branch.bn.running_var
142
+ gamma = branch.bn.weight
143
+ beta = branch.bn.bias
144
+ eps = branch.bn.eps
145
+ elif isinstance(branch, nn.BatchNorm2d):
146
+ if not hasattr(self, 'id_tensor'):
147
+ input_dim = self.c1 // self.g
148
+ kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
149
+ for i in range(self.c1):
150
+ kernel_value[i, i % input_dim, 1, 1] = 1
151
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
152
+ kernel = self.id_tensor
153
+ running_mean = branch.running_mean
154
+ running_var = branch.running_var
155
+ gamma = branch.weight
156
+ beta = branch.bias
157
+ eps = branch.eps
158
+ std = (running_var + eps).sqrt()
159
+ t = (gamma / std).reshape(-1, 1, 1, 1)
160
+ return kernel * t, beta - running_mean * gamma / std
161
+
162
+ def fuse_convs(self):
163
+ if hasattr(self, 'conv'):
164
+ return
165
+ kernel, bias = self.get_equivalent_kernel_bias()
166
+ self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
167
+ out_channels=self.conv1.conv.out_channels,
168
+ kernel_size=self.conv1.conv.kernel_size,
169
+ stride=self.conv1.conv.stride,
170
+ padding=self.conv1.conv.padding,
171
+ dilation=self.conv1.conv.dilation,
172
+ groups=self.conv1.conv.groups,
173
+ bias=True).requires_grad_(False)
174
+ self.conv.weight.data = kernel
175
+ self.conv.bias.data = bias
176
+ for para in self.parameters():
177
+ para.detach_()
178
+ self.__delattr__('conv1')
179
+ self.__delattr__('conv2')
180
+ if hasattr(self, 'nm'):
181
+ self.__delattr__('nm')
182
+ if hasattr(self, 'bn'):
183
+ self.__delattr__('bn')
184
+ if hasattr(self, 'id_tensor'):
185
+ self.__delattr__('id_tensor')
186
+
187
+
188
+ class SP(nn.Module):
189
+ def __init__(self, k=3, s=1):
190
+ super(SP, self).__init__()
191
+ self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
192
+
193
+ def forward(self, x):
194
+ return self.m(x)
195
+
196
+
197
+ class MP(nn.Module):
198
+ # Max pooling
199
+ def __init__(self, k=2):
200
+ super(MP, self).__init__()
201
+ self.m = nn.MaxPool2d(kernel_size=k, stride=k)
202
+
203
+ def forward(self, x):
204
+ return self.m(x)
205
+
206
+
207
+ class ConvTranspose(nn.Module):
208
+ # Convolution transpose 2d layer
209
+ default_act = nn.SiLU() # default activation
210
+
211
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
212
+ super().__init__()
213
+ self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
214
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
215
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
216
+
217
+ def forward(self, x):
218
+ return self.act(self.bn(self.conv_transpose(x)))
219
+
220
+
221
+ class DWConv(Conv):
222
+ # Depth-wise convolution
223
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
224
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
225
+
226
+
227
+ class DWConvTranspose2d(nn.ConvTranspose2d):
228
+ # Depth-wise transpose convolution
229
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
230
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
231
+
232
+
233
+ class DFL(nn.Module):
234
+ # DFL module
235
+ def __init__(self, c1=17):
236
+ super().__init__()
237
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
238
+ self.conv.weight.data[:] = nn.Parameter(torch.arange(c1, dtype=torch.float).view(1, c1, 1, 1)) # / 120.0
239
+ self.c1 = c1
240
+ # self.bn = nn.BatchNorm2d(4)
241
+
242
+ def forward(self, x):
243
+ b, c, a = x.shape # batch, channels, anchors
244
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
245
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
246
+
247
+
248
+ class BottleneckBase(nn.Module):
249
+ # Standard bottleneck
250
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(1, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
251
+ super().__init__()
252
+ c_ = int(c2 * e) # hidden channels
253
+ self.cv1 = Conv(c1, c_, k[0], 1)
254
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
255
+ self.add = shortcut and c1 == c2
256
+
257
+ def forward(self, x):
258
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
259
+
260
+
261
+ class RBottleneckBase(nn.Module):
262
+ # Standard bottleneck
263
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
264
+ super().__init__()
265
+ c_ = int(c2 * e) # hidden channels
266
+ self.cv1 = Conv(c1, c_, k[0], 1)
267
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
268
+ self.add = shortcut and c1 == c2
269
+
270
+ def forward(self, x):
271
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
272
+
273
+
274
+ class RepNRBottleneckBase(nn.Module):
275
+ # Standard bottleneck
276
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
277
+ super().__init__()
278
+ c_ = int(c2 * e) # hidden channels
279
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
280
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
281
+ self.add = shortcut and c1 == c2
282
+
283
+ def forward(self, x):
284
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
285
+
286
+
287
+ class Bottleneck(nn.Module):
288
+ # Standard bottleneck
289
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
290
+ super().__init__()
291
+ c_ = int(c2 * e) # hidden channels
292
+ self.cv1 = Conv(c1, c_, k[0], 1)
293
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
294
+ self.add = shortcut and c1 == c2
295
+
296
+ def forward(self, x):
297
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
298
+
299
+
300
+ class RepNBottleneck(nn.Module):
301
+ # Standard bottleneck
302
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
303
+ super().__init__()
304
+ c_ = int(c2 * e) # hidden channels
305
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
306
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
307
+ self.add = shortcut and c1 == c2
308
+
309
+ def forward(self, x):
310
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
311
+
312
+ class Backbone(nn.Module):
313
+ def __init__(self):
314
+ super(Backbone, self).__init__()
315
+ self.cfgs = [
316
+ # k, t, c, SE, HS, s
317
+ [3, 2, 64 * 2, 1, 0, 1],
318
+ [3, 2, 64 * 2, 0, 0, 1],
319
+ [3, 2, 64 * 2, 1, 0, 1],
320
+ [3, 2, 64 * 2, 0, 0, 1],
321
+ [3, 2, 64 * 2, 0, 0, 1],
322
+ [3, 2, 128 * 2, 0, 0, 2],
323
+ [3, 2, 128 * 2, 1, 0, 1],
324
+ [3, 2, 128 * 2, 0, 0, 1],
325
+ [3, 2, 128 * 2, 1, 0, 1],
326
+ [3, 2, 128 * 2, 0, 0, 1],
327
+ [3, 2, 128 * 2, 0, 0, 1],
328
+ [3, 2, 256 * 2, 0, 1, 2],
329
+ [3, 2, 256 * 2, 1, 1, 1],
330
+ [3, 2, 256 * 2, 0, 1, 1],
331
+ [3, 2, 256 * 2, 1, 1, 1],
332
+ [3, 2, 256 * 2, 0, 1, 1],
333
+ [3, 2, 256 * 2, 1, 1, 1],
334
+ [3, 2, 256 * 2, 0, 1, 1],
335
+ [3, 2, 256 * 2, 1, 1, 1],
336
+ [3, 2, 256 * 2, 0, 1, 1],
337
+ [3, 2, 256 * 2, 1, 1, 1],
338
+ [3, 2, 256 * 2, 0, 1, 1],
339
+ [3, 2, 256 * 2, 1, 1, 1],
340
+ [3, 2, 256 * 2, 0, 1, 1],
341
+ [3, 2, 256 * 2, 1, 1, 1],
342
+ [3, 2, 256 * 2, 0, 1, 1],
343
+ [3, 2, 256 * 2, 1, 1, 1],
344
+ [3, 2, 256 * 2, 0, 1, 1],
345
+ [3, 2, 256 * 2, 1, 1, 1],
346
+ [3, 2, 256 * 2, 0, 1, 1],
347
+ [3, 2, 256 * 2, 1, 1, 1],
348
+ [3, 2, 256 * 2, 0, 1, 1],
349
+ [3, 2, 256 * 2, 1, 1, 1],
350
+ [3, 2, 256 * 2, 0, 1, 1],
351
+ [3, 2, 256 * 2, 1, 1, 1],
352
+ [3, 2, 256 * 2, 0, 1, 1],
353
+ [3, 2, 256 * 2, 0, 1, 1],
354
+ [3, 2, 512 * 2, 0, 1, 2],
355
+ [3, 2, 512 * 2, 1, 1, 1],
356
+ [3, 2, 512 * 2, 0, 1, 1],
357
+ [3, 2, 512 * 2, 1, 1, 1],
358
+ [3, 2, 512 * 2, 0, 1, 1]
359
+ ]
360
+ self.backbone = RepViT(self.cfgs )
361
+
362
+ def forward(self, x):
363
+ outputs = self.backbone (x)
364
+ return outputs
365
+ class Down0(nn.Module):
366
+ def __init__(self,inp):
367
+ super(Down0, self).__init__()
368
+
369
+ def forward(self, x):
370
+ return x[0]
371
+ class Down1(nn.Module):
372
+ def __init__(self,inp):
373
+ super(Down1, self).__init__()
374
+
375
+ def forward(self, x):
376
+ return x[1]
377
+ class Down2(nn.Module):
378
+ def __init__(self,inp):
379
+ super(Down2, self).__init__()
380
+
381
+ def forward(self, x):
382
+ return x[2]
383
+ class Down3(nn.Module):
384
+ def __init__(self,inp):
385
+ super(Down3, self).__init__()
386
+
387
+ def forward(self, x):
388
+ return x[3]
389
+
390
+ class Down4(nn.Module):
391
+ def __init__(self,inp):
392
+ super(Down4, self).__init__()
393
+
394
+ def forward(self, x):
395
+ return x[4]
396
+ class Res(nn.Module):
397
+ # ResNet bottleneck
398
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
399
+ super(Res, self).__init__()
400
+ c_ = int(c2 * e) # hidden channels
401
+ self.cv1 = Conv(c1, c_, 1, 1)
402
+ self.cv2 = Conv(c_, c_, 3, 1, g=g)
403
+ self.cv3 = Conv(c_, c2, 1, 1)
404
+ self.add = shortcut and c1 == c2
405
+
406
+ def forward(self, x):
407
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
408
+
409
+
410
+ class RepNRes(nn.Module):
411
+ # ResNet bottleneck
412
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
413
+ super(RepNRes, self).__init__()
414
+ c_ = int(c2 * e) # hidden channels
415
+ self.cv1 = Conv(c1, c_, 1, 1)
416
+ self.cv2 = RepConvN(c_, c_, 3, 1, g=g)
417
+ self.cv3 = Conv(c_, c2, 1, 1)
418
+ self.add = shortcut and c1 == c2
419
+
420
+ def forward(self, x):
421
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
422
+
423
+
424
+ class BottleneckCSP(nn.Module):
425
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
426
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
427
+ super().__init__()
428
+ c_ = int(c2 * e) # hidden channels
429
+ self.cv1 = Conv(c1, c_, 1, 1)
430
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
431
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
432
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
433
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
434
+ self.act = nn.SiLU()
435
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
436
+
437
+ def forward(self, x):
438
+ y1 = self.cv3(self.m(self.cv1(x)))
439
+ y2 = self.cv2(x)
440
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
441
+
442
+
443
+ class CSP(nn.Module):
444
+ # CSP Bottleneck with 3 convolutions
445
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
446
+ super().__init__()
447
+ c_ = int(c2 * e) # hidden channels
448
+ self.cv1 = Conv(c1, c_, 1, 1)
449
+ self.cv2 = Conv(c1, c_, 1, 1)
450
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
451
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
452
+
453
+ def forward(self, x):
454
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
455
+
456
+
457
+ class RepNCSP(nn.Module):
458
+ # CSP Bottleneck with 3 convolutions
459
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
460
+ super().__init__()
461
+ c_ = int(c2 * e) # hidden channels
462
+ self.cv1 = Conv(c1, c_, 1, 1)
463
+ self.cv2 = Conv(c1, c_, 1, 1)
464
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
465
+ self.m = nn.Sequential(*(RepNBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
466
+
467
+ def forward(self, x):
468
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
469
+
470
+
471
+ class CSPBase(nn.Module):
472
+ # CSP Bottleneck with 3 convolutions
473
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
474
+ super().__init__()
475
+ c_ = int(c2 * e) # hidden channels
476
+ self.cv1 = Conv(c1, c_, 1, 1)
477
+ self.cv2 = Conv(c1, c_, 1, 1)
478
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
479
+ self.m = nn.Sequential(*(BottleneckBase(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
480
+
481
+ def forward(self, x):
482
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
483
+
484
+
485
+ class SPP(nn.Module):
486
+ # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
487
+ def __init__(self, c1, c2, k=(5, 9, 13)):
488
+ super().__init__()
489
+ c_ = c1 // 2 # hidden channels
490
+ self.cv1 = Conv(c1, c_, 1, 1)
491
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
492
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
493
+
494
+ def forward(self, x):
495
+ x = self.cv1(x)
496
+ with warnings.catch_warnings():
497
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
498
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
499
+
500
+
501
+ class ASPP(torch.nn.Module):
502
+
503
+ def __init__(self, in_channels, out_channels):
504
+ super().__init__()
505
+ kernel_sizes = [1, 3, 3, 1]
506
+ dilations = [1, 3, 6, 1]
507
+ paddings = [0, 3, 6, 0]
508
+ self.aspp = torch.nn.ModuleList()
509
+ for aspp_idx in range(len(kernel_sizes)):
510
+ conv = torch.nn.Conv2d(
511
+ in_channels,
512
+ out_channels,
513
+ kernel_size=kernel_sizes[aspp_idx],
514
+ stride=1,
515
+ dilation=dilations[aspp_idx],
516
+ padding=paddings[aspp_idx],
517
+ bias=True)
518
+ self.aspp.append(conv)
519
+ self.gap = torch.nn.AdaptiveAvgPool2d(1)
520
+ self.aspp_num = len(kernel_sizes)
521
+ for m in self.modules():
522
+ if isinstance(m, torch.nn.Conv2d):
523
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
524
+ m.weight.data.normal_(0, math.sqrt(2. / n))
525
+ m.bias.data.fill_(0)
526
+
527
+ def forward(self, x):
528
+ avg_x = self.gap(x)
529
+ out = []
530
+ for aspp_idx in range(self.aspp_num):
531
+ inp = avg_x if (aspp_idx == self.aspp_num - 1) else x
532
+ out.append(F.relu_(self.aspp[aspp_idx](inp)))
533
+ out[-1] = out[-1].expand_as(out[-2])
534
+ out = torch.cat(out, dim=1)
535
+ return out
536
+
537
+
538
+ class SPPCSPC(nn.Module):
539
+ # CSP SPP https://github.com/WongKinYiu/CrossStagePartialNetworks
540
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
541
+ super(SPPCSPC, self).__init__()
542
+ c_ = int(2 * c2 * e) # hidden channels
543
+ self.cv1 = Conv(c1, c_, 1, 1)
544
+ self.cv2 = Conv(c1, c_, 1, 1)
545
+ self.cv3 = Conv(c_, c_, 3, 1)
546
+ self.cv4 = Conv(c_, c_, 1, 1)
547
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
548
+ self.cv5 = Conv(4 * c_, c_, 1, 1)
549
+ self.cv6 = Conv(c_, c_, 3, 1)
550
+ self.cv7 = Conv(2 * c_, c2, 1, 1)
551
+
552
+ def forward(self, x):
553
+ x1 = self.cv4(self.cv3(self.cv1(x)))
554
+ y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
555
+ y2 = self.cv2(x)
556
+ return self.cv7(torch.cat((y1, y2), dim=1))
557
+
558
+
559
+ class SPPF(nn.Module):
560
+ # Spatial Pyramid Pooling - Fast (SPPF) layer by Glenn Jocher
561
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
562
+ super().__init__()
563
+ c_ = c1 // 2 # hidden channels
564
+ self.cv1 = Conv(c1, c_, 1, 1)
565
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
566
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
567
+ # self.m = SoftPool2d(kernel_size=k, stride=1, padding=k // 2)
568
+
569
+ def forward(self, x):
570
+ x = self.cv1(x)
571
+ with warnings.catch_warnings():
572
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
573
+ y1 = self.m(x)
574
+ y2 = self.m(y1)
575
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
576
+
577
+
578
+ import torch.nn.functional as F
579
+ from torch.nn.modules.utils import _pair
580
+
581
+
582
+ class ReOrg(nn.Module):
583
+ # yolo
584
+ def __init__(self):
585
+ super(ReOrg, self).__init__()
586
+
587
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
588
+ return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
589
+
590
+
591
+ class Contract(nn.Module):
592
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
593
+ def __init__(self, gain=2):
594
+ super().__init__()
595
+ self.gain = gain
596
+
597
+ def forward(self, x):
598
+ b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
599
+ s = self.gain
600
+ x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
601
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
602
+ return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
603
+
604
+
605
+ class Expand(nn.Module):
606
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
607
+ def __init__(self, gain=2):
608
+ super().__init__()
609
+ self.gain = gain
610
+
611
+ def forward(self, x):
612
+ b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
613
+ s = self.gain
614
+ x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
615
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
616
+ return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
617
+
618
+
619
+ class Concat(nn.Module):
620
+ # Concatenate a list of tensors along dimension
621
+ def __init__(self, dimension=1):
622
+ super().__init__()
623
+ self.d = dimension
624
+
625
+ def forward(self, x):
626
+ return torch.cat(x, self.d)
627
+
628
+
629
+ class Shortcut(nn.Module):
630
+ def __init__(self, dimension=0):
631
+ super(Shortcut, self).__init__()
632
+ self.d = dimension
633
+
634
+ def forward(self, x):
635
+ return x[0]+x[1]
636
+
637
+
638
+ class Silence(nn.Module):
639
+ def __init__(self):
640
+ super(Silence, self).__init__()
641
+ def forward(self, x):
642
+ return x
643
+
644
+
645
+ ##### GELAN #####
646
+
647
+ class SPPELAN(nn.Module):
648
+ # spp-elan
649
+ def __init__(self, c1, c2, c3): # ch_in, ch_out, number, shortcut, groups, expansion
650
+ super().__init__()
651
+ self.c = c3
652
+ self.cv1 = Conv(c1, c3, 1, 1)
653
+ self.cv2 = SP(5)
654
+ self.cv3 = SP(5)
655
+ self.cv4 = SP(5)
656
+ self.cv5 = Conv(4*c3, c2, 1, 1)
657
+
658
+ def forward(self, x):
659
+ y = [self.cv1(x)]
660
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
661
+ return self.cv5(torch.cat(y, 1))
662
+
663
+
664
+ class RepNCSPELAN4(nn.Module):
665
+ # csp-elan
666
+ def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
667
+ super().__init__()
668
+ self.c = c3//2
669
+ self.cv1 = Conv(c1, c3, 1, 1)
670
+ self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), Conv(c4, c4, 3, 1))
671
+ self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
672
+ self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
673
+
674
+ def forward(self, x):
675
+ y = list(self.cv1(x).chunk(2, 1))
676
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
677
+ return self.cv4(torch.cat(y, 1))
678
+
679
+ def forward_split(self, x):
680
+ y = list(self.cv1(x).split((self.c, self.c), 1))
681
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
682
+ return self.cv4(torch.cat(y, 1))
683
+
684
+ #################
685
+
686
+
687
+ ##### YOLOR #####
688
+
689
+ class ImplicitA(nn.Module):
690
+ def __init__(self, channel):
691
+ super(ImplicitA, self).__init__()
692
+ self.channel = channel
693
+ self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
694
+ nn.init.normal_(self.implicit, std=.02)
695
+
696
+ def forward(self, x):
697
+ return self.implicit + x
698
+
699
+
700
+ class ImplicitM(nn.Module):
701
+ def __init__(self, channel):
702
+ super(ImplicitM, self).__init__()
703
+ self.channel = channel
704
+ self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
705
+ nn.init.normal_(self.implicit, mean=1., std=.02)
706
+
707
+ def forward(self, x):
708
+ return self.implicit * x
709
+
710
+ #################
711
+
712
+
713
+ ##### CBNet #####
714
+
715
+ class CBLinear(nn.Module):
716
+ def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
717
+ super(CBLinear, self).__init__()
718
+ self.c2s = c2s
719
+ self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
720
+
721
+ def forward(self, x):
722
+ outs = self.conv(x).split(self.c2s, dim=1)
723
+ return outs
724
+
725
+ class CBFuse(nn.Module):
726
+ def __init__(self, idx):
727
+ super(CBFuse, self).__init__()
728
+ self.idx = idx
729
+
730
+ def forward(self, xs):
731
+ target_size = xs[-1].shape[2:]
732
+ res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
733
+ out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
734
+ return out
735
+
736
+ #################
737
+
738
+
739
+ class DetectMultiBackend(nn.Module):
740
+ # YOLO MultiBackend class for python inference on various backends
741
+ def __init__(self, weights='yolo.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
742
+ # Usage:
743
+ # PyTorch: weights = *.pt
744
+ # TorchScript: *.torchscript
745
+ # ONNX Runtime: *.onnx
746
+ # ONNX OpenCV DNN: *.onnx --dnn
747
+ # OpenVINO: *_openvino_model
748
+ # CoreML: *.mlmodel
749
+ # TensorRT: *.engine
750
+ # TensorFlow SavedModel: *_saved_model
751
+ # TensorFlow GraphDef: *.pb
752
+ # TensorFlow Lite: *.tflite
753
+ # TensorFlow Edge TPU: *_edgetpu.tflite
754
+ # PaddlePaddle: *_paddle_model
755
+ from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
756
+
757
+ super().__init__()
758
+ w = str(weights[0] if isinstance(weights, list) else weights)
759
+ pt, jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
760
+ fp16 &= pt or jit or onnx or engine # FP16
761
+ nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
762
+ stride = 32 # default stride
763
+ cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
764
+ if not (pt or triton):
765
+ w = attempt_download(w) # download if not local
766
+
767
+ if pt: # PyTorch
768
+ model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
769
+ stride = max(int(model.stride.max()), 32) # model stride
770
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
771
+ model.half() if fp16 else model.float()
772
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
773
+ elif jit: # TorchScript
774
+ LOGGER.info(f'Loading {w} for TorchScript inference...')
775
+ extra_files = {'config.txt': ''} # model metadata
776
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
777
+ model.half() if fp16 else model.float()
778
+ if extra_files['config.txt']: # load metadata dict
779
+ d = json.loads(extra_files['config.txt'],
780
+ object_hook=lambda d: {int(k) if k.isdigit() else k: v
781
+ for k, v in d.items()})
782
+ stride, names = int(d['stride']), d['names']
783
+ elif dnn: # ONNX OpenCV DNN
784
+ LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
785
+ check_requirements('opencv-python>=4.5.4')
786
+ net = cv2.dnn.readNetFromONNX(w)
787
+ elif onnx: # ONNX Runtime
788
+ LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
789
+ check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
790
+ import onnxruntime
791
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
792
+ session = onnxruntime.InferenceSession(w, providers=providers)
793
+ output_names = [x.name for x in session.get_outputs()]
794
+ meta = session.get_modelmeta().custom_metadata_map # metadata
795
+ if 'stride' in meta:
796
+ stride, names = int(meta['stride']), eval(meta['names'])
797
+ elif xml: # OpenVINO
798
+ LOGGER.info(f'Loading {w} for OpenVINO inference...')
799
+ check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
800
+ from openvino.runtime import Core, Layout, get_batch
801
+ ie = Core()
802
+ if not Path(w).is_file(): # if not *.xml
803
+ w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
804
+ network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
805
+ if network.get_parameters()[0].get_layout().empty:
806
+ network.get_parameters()[0].set_layout(Layout("NCHW"))
807
+ batch_dim = get_batch(network)
808
+ if batch_dim.is_static:
809
+ batch_size = batch_dim.get_length()
810
+ executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
811
+ stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
812
+ elif engine: # TensorRT
813
+ LOGGER.info(f'Loading {w} for TensorRT inference...')
814
+ import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
815
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
816
+ if device.type == 'cpu':
817
+ device = torch.device('cuda:0')
818
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
819
+ logger = trt.Logger(trt.Logger.INFO)
820
+ with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
821
+ model = runtime.deserialize_cuda_engine(f.read())
822
+ context = model.create_execution_context()
823
+ bindings = OrderedDict()
824
+ output_names = []
825
+ fp16 = False # default updated below
826
+ dynamic = False
827
+ for i in range(model.num_bindings):
828
+ name = model.get_binding_name(i)
829
+ dtype = trt.nptype(model.get_binding_dtype(i))
830
+ if model.binding_is_input(i):
831
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
832
+ dynamic = True
833
+ context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
834
+ if dtype == np.float16:
835
+ fp16 = True
836
+ else: # output
837
+ output_names.append(name)
838
+ shape = tuple(context.get_binding_shape(i))
839
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
840
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
841
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
842
+ batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
843
+ elif coreml: # CoreML
844
+ LOGGER.info(f'Loading {w} for CoreML inference...')
845
+ import coremltools as ct
846
+ model = ct.models.MLModel(w)
847
+ elif saved_model: # TF SavedModel
848
+ LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
849
+ import tensorflow as tf
850
+ keras = False # assume TF1 saved_model
851
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
852
+ elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
853
+ LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
854
+ import tensorflow as tf
855
+
856
+ def wrap_frozen_graph(gd, inputs, outputs):
857
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
858
+ ge = x.graph.as_graph_element
859
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
860
+
861
+ def gd_outputs(gd):
862
+ name_list, input_list = [], []
863
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
864
+ name_list.append(node.name)
865
+ input_list.extend(node.input)
866
+ return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
867
+
868
+ gd = tf.Graph().as_graph_def() # TF GraphDef
869
+ with open(w, 'rb') as f:
870
+ gd.ParseFromString(f.read())
871
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
872
+ elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
873
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
874
+ from tflite_runtime.interpreter import Interpreter, load_delegate
875
+ except ImportError:
876
+ import tensorflow as tf
877
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
878
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
879
+ LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
880
+ delegate = {
881
+ 'Linux': 'libedgetpu.so.1',
882
+ 'Darwin': 'libedgetpu.1.dylib',
883
+ 'Windows': 'edgetpu.dll'}[platform.system()]
884
+ interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
885
+ else: # TFLite
886
+ LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
887
+ interpreter = Interpreter(model_path=w) # load TFLite model
888
+ interpreter.allocate_tensors() # allocate
889
+ input_details = interpreter.get_input_details() # inputs
890
+ output_details = interpreter.get_output_details() # outputs
891
+ # load metadata
892
+ with contextlib.suppress(zipfile.BadZipFile):
893
+ with zipfile.ZipFile(w, "r") as model:
894
+ meta_file = model.namelist()[0]
895
+ meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
896
+ stride, names = int(meta['stride']), meta['names']
897
+ elif tfjs: # TF.js
898
+ raise NotImplementedError('ERROR: YOLO TF.js inference is not supported')
899
+ elif paddle: # PaddlePaddle
900
+ LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
901
+ check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
902
+ import paddle.inference as pdi
903
+ if not Path(w).is_file(): # if not *.pdmodel
904
+ w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
905
+ weights = Path(w).with_suffix('.pdiparams')
906
+ config = pdi.Config(str(w), str(weights))
907
+ if cuda:
908
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
909
+ predictor = pdi.create_predictor(config)
910
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
911
+ output_names = predictor.get_output_names()
912
+ elif triton: # NVIDIA Triton Inference Server
913
+ LOGGER.info(f'Using {w} as Triton Inference Server...')
914
+ check_requirements('tritonclient[all]')
915
+ from utils.triton import TritonRemoteModel
916
+ model = TritonRemoteModel(url=w)
917
+ nhwc = model.runtime.startswith("tensorflow")
918
+ else:
919
+ raise NotImplementedError(f'ERROR: {w} is not a supported format')
920
+
921
+ # class names
922
+ if 'names' not in locals():
923
+ names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
924
+ if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
925
+ names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
926
+
927
+ self.__dict__.update(locals()) # assign all variables to self
928
+
929
+ def forward(self, im, augment=False, visualize=False):
930
+ # YOLO MultiBackend inference
931
+ b, ch, h, w = im.shape # batch, channel, height, width
932
+ if self.fp16 and im.dtype != torch.float16:
933
+ im = im.half() # to FP16
934
+ if self.nhwc:
935
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
936
+
937
+ if self.pt: # PyTorch
938
+ y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
939
+ elif self.jit: # TorchScript
940
+ y = self.model(im)
941
+ elif self.dnn: # ONNX OpenCV DNN
942
+ im = im.cpu().numpy() # torch to numpy
943
+ self.net.setInput(im)
944
+ y = self.net.forward()
945
+ elif self.onnx: # ONNX Runtime
946
+ im = im.cpu().numpy() # torch to numpy
947
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
948
+ elif self.xml: # OpenVINO
949
+ im = im.cpu().numpy() # FP32
950
+ y = list(self.executable_network([im]).values())
951
+ elif self.engine: # TensorRT
952
+ if self.dynamic and im.shape != self.bindings['images'].shape:
953
+ i = self.model.get_binding_index('images')
954
+ self.context.set_binding_shape(i, im.shape) # reshape if dynamic
955
+ self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
956
+ for name in self.output_names:
957
+ i = self.model.get_binding_index(name)
958
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
959
+ s = self.bindings['images'].shape
960
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
961
+ self.binding_addrs['images'] = int(im.data_ptr())
962
+ self.context.execute_v2(list(self.binding_addrs.values()))
963
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
964
+ elif self.coreml: # CoreML
965
+ im = im.cpu().numpy()
966
+ im = Image.fromarray((im[0] * 255).astype('uint8'))
967
+ # im = im.resize((192, 320), Image.ANTIALIAS)
968
+ y = self.model.predict({'image': im}) # coordinates are xywh normalized
969
+ if 'confidence' in y:
970
+ box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
971
+ conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
972
+ y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
973
+ else:
974
+ y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
975
+ elif self.paddle: # PaddlePaddle
976
+ im = im.cpu().numpy().astype(np.float32)
977
+ self.input_handle.copy_from_cpu(im)
978
+ self.predictor.run()
979
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
980
+ elif self.triton: # NVIDIA Triton Inference Server
981
+ y = self.model(im)
982
+ else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
983
+ im = im.cpu().numpy()
984
+ if self.saved_model: # SavedModel
985
+ y = self.model(im, training=False) if self.keras else self.model(im)
986
+ elif self.pb: # GraphDef
987
+ y = self.frozen_func(x=self.tf.constant(im))
988
+ else: # Lite or Edge TPU
989
+ input = self.input_details[0]
990
+ int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
991
+ if int8:
992
+ scale, zero_point = input['quantization']
993
+ im = (im / scale + zero_point).astype(np.uint8) # de-scale
994
+ self.interpreter.set_tensor(input['index'], im)
995
+ self.interpreter.invoke()
996
+ y = []
997
+ for output in self.output_details:
998
+ x = self.interpreter.get_tensor(output['index'])
999
+ if int8:
1000
+ scale, zero_point = output['quantization']
1001
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
1002
+ y.append(x)
1003
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
1004
+ y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
1005
+
1006
+ if isinstance(y, (list, tuple)):
1007
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
1008
+ else:
1009
+ return self.from_numpy(y)
1010
+
1011
+ def from_numpy(self, x):
1012
+ return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
1013
+
1014
+ def warmup(self, imgsz=(1, 3, 640, 640)):
1015
+ # Warmup model by running inference once
1016
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
1017
+ if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
1018
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
1019
+ for _ in range(2 if self.jit else 1): #
1020
+ self.forward(im) # warmup
1021
+
1022
+ @staticmethod
1023
+ def _model_type(p='path/to/model.pt'):
1024
+ # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
1025
+ # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
1026
+ from export import export_formats
1027
+ from utils.downloads import is_url
1028
+ sf = list(export_formats().Suffix) # export suffixes
1029
+ if not is_url(p, check=False):
1030
+ check_suffix(p, sf) # checks
1031
+ url = urlparse(p) # if url may be Triton inference server
1032
+ types = [s in Path(p).name for s in sf]
1033
+ types[8] &= not types[9] # tflite &= not edgetpu
1034
+ triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
1035
+ return types + [triton]
1036
+
1037
+ @staticmethod
1038
+ def _load_metadata(f=Path('path/to/meta.yaml')):
1039
+ # Load metadata from meta.yaml if it exists
1040
+ if f.exists():
1041
+ d = yaml_load(f)
1042
+ return d['stride'], d['names'] # assign stride, names
1043
+ return None, None
1044
+
1045
+
1046
+ class AutoShape(nn.Module):
1047
+ # YOLO input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
1048
+ conf = 0.25 # NMS confidence threshold
1049
+ iou = 0.45 # NMS IoU threshold
1050
+ agnostic = False # NMS class-agnostic
1051
+ multi_label = False # NMS multiple labels per box
1052
+ classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
1053
+ max_det = 1000 # maximum number of detections per image
1054
+ amp = False # Automatic Mixed Precision (AMP) inference
1055
+
1056
+ def __init__(self, model, verbose=True):
1057
+ super().__init__()
1058
+ if verbose:
1059
+ LOGGER.info('Adding AutoShape... ')
1060
+ copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
1061
+ self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
1062
+ self.pt = not self.dmb or model.pt # PyTorch model
1063
+ self.model = model.eval()
1064
+ if self.pt:
1065
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
1066
+ m.inplace = False # Detect.inplace=False for safe multithread inference
1067
+ m.export = True # do not output loss values
1068
+
1069
+ def _apply(self, fn):
1070
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
1071
+ self = super()._apply(fn)
1072
+ from models.yolo import Detect, Segment
1073
+ if self.pt:
1074
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
1075
+ if isinstance(m, (Detect, Segment)):
1076
+ for k in 'stride', 'anchor_grid', 'stride_grid', 'grid':
1077
+ x = getattr(m, k)
1078
+ setattr(m, k, list(map(fn, x))) if isinstance(x, (list, tuple)) else setattr(m, k, fn(x))
1079
+ return self
1080
+
1081
+ @smart_inference_mode()
1082
+ def forward(self, ims, size=640, augment=False, profile=False):
1083
+ # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
1084
+ # file: ims = 'data/images/zidane.jpg' # str or PosixPath
1085
+ # URI: = 'https://ultralytics.com/images/zidane.jpg'
1086
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
1087
+ # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
1088
+ # numpy: = np.zeros((640,1280,3)) # HWC
1089
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
1090
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
1091
+
1092
+ dt = (Profile(), Profile(), Profile())
1093
+ with dt[0]:
1094
+ if isinstance(size, int): # expand
1095
+ size = (size, size)
1096
+ p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
1097
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
1098
+ if isinstance(ims, torch.Tensor): # torch
1099
+ with amp.autocast(autocast):
1100
+ return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
1101
+
1102
+ # Pre-process
1103
+ n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
1104
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
1105
+ for i, im in enumerate(ims):
1106
+ f = f'image{i}' # filename
1107
+ if isinstance(im, (str, Path)): # filename or uri
1108
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
1109
+ im = np.asarray(exif_transpose(im))
1110
+ elif isinstance(im, Image.Image): # PIL Image
1111
+ im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
1112
+ files.append(Path(f).with_suffix('.jpg').name)
1113
+ if im.shape[0] < 5: # image in CHW
1114
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
1115
+ im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
1116
+ s = im.shape[:2] # HWC
1117
+ shape0.append(s) # image shape
1118
+ g = max(size) / max(s) # gain
1119
+ shape1.append([int(y * g) for y in s])
1120
+ ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
1121
+ shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
1122
+ x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
1123
+ x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
1124
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
1125
+
1126
+ with amp.autocast(autocast):
1127
+ # Inference
1128
+ with dt[1]:
1129
+ y = self.model(x, augment=augment) # forward
1130
+
1131
+ # Post-process
1132
+ with dt[2]:
1133
+ y = non_max_suppression(y if self.dmb else y[0],
1134
+ self.conf,
1135
+ self.iou,
1136
+ self.classes,
1137
+ self.agnostic,
1138
+ self.multi_label,
1139
+ max_det=self.max_det) # NMS
1140
+ for i in range(n):
1141
+ scale_boxes(shape1, y[i][:, :4], shape0[i])
1142
+
1143
+ return Detections(ims, y, files, dt, self.names, x.shape)
1144
+
1145
+
1146
+ class Detections:
1147
+ # YOLO detections class for inference results
1148
+ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
1149
+ super().__init__()
1150
+ d = pred[0].device # device
1151
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
1152
+ self.ims = ims # list of images as numpy arrays
1153
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
1154
+ self.names = names # class names
1155
+ self.files = files # image filenames
1156
+ self.times = times # profiling times
1157
+ self.xyxy = pred # xyxy pixels
1158
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
1159
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
1160
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
1161
+ self.n = len(self.pred) # number of images (batch size)
1162
+ self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
1163
+ self.s = tuple(shape) # inference BCHW shape
1164
+
1165
+ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
1166
+ s, crops = '', []
1167
+ for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
1168
+ s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
1169
+ if pred.shape[0]:
1170
+ for c in pred[:, -1].unique():
1171
+ n = (pred[:, -1] == c).sum() # detections per class
1172
+ s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
1173
+ s = s.rstrip(', ')
1174
+ if show or save or render or crop:
1175
+ annotator = Annotator(im, example=str(self.names))
1176
+ for *box, conf, cls in reversed(pred): # xyxy, confidence, class
1177
+ label = f'{self.names[int(cls)]} {conf:.2f}'
1178
+ if crop:
1179
+ file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
1180
+ crops.append({
1181
+ 'box': box,
1182
+ 'conf': conf,
1183
+ 'cls': cls,
1184
+ 'label': label,
1185
+ 'im': save_one_box(box, im, file=file, save=save)})
1186
+ else: # all others
1187
+ annotator.box_label(box, label if labels else '', color=colors(cls))
1188
+ im = annotator.im
1189
+ else:
1190
+ s += '(no detections)'
1191
+
1192
+ im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
1193
+ if show:
1194
+ display(im) if is_notebook() else im.show(self.files[i])
1195
+ if save:
1196
+ f = self.files[i]
1197
+ im.save(save_dir / f) # save
1198
+ if i == self.n - 1:
1199
+ LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
1200
+ if render:
1201
+ self.ims[i] = np.asarray(im)
1202
+ if pprint:
1203
+ s = s.lstrip('\n')
1204
+ return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
1205
+ if crop:
1206
+ if save:
1207
+ LOGGER.info(f'Saved results to {save_dir}\n')
1208
+ return crops
1209
+
1210
+ @TryExcept('Showing images is not supported in this environment')
1211
+ def show(self, labels=True):
1212
+ self._run(show=True, labels=labels) # show results
1213
+
1214
+ def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
1215
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
1216
+ self._run(save=True, labels=labels, save_dir=save_dir) # save results
1217
+
1218
+ def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
1219
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
1220
+ return self._run(crop=True, save=save, save_dir=save_dir) # crop results
1221
+
1222
+ def render(self, labels=True):
1223
+ self._run(render=True, labels=labels) # render results
1224
+ return self.ims
1225
+
1226
+ def pandas(self):
1227
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
1228
+ new = copy(self) # return copy
1229
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
1230
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
1231
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
1232
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
1233
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
1234
+ return new
1235
+
1236
+ def tolist(self):
1237
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
1238
+ r = range(self.n) # iterable
1239
+ x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
1240
+ # for d in x:
1241
+ # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
1242
+ # setattr(d, k, getattr(d, k)[0]) # pop out of list
1243
+ return x
1244
+
1245
+ def print(self):
1246
+ LOGGER.info(self.__str__())
1247
+
1248
+ def __len__(self): # override len(results)
1249
+ return self.n
1250
+
1251
+ def __str__(self): # override print(results)
1252
+ return self._run(pprint=True) # print results
1253
+
1254
+ def __repr__(self):
1255
+ return f'YOLO {self.__class__} instance\n' + self.__str__()
1256
+
1257
+
1258
+ class Proto(nn.Module):
1259
+ # YOLO mask Proto module for segmentation models
1260
+ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
1261
+ super().__init__()
1262
+ self.cv1 = Conv(c1, c_, k=3)
1263
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
1264
+ self.cv2 = Conv(c_, c_, k=3)
1265
+ self.cv3 = Conv(c_, c2)
1266
+
1267
+ def forward(self, x):
1268
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
1269
+
1270
+
1271
+ class UConv(nn.Module):
1272
+ def __init__(self, c1, c_=256, c2=256): # ch_in, number of protos, number of masks
1273
+ super().__init__()
1274
+
1275
+ self.cv1 = Conv(c1, c_, k=3)
1276
+ self.cv2 = nn.Conv2d(c_, c2, 1, 1)
1277
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1278
+
1279
+ def forward(self, x):
1280
+ return self.up(self.cv2(self.cv1(x)))
1281
+
1282
+
1283
+ class Classify(nn.Module):
1284
+ # YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)
1285
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
1286
+ super().__init__()
1287
+ c_ = 1280 # efficientnet_b0 size
1288
+ self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
1289
+ self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
1290
+ self.drop = nn.Dropout(p=0.0, inplace=True)
1291
+ self.linear = nn.Linear(c_, c2) # to x(b,c2)
1292
+
1293
+ def forward(self, x):
1294
+ if isinstance(x, list):
1295
+ x = torch.cat(x, 1)
1296
+ return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
models/detect/pk-yolo.yaml ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 2 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ # activation: nn.LeakyReLU(0.1)
8
+ # activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # YOLOv9 backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ [-1, 1, Backbone, []],
19
+ # conv down
20
+ [1, 1, Down0, [64]], #2 320 1
21
+ [1, 1, Down1, [128]], # 3 160 3
22
+ [1, 1, Down2, [256]],# 4 80 5
23
+ [1, 1, Down3, [512]], #5 40 7
24
+ [1, 1, Down4, [1024]], #6 20 9
25
+
26
+ # routing
27
+ [ 2, 1, CBLinear, [ [ 64 ] ] ], # 10
28
+ [ 3, 1, CBLinear, [ [ 64, 128 ] ] ], # 11
29
+ [ 4, 1, CBLinear, [ [ 64, 128, 256 ] ] ], # 12
30
+ [ 5, 1, CBLinear, [ [ 64, 128, 256, 512 ] ] ], # 13
31
+ [ 6, 1, CBLinear, [ [ 64, 128, 256, 512, 1024 ] ] ], # 14 -3
32
+
33
+ # conv down fuse
34
+ [ 0, 1, Conv, [ 64, 3, 2 ] ], # 15-P1/2
35
+ [ [ 7, 8, 9, 10, 11, -1 ], 1, CBFuse, [ [ 0, 0, 0, 0, 0 ] ] ], # 16
36
+
37
+ # conv down fuse
38
+ [ -1, 1, Conv, [ 128, 3, 2 ] ], # 17-P2/4
39
+ [ [ 8, 9, 10, 11, -1 ], 1, CBFuse, [ [ 1, 1, 1, 1 ] ] ], # 18
40
+
41
+ # elan-1 block
42
+ [ -1, 1, RepNCSPELAN4, [ 256, 128, 64, 2 ] ], # 19
43
+
44
+ # avg-conv down fuse
45
+ [ -1, 1, ADown, [ 256 ] ], # 20-P3/8
46
+ [ [ 9, 10, 11, -1 ], 1, CBFuse, [ [ 2, 2, 2 ] ] ], # 21
47
+
48
+ # elan-2 block
49
+ [ -1, 1, RepNCSPELAN4, [ 512, 256, 128, 2 ] ], # 22
50
+
51
+ # avg-conv down fuse
52
+ [ -1, 1, ADown, [ 512 ] ], # 23-P4/16
53
+ [ [ 10, 11, -1 ], 1, CBFuse, [ [ 3, 3 ] ] ], # 24
54
+
55
+ # elan-2 block
56
+ [ -1, 1, RepNCSPELAN4, [ 1024, 512, 256, 2 ] ], # 25
57
+
58
+ # avg-conv down fuse
59
+ [ -1, 1, ADown, [ 1024 ] ], # 26-P5/32
60
+ [ [ 11, -1 ], 1, CBFuse, [ [ 4 ] ] ], # 27
61
+
62
+ # elan-2 block
63
+ [ -1, 1, RepNCSPELAN4, [ 1024, 512, 256, 2 ] ], # 28 25
64
+
65
+ ]
66
+
67
+ # YOLOv9 head
68
+ head:
69
+ [
70
+ # multi-level auxiliary branch
71
+
72
+ # elan-spp block
73
+ [6, 1, SPPELAN, [512, 256]], # 29
74
+
75
+ # up-concat merge
76
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
77
+ [[-1, 5], 1, Concat, [1]], # cat backbone P4
78
+
79
+ # csp-elan block
80
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 32
81
+
82
+ # up-concat merge
83
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
84
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
85
+
86
+ # csp-elan block
87
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 35
88
+
89
+
90
+
91
+ # main branch
92
+
93
+ # elan-spp block
94
+ [25, 1, SPPELAN, [512, 256]], # 36
95
+
96
+ # up-concat merge
97
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
98
+ [[-1, 22], 1, Concat, [1]], # cat backbone P4
99
+
100
+ # csp-elan block
101
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 39
102
+
103
+ # up-concat merge
104
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
105
+ [[-1, 19], 1, Concat, [1]], # cat backbone P3
106
+
107
+ # csp-elan block
108
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 42 (P3/8-small)
109
+
110
+ # avg-conv-down merge
111
+ [-1, 1, ADown, [256]],
112
+ [[-1, 36], 1, Concat, [1]], # cat head P4
113
+
114
+ # csp-elan block
115
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 45 (P4/16-medium)
116
+
117
+ # avg-conv-down merge
118
+ [-1, 1, ADown, [512]],
119
+ [[-1, 33], 1, Concat, [1]], # cat head P5
120
+
121
+ # csp-elan block
122
+ [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]], # 48 (P5/32-large)
123
+
124
+ # detect
125
+ [[32, 29, 26, 39, 42, 45], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
126
+ ]
models/detect/yolov9-e.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 2 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # YOLOv9 backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # csp-elan block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, ADown, [256]], # 4-P3/8
29
+
30
+ # csp-elan block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, ADown, [512]], # 6-P4/16
35
+
36
+ # csp-elan block
37
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, ADown, [1024]], # 8-P5/32
41
+
42
+ # csp-elan block
43
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 9
44
+
45
+ # routing
46
+ [1, 1, CBLinear, [[64]]], # 10
47
+ [3, 1, CBLinear, [[64, 128]]], # 11
48
+ [5, 1, CBLinear, [[64, 128, 256]]], # 12
49
+ [7, 1, CBLinear, [[64, 128, 256, 512]]], # 13
50
+ [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]], # 14
51
+
52
+ # conv down
53
+ [0, 1, Conv, [64, 3, 2]], # 15-P1/2
54
+ [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]], # 16
55
+
56
+ # conv down
57
+ [-1, 1, Conv, [128, 3, 2]], # 17-P2/4
58
+ [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]], # 18
59
+
60
+ # csp-elan block
61
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 19
62
+
63
+ # avg-conv down fuse
64
+ [-1, 1, ADown, [256]], # 20-P3/8
65
+ [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]], # 21
66
+
67
+ # csp-elan block
68
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 22
69
+
70
+ # avg-conv down fuse
71
+ [-1, 1, ADown, [512]], # 23-P4/16
72
+ [[13, 14, -1], 1, CBFuse, [[3, 3]]], # 24
73
+
74
+ # csp-elan block
75
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 25
76
+
77
+ # avg-conv down fuse
78
+ [-1, 1, ADown, [1024]], # 26-P5/32
79
+ [[14, -1], 1, CBFuse, [[4]]], # 27
80
+
81
+ # csp-elan block
82
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 28
83
+ ]
84
+
85
+ # YOLOv9 head
86
+ head:
87
+ [
88
+ # multi-level auxiliary branch
89
+
90
+ # elan-spp block
91
+ [9, 1, SPPELAN, [512, 256]], # 29
92
+
93
+ # up-concat merge
94
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
95
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
96
+
97
+ # csp-elan block
98
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 32
99
+
100
+ # up-concat merge
101
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
102
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
103
+
104
+ # csp-elan block
105
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 35
106
+
107
+
108
+
109
+ # main branch
110
+
111
+ # elan-spp block
112
+ [28, 1, SPPELAN, [512, 256]], # 36
113
+
114
+ # up-concat merge
115
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
116
+ [[-1, 25], 1, Concat, [1]], # cat backbone P4
117
+
118
+ # csp-elan block
119
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 39
120
+
121
+ # up-concat merge
122
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
123
+ [[-1, 22], 1, Concat, [1]], # cat backbone P3
124
+
125
+ # csp-elan block
126
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 42 (P3/8-small)
127
+
128
+ # avg-conv-down merge
129
+ [-1, 1, ADown, [256]],
130
+ [[-1, 39], 1, Concat, [1]], # cat head P4
131
+
132
+ # csp-elan block
133
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 45 (P4/16-medium)
134
+
135
+ # avg-conv-down merge
136
+ [-1, 1, ADown, [512]],
137
+ [[-1, 36], 1, Concat, [1]], # cat head P5
138
+
139
+ # csp-elan block
140
+ [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]], # 48 (P5/32-large)
141
+
142
+ # detect
143
+ [[35, 32, 29, 42, 45, 48], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
144
+ ]
models/experimental.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from utils.downloads import attempt_download
8
+
9
+
10
+ class Sum(nn.Module):
11
+ # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
12
+ def __init__(self, n, weight=False): # n: number of inputs
13
+ super().__init__()
14
+ self.weight = weight # apply weights boolean
15
+ self.iter = range(n - 1) # iter object
16
+ if weight:
17
+ self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
18
+
19
+ def forward(self, x):
20
+ y = x[0] # no weight
21
+ if self.weight:
22
+ w = torch.sigmoid(self.w) * 2
23
+ for i in self.iter:
24
+ y = y + x[i + 1] * w[i]
25
+ else:
26
+ for i in self.iter:
27
+ y = y + x[i + 1]
28
+ return y
29
+
30
+
31
+ class MixConv2d(nn.Module):
32
+ # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
33
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
34
+ super().__init__()
35
+ n = len(k) # number of convolutions
36
+ if equal_ch: # equal c_ per group
37
+ i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
38
+ c_ = [(i == g).sum() for g in range(n)] # intermediate channels
39
+ else: # equal weight.numel() per group
40
+ b = [c2] + [0] * n
41
+ a = np.eye(n + 1, n, k=-1)
42
+ a -= np.roll(a, 1, axis=1)
43
+ a *= np.array(k) ** 2
44
+ a[0] = 1
45
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
46
+
47
+ self.m = nn.ModuleList([
48
+ nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
49
+ self.bn = nn.BatchNorm2d(c2)
50
+ self.act = nn.SiLU()
51
+
52
+ def forward(self, x):
53
+ return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
54
+
55
+
56
+ class Ensemble(nn.ModuleList):
57
+ # Ensemble of models
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def forward(self, x, augment=False, profile=False, visualize=False):
62
+ y = [module(x, augment, profile, visualize)[0] for module in self]
63
+ # y = torch.stack(y).max(0)[0] # max ensemble
64
+ # y = torch.stack(y).mean(0) # mean ensemble
65
+ y = torch.cat(y, 1) # nms ensemble
66
+ return y, None # inference, train output
67
+
68
+
69
+ class ORT_NMS(torch.autograd.Function):
70
+ '''ONNX-Runtime NMS operation'''
71
+ @staticmethod
72
+ def forward(ctx,
73
+ boxes,
74
+ scores,
75
+ max_output_boxes_per_class=torch.tensor([100]),
76
+ iou_threshold=torch.tensor([0.45]),
77
+ score_threshold=torch.tensor([0.25])):
78
+ device = boxes.device
79
+ batch = scores.shape[0]
80
+ num_det = random.randint(0, 100)
81
+ batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
82
+ idxs = torch.arange(100, 100 + num_det).to(device)
83
+ zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
84
+ selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
85
+ selected_indices = selected_indices.to(torch.int64)
86
+ return selected_indices
87
+
88
+ @staticmethod
89
+ def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
90
+ return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
91
+
92
+
93
+ class TRT_NMS(torch.autograd.Function):
94
+ '''TensorRT NMS operation'''
95
+ @staticmethod
96
+ def forward(
97
+ ctx,
98
+ boxes,
99
+ scores,
100
+ background_class=-1,
101
+ box_coding=1,
102
+ iou_threshold=0.45,
103
+ max_output_boxes=100,
104
+ plugin_version="1",
105
+ score_activation=0,
106
+ score_threshold=0.25,
107
+ ):
108
+
109
+ batch_size, num_boxes, num_classes = scores.shape
110
+ num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
111
+ det_boxes = torch.randn(batch_size, max_output_boxes, 4)
112
+ det_scores = torch.randn(batch_size, max_output_boxes)
113
+ det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
114
+ return num_det, det_boxes, det_scores, det_classes
115
+
116
+ @staticmethod
117
+ def symbolic(g,
118
+ boxes,
119
+ scores,
120
+ background_class=-1,
121
+ box_coding=1,
122
+ iou_threshold=0.45,
123
+ max_output_boxes=100,
124
+ plugin_version="1",
125
+ score_activation=0,
126
+ score_threshold=0.25):
127
+ out = g.op("TRT::EfficientNMS_TRT",
128
+ boxes,
129
+ scores,
130
+ background_class_i=background_class,
131
+ box_coding_i=box_coding,
132
+ iou_threshold_f=iou_threshold,
133
+ max_output_boxes_i=max_output_boxes,
134
+ plugin_version_s=plugin_version,
135
+ score_activation_i=score_activation,
136
+ score_threshold_f=score_threshold,
137
+ outputs=4)
138
+ nums, boxes, scores, classes = out
139
+ return nums, boxes, scores, classes
140
+
141
+
142
+ class ONNX_ORT(nn.Module):
143
+ '''onnx module with ONNX-Runtime NMS operation.'''
144
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
145
+ super().__init__()
146
+ self.device = device if device else torch.device("cpu")
147
+ self.max_obj = torch.tensor([max_obj]).to(device)
148
+ self.iou_threshold = torch.tensor([iou_thres]).to(device)
149
+ self.score_threshold = torch.tensor([score_thres]).to(device)
150
+ self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
151
+ self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
152
+ dtype=torch.float32,
153
+ device=self.device)
154
+ self.n_classes=n_classes
155
+
156
+ def forward(self, x):
157
+ ## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
158
+ ## thanks https://github.com/thaitc-hust
159
+ if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
160
+ x = x[1]
161
+ x = x.permute(0, 2, 1)
162
+ bboxes_x = x[..., 0:1]
163
+ bboxes_y = x[..., 1:2]
164
+ bboxes_w = x[..., 2:3]
165
+ bboxes_h = x[..., 3:4]
166
+ bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
167
+ bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
168
+ obj_conf = x[..., 4:]
169
+ scores = obj_conf
170
+ bboxes @= self.convert_matrix
171
+ max_score, category_id = scores.max(2, keepdim=True)
172
+ dis = category_id.float() * self.max_wh
173
+ nmsbox = bboxes + dis
174
+ max_score_tp = max_score.transpose(1, 2).contiguous()
175
+ selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
176
+ X, Y = selected_indices[:, 0], selected_indices[:, 2]
177
+ selected_boxes = bboxes[X, Y, :]
178
+ selected_categories = category_id[X, Y, :].float()
179
+ selected_scores = max_score[X, Y, :]
180
+ X = X.unsqueeze(1).float()
181
+ return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
182
+
183
+
184
+ class ONNX_TRT(nn.Module):
185
+ '''onnx module with TensorRT NMS operation.'''
186
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
187
+ super().__init__()
188
+ assert max_wh is None
189
+ self.device = device if device else torch.device('cpu')
190
+ self.background_class = -1,
191
+ self.box_coding = 1,
192
+ self.iou_threshold = iou_thres
193
+ self.max_obj = max_obj
194
+ self.plugin_version = '1'
195
+ self.score_activation = 0
196
+ self.score_threshold = score_thres
197
+ self.n_classes=n_classes
198
+
199
+ def forward(self, x):
200
+ ## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
201
+ ## thanks https://github.com/thaitc-hust
202
+ if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
203
+ x = x[1]
204
+ x = x.permute(0, 2, 1)
205
+ bboxes_x = x[..., 0:1]
206
+ bboxes_y = x[..., 1:2]
207
+ bboxes_w = x[..., 2:3]
208
+ bboxes_h = x[..., 3:4]
209
+ bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
210
+ bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
211
+ obj_conf = x[..., 4:]
212
+ scores = obj_conf
213
+ num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(bboxes, scores, self.background_class, self.box_coding,
214
+ self.iou_threshold, self.max_obj,
215
+ self.plugin_version, self.score_activation,
216
+ self.score_threshold)
217
+ return num_det, det_boxes, det_scores, det_classes
218
+
219
+ class End2End(nn.Module):
220
+ '''export onnx or tensorrt model with NMS operation.'''
221
+ def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
222
+ super().__init__()
223
+ device = device if device else torch.device('cpu')
224
+ assert isinstance(max_wh,(int)) or max_wh is None
225
+ self.model = model.to(device)
226
+ self.model.model[-1].end2end = True
227
+ self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
228
+ self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
229
+ self.end2end.eval()
230
+
231
+ def forward(self, x):
232
+ x = self.model(x)
233
+ x = self.end2end(x)
234
+ return x
235
+
236
+
237
+ def attempt_load(weights, device=None, inplace=True, fuse=True):
238
+ # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
239
+ from models.yolo import Detect, Model
240
+
241
+ model = Ensemble()
242
+ for w in weights if isinstance(weights, list) else [weights]:
243
+ ckpt = torch.load(attempt_download(w), map_location='cpu') # load
244
+ ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
245
+
246
+ # Model compatibility updates
247
+ if not hasattr(ckpt, 'stride'):
248
+ ckpt.stride = torch.tensor([32.])
249
+ if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
250
+ ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
251
+
252
+ model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
253
+
254
+ # Module compatibility updates
255
+ for m in model.modules():
256
+ t = type(m)
257
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
258
+ m.inplace = inplace # torch 1.7.0 compatibility
259
+ # if t is Detect and not isinstance(m.anchor_grid, list):
260
+ # delattr(m, 'anchor_grid')
261
+ # setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
262
+ elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
263
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
264
+
265
+ # Return model
266
+ if len(model) == 1:
267
+ return model[-1]
268
+
269
+ # Return detection ensemble
270
+ print(f'Ensemble created with {weights}\n')
271
+ for k in 'names', 'nc', 'yaml':
272
+ setattr(model, k, getattr(model[0], k))
273
+ model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
274
+ assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
275
+ return model
models/repvit.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import numpy as np
3
+ import itertools
4
+
5
+ def _make_divisible(v, divisor, min_value=None):
6
+ """
7
+ This function is taken from the original tf repo.
8
+ It ensures that all layers have a channel number that is divisible by 8
9
+ It can be seen here:
10
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
11
+ :param v:
12
+ :param divisor:
13
+ :param min_value:
14
+ :return:
15
+ """
16
+ if min_value is None:
17
+ min_value = divisor
18
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
19
+ # Make sure that round down does not go down by more than 10%.
20
+ if new_v < 0.9 * v:
21
+ new_v += divisor
22
+ return new_v
23
+
24
+ from timm.models.layers import SqueezeExcite
25
+
26
+ import torch
27
+
28
+ class Conv2d_BN(torch.nn.Sequential):
29
+ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
30
+ groups=1, bn_weight_init=1, resolution=-10000):
31
+ super().__init__()
32
+ self.add_module('c', torch.nn.Conv2d(
33
+ a, b, ks, stride, pad, dilation, groups, bias=False))
34
+ self.add_module('bn', torch.nn.BatchNorm2d(b))
35
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
36
+ torch.nn.init.constant_(self.bn.bias, 0)
37
+
38
+ @torch.no_grad()
39
+ def fuse(self):
40
+ c, bn = self._modules.values()
41
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
42
+ w = c.weight * w[:, None, None, None]
43
+ b = bn.bias - bn.running_mean * bn.weight / \
44
+ (bn.running_var + bn.eps)**0.5
45
+ m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
46
+ 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
47
+ device=c.weight.device)
48
+ m.weight.data.copy_(w)
49
+ m.bias.data.copy_(b)
50
+ return m
51
+
52
+ class Residual(torch.nn.Module):
53
+ def __init__(self, m, drop=0.):
54
+ super().__init__()
55
+ self.m = m
56
+ self.drop = drop
57
+
58
+ def forward(self, x):
59
+ if self.training and self.drop > 0:
60
+ return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
61
+ device=x.device).ge_(self.drop).div(1 - self.drop).detach()
62
+ else:
63
+ return x + self.m(x)
64
+
65
+ @torch.no_grad()
66
+ def fuse(self):
67
+ if isinstance(self.m, Conv2d_BN):
68
+ m = self.m.fuse()
69
+ assert(m.groups == m.in_channels)
70
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
71
+ identity = torch.nn.functional.pad(identity, [1,1,1,1])
72
+ m.weight += identity.to(m.weight.device)
73
+ return m
74
+ elif isinstance(self.m, torch.nn.Conv2d):
75
+ m = self.m
76
+ assert(m.groups != m.in_channels)
77
+ identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
78
+ identity = torch.nn.functional.pad(identity, [1,1,1,1])
79
+ m.weight += identity.to(m.weight.device)
80
+ return m
81
+ else:
82
+ return self
83
+
84
+
85
+ class RepVGGDW(torch.nn.Module):
86
+ def __init__(self, ed) -> None:
87
+ super().__init__()
88
+ self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
89
+ self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
90
+ self.dim = ed
91
+ self.bn = torch.nn.BatchNorm2d(ed)
92
+
93
+ def forward(self, x):
94
+ return self.bn((self.conv(x) + self.conv1(x)) + x)
95
+
96
+ @torch.no_grad()
97
+ def fuse(self):
98
+ conv = self.conv.fuse()
99
+ conv1 = self.conv1
100
+
101
+ conv_w = conv.weight
102
+ conv_b = conv.bias
103
+ conv1_w = conv1.weight
104
+ conv1_b = conv1.bias
105
+
106
+ conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
107
+
108
+ identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
109
+
110
+ final_conv_w = conv_w + conv1_w + identity
111
+ final_conv_b = conv_b + conv1_b
112
+
113
+ conv.weight.data.copy_(final_conv_w)
114
+ conv.bias.data.copy_(final_conv_b)
115
+
116
+ bn = self.bn
117
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
118
+ w = conv.weight * w[:, None, None, None]
119
+ b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
120
+ (bn.running_var + bn.eps)**0.5
121
+ conv.weight.data.copy_(w)
122
+ conv.bias.data.copy_(b)
123
+ return conv
124
+
125
+
126
+ class RepViTBlock(nn.Module):
127
+ def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
128
+ super(RepViTBlock, self).__init__()
129
+ assert stride in [1, 2]
130
+
131
+ self.identity = stride == 1 and inp == oup
132
+ assert(hidden_dim == 2 * inp)
133
+
134
+ if stride == 2:
135
+ self.token_mixer = nn.Sequential(
136
+ Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
137
+ SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
138
+ Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
139
+ )
140
+ self.channel_mixer = Residual(nn.Sequential(
141
+ # pw
142
+ Conv2d_BN(oup, 2 * oup, 1, 1, 0),
143
+ nn.GELU() if use_hs else nn.GELU(),
144
+ # pw-linear
145
+ Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
146
+ ))
147
+ else:
148
+ assert(self.identity)
149
+ self.token_mixer = nn.Sequential(
150
+ RepVGGDW(inp),
151
+ SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
152
+ )
153
+ self.channel_mixer = Residual(nn.Sequential(
154
+ # pw
155
+ Conv2d_BN(inp, hidden_dim, 1, 1, 0),
156
+ nn.GELU() if use_hs else nn.GELU(),
157
+ # pw-linear
158
+ Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
159
+ ))
160
+
161
+ def forward(self, x):
162
+ return self.channel_mixer(self.token_mixer(x))
163
+
164
+ from timm.models.vision_transformer import trunc_normal_
165
+ class BN_Linear(torch.nn.Sequential):
166
+ def __init__(self, a, b, bias=True, std=0.02):
167
+ super().__init__()
168
+ self.add_module('bn', torch.nn.BatchNorm1d(a))
169
+ self.add_module('l', torch.nn.Linear(a, b, bias=bias))
170
+ trunc_normal_(self.l.weight, std=std)
171
+ if bias:
172
+ torch.nn.init.constant_(self.l.bias, 0)
173
+
174
+ @torch.no_grad()
175
+ def fuse(self):
176
+ bn, l = self._modules.values()
177
+ w = bn.weight / (bn.running_var + bn.eps)**0.5
178
+ b = bn.bias - self.bn.running_mean * \
179
+ self.bn.weight / (bn.running_var + bn.eps)**0.5
180
+ w = l.weight * w[None, :]
181
+ if l.bias is None:
182
+ b = b @ self.l.weight.T
183
+ else:
184
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
185
+ m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
186
+ m.weight.data.copy_(w)
187
+ m.bias.data.copy_(b)
188
+ return m
189
+
190
+ class RepViT(nn.Module):
191
+ def __init__(self, cfgs, distillation=False, pretrained=None, init_cfg=None, out_indices=[]):
192
+ super(RepViT, self).__init__()
193
+ # setting of inverted residual blocks
194
+ self.cfgs = cfgs
195
+
196
+ # building first layer
197
+ input_channel = self.cfgs[0][2]
198
+ patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU() )
199
+ layers = [patch_embed]
200
+ patch_embed2 = torch.nn.Sequential(Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1), torch.nn.GELU())
201
+ layers.append(patch_embed2)
202
+
203
+ # building inverted residual blocks
204
+ block = RepViTBlock
205
+ for k, t, c, use_se, use_hs, s in self.cfgs:
206
+ output_channel = _make_divisible(c, 8)
207
+ exp_size = _make_divisible(input_channel * t, 8)
208
+ layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
209
+ input_channel = output_channel
210
+ self.features = nn.ModuleList(layers)
211
+ #
212
+ # self.init_cfg = init_cfg
213
+ # assert(self.init_cfg is not None)
214
+ self.out_indices = out_indices
215
+
216
+ #self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
217
+ self.train()
218
+ self.out_indices=[0,5,11, 37, 42]
219
+ # 320 160 80 40 20
220
+ def train(self, mode=True):
221
+ """Convert the model into training mode while keep layers freezed."""
222
+ super(RepViT, self).train(mode)
223
+
224
+ def forward(self, x):
225
+ outs = []
226
+ for i, f in enumerate(self.features):
227
+ x = f(x)
228
+ #print(x.shape)
229
+ if i in self.out_indices:
230
+ outs.append(x)
231
+ #print(x.shape)
232
+ # assert(len(outs) == 4)
233
+ return outs
234
+
235
+ from timm.models import register_model
236
+ def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
237
+ """
238
+ Constructs a MobileNetV3-Large model
239
+ """
240
+ cfgs = [
241
+ # k, t, c, SE, HS, s
242
+ [3, 2, 64, 1, 0, 1],
243
+ [3, 2, 64, 0, 0, 1],
244
+ [3, 2, 64, 0, 0, 1],
245
+ [3, 2, 128, 0, 0, 2],
246
+ [3, 2, 128, 1, 0, 1],
247
+ [3, 2, 128, 0, 0, 1],
248
+ [3, 2, 128, 0, 0, 1],
249
+ [3, 2, 256, 0, 1, 2],
250
+ [3, 2, 256, 1, 1, 1],
251
+ [3, 2, 256, 0, 1, 1],
252
+ [3, 2, 256, 1, 1, 1],
253
+ [3, 2, 256, 0, 1, 1],
254
+ [3, 2, 256, 1, 1, 1],
255
+ [3, 2, 256, 0, 1, 1],
256
+ [3, 2, 256, 1, 1, 1],
257
+ [3, 2, 256, 0, 1, 1],
258
+ [3, 2, 256, 1, 1, 1],
259
+ [3, 2, 256, 0, 1, 1],
260
+ [3, 2, 256, 1, 1, 1],
261
+ [3, 2, 256, 0, 1, 1],
262
+ [3, 2, 256, 0, 1, 1],
263
+ [3, 2, 512, 0, 1, 2],
264
+ [3, 2, 512, 1, 1, 1],
265
+ [3, 2, 512, 0, 1, 1]
266
+ ]
267
+ return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
268
+
269
+ def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
270
+ """
271
+ Constructs a MobileNetV3-Large model
272
+ """
273
+ cfgs = [
274
+ # k, t, c, SE, HS, s
275
+ [3, 2, 64, 1, 0, 1],
276
+ [3, 2, 64, 0, 0, 1],
277
+ [3, 2, 64, 1, 0, 1],
278
+ [3, 2, 64, 0, 0, 1],
279
+ [3, 2, 64, 0, 0, 1],
280
+ [3, 2, 128, 0, 0, 2],
281
+ [3, 2, 128, 1, 0, 1],
282
+ [3, 2, 128, 0, 0, 1],
283
+ [3, 2, 128, 1, 0, 1],
284
+ [3, 2, 128, 0, 0, 1],
285
+ [3, 2, 128, 0, 0, 1],
286
+ [3, 2, 256, 0, 1, 2],
287
+ [3, 2, 256, 1, 1, 1],
288
+ [3, 2, 256, 0, 1, 1],
289
+ [3, 2, 256, 1, 1, 1],
290
+ [3, 2, 256, 0, 1, 1],
291
+ [3, 2, 256, 1, 1, 1],
292
+ [3, 2, 256, 0, 1, 1],
293
+ [3, 2, 256, 1, 1, 1],
294
+ [3, 2, 256, 0, 1, 1],
295
+ [3, 2, 256, 1, 1, 1],
296
+ [3, 2, 256, 0, 1, 1],
297
+ [3, 2, 256, 1, 1, 1],
298
+ [3, 2, 256, 0, 1, 1],
299
+ [3, 2, 256, 1, 1, 1],
300
+ [3, 2, 256, 0, 1, 1],
301
+ [3, 2, 256, 1, 1, 1],
302
+ [3, 2, 256, 0, 1, 1],
303
+ [3, 2, 256, 1, 1, 1],
304
+ [3, 2, 256, 0, 1, 1],
305
+ [3, 2, 256, 1, 1, 1],
306
+ [3, 2, 256, 0, 1, 1],
307
+ [3, 2, 256, 1, 1, 1],
308
+ [3, 2, 256, 0, 1, 1],
309
+ [3, 2, 256, 1, 1, 1],
310
+ [3, 2, 256, 0, 1, 1],
311
+ [3, 2, 256, 0, 1, 1],
312
+ [3, 2, 512, 0, 1, 2],
313
+ [3, 2, 512, 1, 1, 1],
314
+ [3, 2, 512, 0, 1, 1],
315
+ [3, 2, 512, 1, 1, 1],
316
+ [3, 2, 512, 0, 1, 1]
317
+ ]
318
+ return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
319
+
320
+
321
+ def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
322
+ """
323
+ Constructs a MobileNetV3-Large model
324
+ """
325
+ cfgs = [
326
+ # k, t, c, SE, HS, s
327
+ [3, 2, 80, 1, 0, 1],
328
+ [3, 2, 80, 0, 0, 1],
329
+ [3, 2, 80, 1, 0, 1],
330
+ [3, 2, 80, 0, 0, 1],
331
+ [3, 2, 80, 1, 0, 1],
332
+ [3, 2, 80, 0, 0, 1],
333
+ [3, 2, 80, 0, 0, 1],
334
+ [3, 2, 160, 0, 0, 2],
335
+ [3, 2, 160, 1, 0, 1],
336
+ [3, 2, 160, 0, 0, 1],
337
+ [3, 2, 160, 1, 0, 1],
338
+ [3, 2, 160, 0, 0, 1],
339
+ [3, 2, 160, 1, 0, 1],
340
+ [3, 2, 160, 0, 0, 1],
341
+ [3, 2, 160, 0, 0, 1],
342
+ [3, 2, 320, 0, 1, 2],
343
+ [3, 2, 320, 1, 1, 1],
344
+ [3, 2, 320, 0, 1, 1],
345
+ [3, 2, 320, 1, 1, 1],
346
+ [3, 2, 320, 0, 1, 1],
347
+ [3, 2, 320, 1, 1, 1],
348
+ [3, 2, 320, 0, 1, 1],
349
+ [3, 2, 320, 1, 1, 1],
350
+ [3, 2, 320, 0, 1, 1],
351
+ [3, 2, 320, 1, 1, 1],
352
+ [3, 2, 320, 0, 1, 1],
353
+ [3, 2, 320, 1, 1, 1],
354
+ [3, 2, 320, 0, 1, 1],
355
+ [3, 2, 320, 1, 1, 1],
356
+ [3, 2, 320, 0, 1, 1],
357
+ [3, 2, 320, 1, 1, 1],
358
+ [3, 2, 320, 0, 1, 1],
359
+ [3, 2, 320, 1, 1, 1],
360
+ [3, 2, 320, 0, 1, 1],
361
+ [3, 2, 320, 1, 1, 1],
362
+ [3, 2, 320, 0, 1, 1],
363
+ [3, 2, 320, 1, 1, 1],
364
+ [3, 2, 320, 0, 1, 1],
365
+ [3, 2, 320, 1, 1, 1],
366
+ [3, 2, 320, 0, 1, 1],
367
+ [3, 2, 320, 1, 1, 1],
368
+ [3, 2, 320, 0, 1, 1],
369
+ [3, 2, 320, 1, 1, 1],
370
+ [3, 2, 320, 0, 1, 1],
371
+ [3, 2, 320, 1, 1, 1],
372
+ [3, 2, 320, 0, 1, 1],
373
+ [3, 2, 320, 1, 1, 1],
374
+ [3, 2, 320, 0, 1, 1],
375
+ [3, 2, 320, 1, 1, 1],
376
+ [3, 2, 320, 0, 1, 1],
377
+ # [3, 2, 320, 1, 1, 1],
378
+ # [3, 2, 320, 0, 1, 1],
379
+ [3, 2, 320, 0, 1, 1],
380
+ [3, 2, 640, 0, 1, 2],
381
+ [3, 2, 640, 1, 1, 1],
382
+ [3, 2, 640, 0, 1, 1],
383
+ # [3, 2, 640, 1, 1, 1],
384
+ # [3, 2, 640, 0, 1, 1]
385
+ ]
386
+ return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
387
+
388
+
389
+
390
+
391
+ cfgs = [
392
+ # k, t, c, SE, HS, s
393
+ [3, 2, 64*2, 1, 0, 1],
394
+ [3, 2, 64*2, 0, 0, 1],
395
+ [3, 2, 64*2, 1, 0, 1],
396
+ [3, 2, 64*2, 0, 0, 1],
397
+ [3, 2, 64*2, 0, 0, 1],
398
+ [3, 2, 128*2, 0, 0, 2],
399
+ [3, 2, 128*2, 1, 0, 1],
400
+ [3, 2, 128*2, 0, 0, 1],
401
+ [3, 2, 128*2, 1, 0, 1],
402
+ [3, 2, 128*2, 0, 0, 1],
403
+ [3, 2, 128*2, 0, 0, 1],
404
+ [3, 2, 256*2, 0, 1, 2],
405
+ [3, 2, 256*2, 1, 1, 1],
406
+ [3, 2, 256*2, 0, 1, 1],
407
+ [3, 2, 256*2, 1, 1, 1],
408
+ [3, 2, 256*2, 0, 1, 1],
409
+ [3, 2, 256*2, 1, 1, 1],
410
+ [3, 2, 256*2, 0, 1, 1],
411
+ [3, 2, 256*2, 1, 1, 1],
412
+ [3, 2, 256*2, 0, 1, 1],
413
+ [3, 2, 256*2, 1, 1, 1],
414
+ [3, 2, 256*2, 0, 1, 1],
415
+ [3, 2, 256*2, 1, 1, 1],
416
+ [3, 2, 256*2, 0, 1, 1],
417
+ [3, 2, 256*2, 1, 1, 1],
418
+ [3, 2, 256*2, 0, 1, 1],
419
+ [3, 2, 256*2, 1, 1, 1],
420
+ [3, 2, 256*2, 0, 1, 1],
421
+ [3, 2, 256*2, 1, 1, 1],
422
+ [3, 2, 256*2, 0, 1, 1],
423
+ [3, 2, 256*2, 1, 1, 1],
424
+ [3, 2, 256*2, 0, 1, 1],
425
+ [3, 2, 256*2, 1, 1, 1],
426
+ [3, 2, 256*2, 0, 1, 1],
427
+ [3, 2, 256*2, 1, 1, 1],
428
+ [3, 2, 256*2, 0, 1, 1],
429
+ [3, 2, 256*2, 0, 1, 1],
430
+ [3, 2, 512*2, 0, 1, 2],
431
+ [3, 2, 512*2, 1, 1, 1],
432
+ [3, 2, 512*2, 0, 1, 1],
433
+ [3, 2, 512*2, 1, 1, 1],
434
+ [3, 2, 512*2, 0, 1, 1]
435
+ ]
436
+
437
+ if __name__ =="__main__":
438
+ model = RepViT(cfgs )
439
+ t1 = torch.rand(1,3,640,640)
440
+ x = model(t1)
models/tf.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+
6
+ FILE = Path(__file__).resolve()
7
+ ROOT = FILE.parents[1] # YOLO root directory
8
+ if str(ROOT) not in sys.path:
9
+ sys.path.append(str(ROOT)) # add ROOT to PATH
10
+ # ROOT = ROOT.relative_to(Path.cwd()) # relative
11
+
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ import torch
15
+ import torch.nn as nn
16
+ from tensorflow import keras
17
+
18
+ from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
19
+ DWConvTranspose2d, Focus, autopad)
20
+ from models.experimental import MixConv2d, attempt_load
21
+ from models.yolo import Detect, Segment
22
+ from utils.activations import SiLU
23
+ from utils.general import LOGGER, make_divisible, print_args
24
+
25
+
26
+ class TFBN(keras.layers.Layer):
27
+ # TensorFlow BatchNormalization wrapper
28
+ def __init__(self, w=None):
29
+ super().__init__()
30
+ self.bn = keras.layers.BatchNormalization(
31
+ beta_initializer=keras.initializers.Constant(w.bias.numpy()),
32
+ gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
33
+ moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
34
+ moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
35
+ epsilon=w.eps)
36
+
37
+ def call(self, inputs):
38
+ return self.bn(inputs)
39
+
40
+
41
+ class TFPad(keras.layers.Layer):
42
+ # Pad inputs in spatial dimensions 1 and 2
43
+ def __init__(self, pad):
44
+ super().__init__()
45
+ if isinstance(pad, int):
46
+ self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
47
+ else: # tuple/list
48
+ self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])
49
+
50
+ def call(self, inputs):
51
+ return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
52
+
53
+
54
+ class TFConv(keras.layers.Layer):
55
+ # Standard convolution
56
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
57
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
58
+ super().__init__()
59
+ assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
60
+ # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
61
+ # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
62
+ conv = keras.layers.Conv2D(
63
+ filters=c2,
64
+ kernel_size=k,
65
+ strides=s,
66
+ padding='SAME' if s == 1 else 'VALID',
67
+ use_bias=not hasattr(w, 'bn'),
68
+ kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
69
+ bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
70
+ self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
71
+ self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
72
+ self.act = activations(w.act) if act else tf.identity
73
+
74
+ def call(self, inputs):
75
+ return self.act(self.bn(self.conv(inputs)))
76
+
77
+
78
+ class TFDWConv(keras.layers.Layer):
79
+ # Depthwise convolution
80
+ def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
81
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
82
+ super().__init__()
83
+ assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels'
84
+ conv = keras.layers.DepthwiseConv2D(
85
+ kernel_size=k,
86
+ depth_multiplier=c2 // c1,
87
+ strides=s,
88
+ padding='SAME' if s == 1 else 'VALID',
89
+ use_bias=not hasattr(w, 'bn'),
90
+ depthwise_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
91
+ bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
92
+ self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
93
+ self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
94
+ self.act = activations(w.act) if act else tf.identity
95
+
96
+ def call(self, inputs):
97
+ return self.act(self.bn(self.conv(inputs)))
98
+
99
+
100
+ class TFDWConvTranspose2d(keras.layers.Layer):
101
+ # Depthwise ConvTranspose2d
102
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
103
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
104
+ super().__init__()
105
+ assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
106
+ assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
107
+ weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
108
+ self.c1 = c1
109
+ self.conv = [
110
+ keras.layers.Conv2DTranspose(filters=1,
111
+ kernel_size=k,
112
+ strides=s,
113
+ padding='VALID',
114
+ output_padding=p2,
115
+ use_bias=True,
116
+ kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
117
+ bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]
118
+
119
+ def call(self, inputs):
120
+ return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
121
+
122
+
123
+ class TFFocus(keras.layers.Layer):
124
+ # Focus wh information into c-space
125
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
126
+ # ch_in, ch_out, kernel, stride, padding, groups
127
+ super().__init__()
128
+ self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
129
+
130
+ def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
131
+ # inputs = inputs / 255 # normalize 0-255 to 0-1
132
+ inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
133
+ return self.conv(tf.concat(inputs, 3))
134
+
135
+
136
+ class TFBottleneck(keras.layers.Layer):
137
+ # Standard bottleneck
138
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
139
+ super().__init__()
140
+ c_ = int(c2 * e) # hidden channels
141
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
142
+ self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
143
+ self.add = shortcut and c1 == c2
144
+
145
+ def call(self, inputs):
146
+ return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
147
+
148
+
149
+ class TFCrossConv(keras.layers.Layer):
150
+ # Cross Convolution
151
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
152
+ super().__init__()
153
+ c_ = int(c2 * e) # hidden channels
154
+ self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
155
+ self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
156
+ self.add = shortcut and c1 == c2
157
+
158
+ def call(self, inputs):
159
+ return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
160
+
161
+
162
+ class TFConv2d(keras.layers.Layer):
163
+ # Substitution for PyTorch nn.Conv2D
164
+ def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
165
+ super().__init__()
166
+ assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
167
+ self.conv = keras.layers.Conv2D(filters=c2,
168
+ kernel_size=k,
169
+ strides=s,
170
+ padding='VALID',
171
+ use_bias=bias,
172
+ kernel_initializer=keras.initializers.Constant(
173
+ w.weight.permute(2, 3, 1, 0).numpy()),
174
+ bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)
175
+
176
+ def call(self, inputs):
177
+ return self.conv(inputs)
178
+
179
+
180
+ class TFBottleneckCSP(keras.layers.Layer):
181
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
182
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
183
+ # ch_in, ch_out, number, shortcut, groups, expansion
184
+ super().__init__()
185
+ c_ = int(c2 * e) # hidden channels
186
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
187
+ self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
188
+ self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
189
+ self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
190
+ self.bn = TFBN(w.bn)
191
+ self.act = lambda x: keras.activations.swish(x)
192
+ self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
193
+
194
+ def call(self, inputs):
195
+ y1 = self.cv3(self.m(self.cv1(inputs)))
196
+ y2 = self.cv2(inputs)
197
+ return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
198
+
199
+
200
+ class TFC3(keras.layers.Layer):
201
+ # CSP Bottleneck with 3 convolutions
202
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
203
+ # ch_in, ch_out, number, shortcut, groups, expansion
204
+ super().__init__()
205
+ c_ = int(c2 * e) # hidden channels
206
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
207
+ self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
208
+ self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
209
+ self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
210
+
211
+ def call(self, inputs):
212
+ return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
213
+
214
+
215
+ class TFC3x(keras.layers.Layer):
216
+ # 3 module with cross-convolutions
217
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
218
+ # ch_in, ch_out, number, shortcut, groups, expansion
219
+ super().__init__()
220
+ c_ = int(c2 * e) # hidden channels
221
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
222
+ self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
223
+ self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
224
+ self.m = keras.Sequential([
225
+ TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)])
226
+
227
+ def call(self, inputs):
228
+ return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
229
+
230
+
231
+ class TFSPP(keras.layers.Layer):
232
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
233
+ def __init__(self, c1, c2, k=(5, 9, 13), w=None):
234
+ super().__init__()
235
+ c_ = c1 // 2 # hidden channels
236
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
237
+ self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
238
+ self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
239
+
240
+ def call(self, inputs):
241
+ x = self.cv1(inputs)
242
+ return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
243
+
244
+
245
+ class TFSPPF(keras.layers.Layer):
246
+ # Spatial pyramid pooling-Fast layer
247
+ def __init__(self, c1, c2, k=5, w=None):
248
+ super().__init__()
249
+ c_ = c1 // 2 # hidden channels
250
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
251
+ self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
252
+ self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
253
+
254
+ def call(self, inputs):
255
+ x = self.cv1(inputs)
256
+ y1 = self.m(x)
257
+ y2 = self.m(y1)
258
+ return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
259
+
260
+
261
+ class TFDetect(keras.layers.Layer):
262
+ # TF YOLO Detect layer
263
+ def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
264
+ super().__init__()
265
+ self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
266
+ self.nc = nc # number of classes
267
+ self.no = nc + 5 # number of outputs per anchor
268
+ self.nl = len(anchors) # number of detection layers
269
+ self.na = len(anchors[0]) // 2 # number of anchors
270
+ self.grid = [tf.zeros(1)] * self.nl # init grid
271
+ self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
272
+ self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
273
+ self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
274
+ self.training = False # set to False after building model
275
+ self.imgsz = imgsz
276
+ for i in range(self.nl):
277
+ ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
278
+ self.grid[i] = self._make_grid(nx, ny)
279
+
280
+ def call(self, inputs):
281
+ z = [] # inference output
282
+ x = []
283
+ for i in range(self.nl):
284
+ x.append(self.m[i](inputs[i]))
285
+ # x(bs,20,20,255) to x(bs,3,20,20,85)
286
+ ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
287
+ x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
288
+
289
+ if not self.training: # inference
290
+ y = x[i]
291
+ grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
292
+ anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
293
+ xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
294
+ wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
295
+ # Normalize xywh to 0-1 to reduce calibration error
296
+ xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
297
+ wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
298
+ y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
299
+ z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
300
+
301
+ return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1),)
302
+
303
+ @staticmethod
304
+ def _make_grid(nx=20, ny=20):
305
+ # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
306
+ # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
307
+ xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
308
+ return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
309
+
310
+
311
+ class TFSegment(TFDetect):
312
+ # YOLO Segment head for segmentation models
313
+ def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
314
+ super().__init__(nc, anchors, ch, imgsz, w)
315
+ self.nm = nm # number of masks
316
+ self.npr = npr # number of protos
317
+ self.no = 5 + nc + self.nm # number of outputs per anchor
318
+ self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv
319
+ self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos
320
+ self.detect = TFDetect.call
321
+
322
+ def call(self, x):
323
+ p = self.proto(x[0])
324
+ # p = TFUpsample(None, scale_factor=4, mode='nearest')(self.proto(x[0])) # (optional) full-size protos
325
+ p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160)
326
+ x = self.detect(self, x)
327
+ return (x, p) if self.training else (x[0], p)
328
+
329
+
330
+ class TFProto(keras.layers.Layer):
331
+
332
+ def __init__(self, c1, c_=256, c2=32, w=None):
333
+ super().__init__()
334
+ self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
335
+ self.upsample = TFUpsample(None, scale_factor=2, mode='nearest')
336
+ self.cv2 = TFConv(c_, c_, k=3, w=w.cv2)
337
+ self.cv3 = TFConv(c_, c2, w=w.cv3)
338
+
339
+ def call(self, inputs):
340
+ return self.cv3(self.cv2(self.upsample(self.cv1(inputs))))
341
+
342
+
343
+ class TFUpsample(keras.layers.Layer):
344
+ # TF version of torch.nn.Upsample()
345
+ def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
346
+ super().__init__()
347
+ assert scale_factor % 2 == 0, "scale_factor must be multiple of 2"
348
+ self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * scale_factor, x.shape[2] * scale_factor), mode)
349
+ # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
350
+ # with default arguments: align_corners=False, half_pixel_centers=False
351
+ # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
352
+ # size=(x.shape[1] * 2, x.shape[2] * 2))
353
+
354
+ def call(self, inputs):
355
+ return self.upsample(inputs)
356
+
357
+
358
+ class TFConcat(keras.layers.Layer):
359
+ # TF version of torch.concat()
360
+ def __init__(self, dimension=1, w=None):
361
+ super().__init__()
362
+ assert dimension == 1, "convert only NCHW to NHWC concat"
363
+ self.d = 3
364
+
365
+ def call(self, inputs):
366
+ return tf.concat(inputs, self.d)
367
+
368
+
369
+ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
370
+ LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
371
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
372
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
373
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
374
+
375
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
376
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
377
+ m_str = m
378
+ m = eval(m) if isinstance(m, str) else m # eval strings
379
+ for j, a in enumerate(args):
380
+ try:
381
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
382
+ except NameError:
383
+ pass
384
+
385
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
386
+ if m in [
387
+ nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
388
+ BottleneckCSP, C3, C3x]:
389
+ c1, c2 = ch[f], args[0]
390
+ c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
391
+
392
+ args = [c1, c2, *args[1:]]
393
+ if m in [BottleneckCSP, C3, C3x]:
394
+ args.insert(2, n)
395
+ n = 1
396
+ elif m is nn.BatchNorm2d:
397
+ args = [ch[f]]
398
+ elif m is Concat:
399
+ c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
400
+ elif m in [Detect, Segment]:
401
+ args.append([ch[x + 1] for x in f])
402
+ if isinstance(args[1], int): # number of anchors
403
+ args[1] = [list(range(args[1] * 2))] * len(f)
404
+ if m is Segment:
405
+ args[3] = make_divisible(args[3] * gw, 8)
406
+ args.append(imgsz)
407
+ else:
408
+ c2 = ch[f]
409
+
410
+ tf_m = eval('TF' + m_str.replace('nn.', ''))
411
+ m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
412
+ else tf_m(*args, w=model.model[i]) # module
413
+
414
+ torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
415
+ t = str(m)[8:-2].replace('__main__.', '') # module type
416
+ np = sum(x.numel() for x in torch_m_.parameters()) # number params
417
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
418
+ LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
419
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
420
+ layers.append(m_)
421
+ ch.append(c2)
422
+ return keras.Sequential(layers), sorted(save)
423
+
424
+
425
+ class TFModel:
426
+ # TF YOLO model
427
+ def __init__(self, cfg='yolo.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
428
+ super().__init__()
429
+ if isinstance(cfg, dict):
430
+ self.yaml = cfg # model dict
431
+ else: # is *.yaml
432
+ import yaml # for torch hub
433
+ self.yaml_file = Path(cfg).name
434
+ with open(cfg) as f:
435
+ self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
436
+
437
+ # Define model
438
+ if nc and nc != self.yaml['nc']:
439
+ LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
440
+ self.yaml['nc'] = nc # override yaml value
441
+ self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
442
+
443
+ def predict(self,
444
+ inputs,
445
+ tf_nms=False,
446
+ agnostic_nms=False,
447
+ topk_per_class=100,
448
+ topk_all=100,
449
+ iou_thres=0.45,
450
+ conf_thres=0.25):
451
+ y = [] # outputs
452
+ x = inputs
453
+ for m in self.model.layers:
454
+ if m.f != -1: # if not from previous layer
455
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
456
+
457
+ x = m(x) # run
458
+ y.append(x if m.i in self.savelist else None) # save output
459
+
460
+ # Add TensorFlow NMS
461
+ if tf_nms:
462
+ boxes = self._xywh2xyxy(x[0][..., :4])
463
+ probs = x[0][:, :, 4:5]
464
+ classes = x[0][:, :, 5:]
465
+ scores = probs * classes
466
+ if agnostic_nms:
467
+ nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
468
+ else:
469
+ boxes = tf.expand_dims(boxes, 2)
470
+ nms = tf.image.combined_non_max_suppression(boxes,
471
+ scores,
472
+ topk_per_class,
473
+ topk_all,
474
+ iou_thres,
475
+ conf_thres,
476
+ clip_boxes=False)
477
+ return (nms,)
478
+ return x # output [1,6300,85] = [xywh, conf, class0, class1, ...]
479
+ # x = x[0] # [x(1,6300,85), ...] to x(6300,85)
480
+ # xywh = x[..., :4] # x(6300,4) boxes
481
+ # conf = x[..., 4:5] # x(6300,1) confidences
482
+ # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
483
+ # return tf.concat([conf, cls, xywh], 1)
484
+
485
+ @staticmethod
486
+ def _xywh2xyxy(xywh):
487
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
488
+ x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
489
+ return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
490
+
491
+
492
+ class AgnosticNMS(keras.layers.Layer):
493
+ # TF Agnostic NMS
494
+ def call(self, input, topk_all, iou_thres, conf_thres):
495
+ # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
496
+ return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
497
+ input,
498
+ fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
499
+ name='agnostic_nms')
500
+
501
+ @staticmethod
502
+ def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
503
+ boxes, classes, scores = x
504
+ class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
505
+ scores_inp = tf.reduce_max(scores, -1)
506
+ selected_inds = tf.image.non_max_suppression(boxes,
507
+ scores_inp,
508
+ max_output_size=topk_all,
509
+ iou_threshold=iou_thres,
510
+ score_threshold=conf_thres)
511
+ selected_boxes = tf.gather(boxes, selected_inds)
512
+ padded_boxes = tf.pad(selected_boxes,
513
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
514
+ mode="CONSTANT",
515
+ constant_values=0.0)
516
+ selected_scores = tf.gather(scores_inp, selected_inds)
517
+ padded_scores = tf.pad(selected_scores,
518
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
519
+ mode="CONSTANT",
520
+ constant_values=-1.0)
521
+ selected_classes = tf.gather(class_inds, selected_inds)
522
+ padded_classes = tf.pad(selected_classes,
523
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
524
+ mode="CONSTANT",
525
+ constant_values=-1.0)
526
+ valid_detections = tf.shape(selected_inds)[0]
527
+ return padded_boxes, padded_scores, padded_classes, valid_detections
528
+
529
+
530
+ def activations(act=nn.SiLU):
531
+ # Returns TF activation from input PyTorch activation
532
+ if isinstance(act, nn.LeakyReLU):
533
+ return lambda x: keras.activations.relu(x, alpha=0.1)
534
+ elif isinstance(act, nn.Hardswish):
535
+ return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
536
+ elif isinstance(act, (nn.SiLU, SiLU)):
537
+ return lambda x: keras.activations.swish(x)
538
+ else:
539
+ raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}')
540
+
541
+
542
+ def representative_dataset_gen(dataset, ncalib=100):
543
+ # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
544
+ for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
545
+ im = np.transpose(img, [1, 2, 0])
546
+ im = np.expand_dims(im, axis=0).astype(np.float32)
547
+ im /= 255
548
+ yield [im]
549
+ if n >= ncalib:
550
+ break
551
+
552
+
553
+ def run(
554
+ weights=ROOT / 'yolo.pt', # weights path
555
+ imgsz=(640, 640), # inference size h,w
556
+ batch_size=1, # batch size
557
+ dynamic=False, # dynamic batch size
558
+ ):
559
+ # PyTorch model
560
+ im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
561
+ model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
562
+ _ = model(im) # inference
563
+ model.info()
564
+
565
+ # TensorFlow model
566
+ im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
567
+ tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
568
+ _ = tf_model.predict(im) # inference
569
+
570
+ # Keras model
571
+ im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
572
+ keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
573
+ keras_model.summary()
574
+
575
+ LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
576
+
577
+
578
+ def parse_opt():
579
+ parser = argparse.ArgumentParser()
580
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='weights path')
581
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
582
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
583
+ parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
584
+ opt = parser.parse_args()
585
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
586
+ print_args(vars(opt))
587
+ return opt
588
+
589
+
590
+ def main(opt):
591
+ run(**vars(opt))
592
+
593
+
594
+ if __name__ == "__main__":
595
+ opt = parse_opt()
596
+ main(opt)
models/yolo.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ FILE = Path(__file__).resolve()
9
+ ROOT = FILE.parents[1] # YOLO root directory
10
+ if str(ROOT) not in sys.path:
11
+ sys.path.append(str(ROOT)) # add ROOT to PATH
12
+ if platform.system() != 'Windows':
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import *
16
+ from models.experimental import *
17
+ from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
18
+ from utils.plots import feature_visualization
19
+ from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
20
+ time_sync)
21
+ from utils.tal.anchor_generator import make_anchors, dist2bbox
22
+
23
+ try:
24
+ import thop # for FLOPs computation
25
+ except ImportError:
26
+ thop = None
27
+
28
+
29
+ class Detect(nn.Module):
30
+ # YOLO Detect head for detection models
31
+ dynamic = False # force grid reconstruction
32
+ export = False # export mode
33
+ shape = None
34
+ anchors = torch.empty(0) # init
35
+ strides = torch.empty(0) # init
36
+
37
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
38
+ super().__init__()
39
+ self.nc = nc # number of classes
40
+ self.nl = len(ch) # number of detection layers
41
+ self.reg_max = 16
42
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
43
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
44
+ self.stride = torch.zeros(self.nl) # strides computed during build
45
+
46
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
47
+ self.cv2 = nn.ModuleList(
48
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
49
+ self.cv3 = nn.ModuleList(
50
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
51
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ shape = x[0].shape # BCHW
55
+ for i in range(self.nl):
56
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
57
+ if self.training:
58
+ return x
59
+ elif self.dynamic or self.shape != shape:
60
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
61
+ self.shape = shape
62
+
63
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
64
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
65
+ y = torch.cat((dbox, cls.sigmoid()), 1)
66
+ return y if self.export else (y, x)
67
+
68
+ def bias_init(self):
69
+ # Initialize Detect() biases, WARNING: requires stride availability
70
+ m = self # self.model[-1] # Detect() module
71
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
72
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
73
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
74
+ a[-1].bias.data[:] = 1.0 # box
75
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
76
+
77
+
78
+ class DDetect(nn.Module):
79
+ # YOLO Detect head for detection models
80
+ dynamic = False # force grid reconstruction
81
+ export = False # export mode
82
+ shape = None
83
+ anchors = torch.empty(0) # init
84
+ strides = torch.empty(0) # init
85
+
86
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
87
+ super().__init__()
88
+ self.nc = nc # number of classes
89
+ self.nl = len(ch) # number of detection layers
90
+ self.reg_max = 16
91
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
92
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
93
+ self.stride = torch.zeros(self.nl) # strides computed during build
94
+
95
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
96
+ self.cv2 = nn.ModuleList(
97
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch)
98
+ self.cv3 = nn.ModuleList(
99
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
100
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
101
+
102
+ def forward(self, x):
103
+ shape = x[0].shape # BCHW
104
+ for i in range(self.nl):
105
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
106
+ if self.training:
107
+ return x
108
+ elif self.dynamic or self.shape != shape:
109
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
110
+ self.shape = shape
111
+
112
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
113
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
114
+ y = torch.cat((dbox, cls.sigmoid()), 1)
115
+ return y if self.export else (y, x)
116
+
117
+ def bias_init(self):
118
+ # Initialize Detect() biases, WARNING: requires stride availability
119
+ m = self # self.model[-1] # Detect() module
120
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
121
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
122
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
123
+ a[-1].bias.data[:] = 1.0 # box
124
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
125
+
126
+
127
+ class DualDetect(nn.Module):
128
+ # YOLO Detect head for detection models
129
+ dynamic = False # force grid reconstruction
130
+ export = False # export mode
131
+ shape = None
132
+ anchors = torch.empty(0) # init
133
+ strides = torch.empty(0) # init
134
+
135
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
136
+ super().__init__()
137
+ self.nc = nc # number of classes
138
+ self.nl = len(ch) // 2 # number of detection layers
139
+ self.reg_max = 16
140
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
141
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
142
+ self.stride = torch.zeros(self.nl) # strides computed during build
143
+
144
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
145
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
146
+ self.cv2 = nn.ModuleList(
147
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
148
+ self.cv3 = nn.ModuleList(
149
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
150
+ self.cv4 = nn.ModuleList(
151
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:])
152
+ self.cv5 = nn.ModuleList(
153
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
154
+ self.dfl = DFL(self.reg_max)
155
+ self.dfl2 = DFL(self.reg_max)
156
+
157
+ def forward(self, x):
158
+ shape = x[0].shape # BCHW
159
+ d1 = []
160
+ d2 = []
161
+ for i in range(self.nl):
162
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
163
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
164
+ if self.training:
165
+ return [d1, d2]
166
+ elif self.dynamic or self.shape != shape:
167
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
168
+ self.shape = shape
169
+
170
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
171
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
172
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
173
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
174
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
175
+ return y if self.export else (y, [d1, d2])
176
+
177
+ def bias_init(self):
178
+ # Initialize Detect() biases, WARNING: requires stride availability
179
+ m = self # self.model[-1] # Detect() module
180
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
181
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
182
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
183
+ a[-1].bias.data[:] = 1.0 # box
184
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
185
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
186
+ a[-1].bias.data[:] = 1.0 # box
187
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
188
+
189
+
190
+ class DualDDetect(nn.Module):
191
+ # YOLO Detect head for detection models
192
+ dynamic = False # force grid reconstruction
193
+ export = False # export mode
194
+ shape = None
195
+ anchors = torch.empty(0) # init
196
+ strides = torch.empty(0) # init
197
+
198
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
199
+ super().__init__()
200
+ self.nc = nc # number of classes
201
+ self.nl = len(ch) // 2 # number of detection layers
202
+ self.reg_max = 16
203
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
204
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
205
+ self.stride = torch.zeros(self.nl) # strides computed during build
206
+
207
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
208
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
209
+ self.cv2 = nn.ModuleList(
210
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
211
+ self.cv3 = nn.ModuleList(
212
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
213
+ self.cv4 = nn.ModuleList(
214
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4), nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:])
215
+ self.cv5 = nn.ModuleList(
216
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
217
+ self.dfl = DFL(self.reg_max)
218
+ self.dfl2 = DFL(self.reg_max)
219
+
220
+ def forward(self, x):
221
+ shape = x[0].shape # BCHW
222
+ d1 = []
223
+ d2 = []
224
+ for i in range(self.nl):
225
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
226
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
227
+ if self.training:
228
+ return [d1, d2]
229
+ elif self.dynamic or self.shape != shape:
230
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
231
+ self.shape = shape
232
+
233
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
234
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
235
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
236
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
237
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
238
+ return y if self.export else (y, [d1, d2])
239
+ #y = torch.cat((dbox2, cls2.sigmoid()), 1)
240
+ #return y if self.export else (y, d2)
241
+ #y1 = torch.cat((dbox, cls.sigmoid()), 1)
242
+ #y2 = torch.cat((dbox2, cls2.sigmoid()), 1)
243
+ #return [y1, y2] if self.export else [(y1, d1), (y2, d2)]
244
+ #return [y1, y2] if self.export else [(y1, y2), (d1, d2)]
245
+
246
+ def bias_init(self):
247
+ # Initialize Detect() biases, WARNING: requires stride availability
248
+ m = self # self.model[-1] # Detect() module
249
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
250
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
251
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
252
+ a[-1].bias.data[:] = 1.0 # box
253
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
254
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
255
+ a[-1].bias.data[:] = 1.0 # box
256
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
257
+
258
+
259
+ class TripleDetect(nn.Module):
260
+ # YOLO Detect head for detection models
261
+ dynamic = False # force grid reconstruction
262
+ export = False # export mode
263
+ shape = None
264
+ anchors = torch.empty(0) # init
265
+ strides = torch.empty(0) # init
266
+
267
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
268
+ super().__init__()
269
+ self.nc = nc # number of classes
270
+ self.nl = len(ch) // 3 # number of detection layers
271
+ self.reg_max = 16
272
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
273
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
274
+ self.stride = torch.zeros(self.nl) # strides computed during build
275
+
276
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
277
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
278
+ c6, c7 = max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
279
+ self.cv2 = nn.ModuleList(
280
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
281
+ self.cv3 = nn.ModuleList(
282
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
283
+ self.cv4 = nn.ModuleList(
284
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:self.nl*2])
285
+ self.cv5 = nn.ModuleList(
286
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
287
+ self.cv6 = nn.ModuleList(
288
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, 4 * self.reg_max, 1)) for x in ch[self.nl*2:self.nl*3])
289
+ self.cv7 = nn.ModuleList(
290
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
291
+ self.dfl = DFL(self.reg_max)
292
+ self.dfl2 = DFL(self.reg_max)
293
+ self.dfl3 = DFL(self.reg_max)
294
+
295
+ def forward(self, x):
296
+ shape = x[0].shape # BCHW
297
+ d1 = []
298
+ d2 = []
299
+ d3 = []
300
+ for i in range(self.nl):
301
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
302
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
303
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
304
+ if self.training:
305
+ return [d1, d2, d3]
306
+ elif self.dynamic or self.shape != shape:
307
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
308
+ self.shape = shape
309
+
310
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
311
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
312
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
313
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
314
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
315
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
316
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
317
+ return y if self.export else (y, [d1, d2, d3])
318
+
319
+ def bias_init(self):
320
+ # Initialize Detect() biases, WARNING: requires stride availability
321
+ m = self # self.model[-1] # Detect() module
322
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
323
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
324
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
325
+ a[-1].bias.data[:] = 1.0 # box
326
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
327
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
328
+ a[-1].bias.data[:] = 1.0 # box
329
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
330
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
331
+ a[-1].bias.data[:] = 1.0 # box
332
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
333
+
334
+
335
+ class TripleDDetect(nn.Module):
336
+ # YOLO Detect head for detection models
337
+ dynamic = False # force grid reconstruction
338
+ export = False # export mode
339
+ shape = None
340
+ anchors = torch.empty(0) # init
341
+ strides = torch.empty(0) # init
342
+
343
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
344
+ super().__init__()
345
+ self.nc = nc # number of classes
346
+ self.nl = len(ch) // 3 # number of detection layers
347
+ self.reg_max = 16
348
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
349
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
350
+ self.stride = torch.zeros(self.nl) # strides computed during build
351
+
352
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), \
353
+ max((ch[0], min((self.nc * 2, 128)))) # channels
354
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), \
355
+ max((ch[self.nl], min((self.nc * 2, 128)))) # channels
356
+ c6, c7 = make_divisible(max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), 4), \
357
+ max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
358
+ self.cv2 = nn.ModuleList(
359
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4),
360
+ nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
361
+ self.cv3 = nn.ModuleList(
362
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
363
+ self.cv4 = nn.ModuleList(
364
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4),
365
+ nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:self.nl*2])
366
+ self.cv5 = nn.ModuleList(
367
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
368
+ self.cv6 = nn.ModuleList(
369
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3, g=4),
370
+ nn.Conv2d(c6, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl*2:self.nl*3])
371
+ self.cv7 = nn.ModuleList(
372
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
373
+ self.dfl = DFL(self.reg_max)
374
+ self.dfl2 = DFL(self.reg_max)
375
+ self.dfl3 = DFL(self.reg_max)
376
+
377
+ def forward(self, x):
378
+ shape = x[0].shape # BCHW
379
+ d1 = []
380
+ d2 = []
381
+ d3 = []
382
+ for i in range(self.nl):
383
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
384
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
385
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
386
+ if self.training:
387
+ return [d1, d2, d3]
388
+ elif self.dynamic or self.shape != shape:
389
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
390
+ self.shape = shape
391
+
392
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
393
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
394
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
395
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
396
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
397
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
398
+ #y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
399
+ #return y if self.export else (y, [d1, d2, d3])
400
+ y = torch.cat((dbox3, cls3.sigmoid()), 1)
401
+ return y if self.export else (y, d3)
402
+
403
+ def bias_init(self):
404
+ # Initialize Detect() biases, WARNING: requires stride availability
405
+ m = self # self.model[-1] # Detect() module
406
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
407
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
408
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
409
+ a[-1].bias.data[:] = 1.0 # box
410
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
411
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
412
+ a[-1].bias.data[:] = 1.0 # box
413
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
414
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
415
+ a[-1].bias.data[:] = 1.0 # box
416
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
417
+
418
+
419
+ class Segment(Detect):
420
+ # YOLO Segment head for segmentation models
421
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
422
+ super().__init__(nc, ch, inplace)
423
+ self.nm = nm # number of masks
424
+ self.npr = npr # number of protos
425
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
426
+ self.detect = Detect.forward
427
+
428
+ c4 = max(ch[0] // 4, self.nm)
429
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
430
+
431
+ def forward(self, x):
432
+ p = self.proto(x[0])
433
+ bs = p.shape[0]
434
+
435
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
436
+ x = self.detect(self, x)
437
+ if self.training:
438
+ return x, mc, p
439
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
440
+
441
+
442
+ class Panoptic(Detect):
443
+ # YOLO Panoptic head for panoptic segmentation models
444
+ def __init__(self, nc=80, sem_nc=93, nm=32, npr=256, ch=(), inplace=True):
445
+ super().__init__(nc, ch, inplace)
446
+ self.sem_nc = sem_nc
447
+ self.nm = nm # number of masks
448
+ self.npr = npr # number of protos
449
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
450
+ self.uconv = UConv(ch[0], ch[0]//4, self.sem_nc+self.nc)
451
+ self.detect = Detect.forward
452
+
453
+ c4 = max(ch[0] // 4, self.nm)
454
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
455
+
456
+
457
+ def forward(self, x):
458
+ p = self.proto(x[0])
459
+ s = self.uconv(x[0])
460
+ bs = p.shape[0]
461
+
462
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
463
+ x = self.detect(self, x)
464
+ if self.training:
465
+ return x, mc, p, s
466
+ return (torch.cat([x, mc], 1), p, s) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p, s))
467
+
468
+
469
+ class BaseModel(nn.Module):
470
+ # YOLO base model
471
+ def forward(self, x, profile=False, visualize=False):
472
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
473
+
474
+ def _forward_once(self, x, profile=False, visualize=False):
475
+ y, dt = [], [] # outputs
476
+ for m in self.model:
477
+ if m.f != -1: # if not from previous layer
478
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
479
+ if profile:
480
+ self._profile_one_layer(m, x, dt)
481
+ #print(m)
482
+ x = m(x) # run
483
+ y.append(x if m.i in self.save else None) # save output
484
+ if visualize:
485
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
486
+ return x
487
+
488
+ def _profile_one_layer(self, m, x, dt):
489
+ c = m == self.model[-1] # is final layer, copy input as inplace fix
490
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
491
+ t = time_sync()
492
+ for _ in range(10):
493
+ m(x.copy() if c else x)
494
+ dt.append((time_sync() - t) * 100)
495
+ if m == self.model[0]:
496
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
497
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
498
+ if c:
499
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
500
+
501
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
502
+ LOGGER.info('Fusing layers... ')
503
+ for m in self.model.modules():
504
+ if isinstance(m, (RepConvN)) and hasattr(m, 'fuse_convs'):
505
+ m.fuse_convs()
506
+ m.forward = m.forward_fuse # update forward
507
+ if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
508
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
509
+ delattr(m, 'bn') # remove batchnorm
510
+ m.forward = m.forward_fuse # update forward
511
+ self.info()
512
+ return self
513
+
514
+ def info(self, verbose=False, img_size=640): # print model information
515
+ model_info(self, verbose, img_size)
516
+
517
+ def _apply(self, fn):
518
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
519
+ self = super()._apply(fn)
520
+ m = self.model[-1] # Detect()
521
+ if isinstance(m, (Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, Panoptic)):
522
+ m.stride = fn(m.stride)
523
+ m.anchors = fn(m.anchors)
524
+ m.strides = fn(m.strides)
525
+ # m.grid = list(map(fn, m.grid))
526
+ return self
527
+
528
+
529
+ class DetectionModel(BaseModel):
530
+ # YOLO detection model
531
+ def __init__(self, cfg='yolo.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
532
+ super().__init__()
533
+ if isinstance(cfg, dict):
534
+ self.yaml = cfg # model dict
535
+ else: # is *.yaml
536
+ import yaml # for torch hub
537
+ self.yaml_file = Path(cfg).name
538
+ with open(cfg, encoding='ascii', errors='ignore') as f:
539
+ self.yaml = yaml.safe_load(f) # model dict
540
+
541
+ # Define model
542
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
543
+ if nc and nc != self.yaml['nc']:
544
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
545
+ self.yaml['nc'] = nc # override yaml value
546
+ if anchors:
547
+ LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
548
+ self.yaml['anchors'] = round(anchors) # override yaml value
549
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
550
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
551
+ self.inplace = self.yaml.get('inplace', True)
552
+
553
+ # Build strides, anchors
554
+ m = self.model[-1] # Detect()
555
+ if isinstance(m, (Detect, DDetect, Segment, Panoptic)):
556
+ s = 256 # 2x min stride
557
+ m.inplace = self.inplace
558
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Panoptic)) else self.forward(x)
559
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
560
+ # check_anchor_order(m)
561
+ # m.anchors /= m.stride.view(-1, 1, 1)
562
+ self.stride = m.stride
563
+ m.bias_init() # only run once
564
+ if isinstance(m, (DualDetect, TripleDetect, DualDDetect, TripleDDetect)):
565
+ s = 256 # 2x min stride
566
+ m.inplace = self.inplace
567
+ #forward = lambda x: self.forward(x)[0][0] if isinstance(m, (DualSegment, DualPanoptic)) else self.forward(x)[0]
568
+ forward = lambda x: self.forward(x)[0]
569
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
570
+ # check_anchor_order(m)
571
+ # m.anchors /= m.stride.view(-1, 1, 1)
572
+ self.stride = m.stride
573
+ m.bias_init() # only run once
574
+
575
+ # Init weights, biases
576
+ initialize_weights(self)
577
+ self.info()
578
+ LOGGER.info('')
579
+
580
+ def forward(self, x, augment=False, profile=False, visualize=False):
581
+ if augment:
582
+ return self._forward_augment(x) # augmented inference, None
583
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
584
+
585
+ def _forward_augment(self, x):
586
+ img_size = x.shape[-2:] # height, width
587
+ s = [1, 0.83, 0.67] # scales
588
+ f = [None, 3, None] # flips (2-ud, 3-lr)
589
+ y = [] # outputs
590
+ for si, fi in zip(s, f):
591
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
592
+ yi = self._forward_once(xi)[0] # forward
593
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
594
+ yi = self._descale_pred(yi, fi, si, img_size)
595
+ y.append(yi)
596
+ y = self._clip_augmented(y) # clip augmented tails
597
+ return torch.cat(y, 1), None # augmented inference, train
598
+
599
+ def _descale_pred(self, p, flips, scale, img_size):
600
+ # de-scale predictions following augmented inference (inverse operation)
601
+ if self.inplace:
602
+ p[..., :4] /= scale # de-scale
603
+ if flips == 2:
604
+ p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
605
+ elif flips == 3:
606
+ p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
607
+ else:
608
+ x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
609
+ if flips == 2:
610
+ y = img_size[0] - y # de-flip ud
611
+ elif flips == 3:
612
+ x = img_size[1] - x # de-flip lr
613
+ p = torch.cat((x, y, wh, p[..., 4:]), -1)
614
+ return p
615
+
616
+ def _clip_augmented(self, y):
617
+ # Clip YOLO augmented inference tails
618
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
619
+ g = sum(4 ** x for x in range(nl)) # grid points
620
+ e = 1 # exclude layer count
621
+ i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
622
+ y[0] = y[0][:, :-i] # large
623
+ i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
624
+ y[-1] = y[-1][:, i:] # small
625
+ return y
626
+
627
+
628
+ Model = DetectionModel # retain YOLO 'Model' class for backwards compatibility
629
+
630
+
631
+ class SegmentationModel(DetectionModel):
632
+ # YOLO segmentation model
633
+ def __init__(self, cfg='yolo-seg.yaml', ch=3, nc=None, anchors=None):
634
+ super().__init__(cfg, ch, nc, anchors)
635
+
636
+
637
+ class ClassificationModel(BaseModel):
638
+ # YOLO classification model
639
+ def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
640
+ super().__init__()
641
+ self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
642
+
643
+ def _from_detection_model(self, model, nc=1000, cutoff=10):
644
+ # Create a YOLO classification model from a YOLO detection model
645
+ if isinstance(model, DetectMultiBackend):
646
+ model = model.model # unwrap DetectMultiBackend
647
+ model.model = model.model[:cutoff] # backbone
648
+ m = model.model[-1] # last layer
649
+ ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
650
+ c = Classify(ch, nc) # Classify()
651
+ c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
652
+ model.model[-1] = c # replace
653
+ self.model = model.model
654
+ self.stride = model.stride
655
+ self.save = []
656
+ self.nc = nc
657
+
658
+ def _from_yaml(self, cfg):
659
+ # Create a YOLO classification model from a *.yaml file
660
+ self.model = None
661
+
662
+
663
+ def parse_model(d, ch): # model_dict, input_channels(3)
664
+ # Parse a YOLO model.yaml dictionary
665
+ LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
666
+ anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
667
+ if act:
668
+ Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
669
+ RepConvN.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
670
+ LOGGER.info(f"{colorstr('activation:')} {act}") # print
671
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
672
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
673
+
674
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
675
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
676
+ m = eval(m) if isinstance(m, str) else m # eval strings
677
+ for j, a in enumerate(args):
678
+ with contextlib.suppress(NameError):
679
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
680
+
681
+ n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
682
+ if m in {
683
+ Conv, AConv, ConvTranspose,
684
+ Bottleneck, SPP, SPPF, DWConv, BottleneckCSP, nn.ConvTranspose2d, DWConvTranspose2d, SPPCSPC, ADown,
685
+ RepNCSPELAN4, SPPELAN}:
686
+ c1, c2 = ch[f], args[0]
687
+ if c2 != no: # if not output
688
+ c2 = make_divisible(c2 * gw, 8)
689
+
690
+ args = [c1, c2, *args[1:]]
691
+ if m in {BottleneckCSP, SPPCSPC}:
692
+ args.insert(2, n) # number of repeats
693
+ n = 1
694
+ elif m is nn.BatchNorm2d:
695
+ args = [ch[f]]
696
+ elif m in [Down0,Down1,Down2,Down3,Down4]:
697
+ c2 = args[0]
698
+
699
+ elif m is Concat:
700
+ c2 = sum(ch[x] for x in f)
701
+ elif m is Shortcut:
702
+ c2 = ch[f[0]]
703
+ elif m is ReOrg:
704
+ c2 = ch[f] * 4
705
+ elif m is CBLinear:
706
+ c2 = args[0]
707
+ c1 = ch[f]
708
+ args = [c1, c2, *args[1:]]
709
+ elif m is CBFuse:
710
+ c2 = ch[f[-1]]
711
+ # TODO: channel, gw, gd
712
+ elif m in {Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, Panoptic}:
713
+ args.append([ch[x] for x in f])
714
+ # if isinstance(args[1], int): # number of anchors
715
+ # args[1] = [list(range(args[1] * 2))] * len(f)
716
+ if m in {Segment, Panoptic}:
717
+ args[2] = make_divisible(args[2] * gw, 8)
718
+ elif m is Contract:
719
+ c2 = ch[f] * args[0] ** 2
720
+ elif m is Expand:
721
+ c2 = ch[f] // args[0] ** 2
722
+ else:
723
+ c2 = ch[f]
724
+
725
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
726
+ t = str(m)[8:-2].replace('__main__.', '') # module type
727
+ np = sum(x.numel() for x in m_.parameters()) # number params
728
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
729
+ LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
730
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
731
+ layers.append(m_)
732
+ if i == 0:
733
+ ch = []
734
+ ch.append(c2)
735
+ return nn.Sequential(*layers), sorted(save)
736
+
737
+
738
+ if __name__ == '__main__':
739
+ parser = argparse.ArgumentParser()
740
+ parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml')
741
+ parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
742
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
743
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
744
+ parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
745
+ parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
746
+ opt = parser.parse_args()
747
+ opt.cfg = check_yaml(opt.cfg) # check YAML
748
+ print_args(vars(opt))
749
+ device = select_device(opt.device)
750
+
751
+ # Create model
752
+ im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
753
+ model = Model(opt.cfg).to(device)
754
+ model.eval()
755
+
756
+ # Options
757
+ if opt.line_profile: # profile layer by layer
758
+ model(im, profile=True)
759
+
760
+ elif opt.profile: # profile forward-backward
761
+ results = profile(input=im, ops=[model], n=3)
762
+
763
+ elif opt.test: # test all models
764
+ for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
765
+ try:
766
+ _ = Model(cfg)
767
+ except Exception as e:
768
+ print(f'Error in {cfg}: {e}')
769
+
770
+ else: # report fused model summary
771
+ model.fuse()
spark repvit/repvit_1kpretrained_timm_style.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89abdcdc1a2865f96822bb61e3096159087c6f5d331961dd1fed8e0a9c58988e
3
+ size 269763237
spark/downstream_d2/README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## About code isolation
2
+
3
+ This `downstream_d2` is isolated from pre-training codes. One can treat this `downstream_d2` as an independent codebase 🛠️.
4
+
5
+
6
+ ## Fine-tuned ResNet-50 weights, log files, and performance
7
+
8
+ <div align="center">
9
+
10
+ [[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1H8605HbxGvrsu4x4rIoNr-Wkd7JkxFPQ/view?usp=share_link)]
11
+ [[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1Ue7SiQ1E_AwgtYo56Fm-iUlQPZ8vIwYj/view?usp=share_link)]
12
+ [[`metrics.json`](https://drive.google.com/file/d/1wfbUWh4svV8sPWya_0PAhsLHVayDQRCi/view?usp=share_link)]
13
+ [[`log.txt`](https://drive.google.com/file/d/11zVo_87pe9DMAmfNQK9FUfyjQWHTRKxV/view?usp=share_link)]
14
+ [[`tensorboard file`](https://drive.google.com/file/d/1aM1qj8c3-Uka1dZuYmKhgp1lNJpeMDMl/view?usp=share_link)]
15
+ </div>
16
+
17
+ <p align="center">
18
+ <img src="https://user-images.githubusercontent.com/39692511/211497479-0563e891-f2ad-4cf1-b682-a21c2be1442d.png" width=80%>
19
+ <p>
20
+
21
+
22
+ ## Installation [Detectron2 v0.6](https://github.com/facebookresearch/detectron2/releases/tag/v0.6) before fine-tuning ResNet on COCO
23
+
24
+
25
+ 1. Let you in some python environment, e.g.:
26
+ ```shell script
27
+ $ conda create -n spark python=3.8 -y
28
+ $ conda activate spark
29
+ ```
30
+
31
+ 2. Install `detectron2==0.6` (e.g., with `torch==1.10.0` and `cuda11.3`):
32
+ ```shell script
33
+ $ pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
34
+ ```
35
+
36
+ You can also find instructions for different pytorch/cuda versions on [this page](https://github.com/facebookresearch/detectron2/releases/tag/v0.6).
37
+
38
+
39
+ 3. Put the COCO dataset folder at `downstream_d2/datasets/coco`.
40
+ The folder should follow the [directory structure](https://github.com/facebookresearch/detectron2/tree/master/datasets) requried by `Detectron2`, which should look like this:
41
+ ```
42
+ downstream_d2/datasets/coco:
43
+ annotations/:
44
+ captions_train2017.json captions_val2017.json
45
+ instances_train2017.json instances_val2017.json
46
+ person_keypoints_train2017.json person_keypoints_val2017.json
47
+ train2017/:
48
+ a_lot_images.jpg
49
+ val2017/:
50
+ a_lot_images.jpg
51
+ ```
52
+
53
+
54
+ ## Training from pre-trained checkpoint
55
+
56
+ The script file for COCO fine-tuning (object detection and instance segmentation) is [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py),
57
+ which is a modification of [Detectron2's tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py).
58
+
59
+
60
+ Before fine-tuning a ResNet50 pre-trained by SparK, you should first convert our checkpoint file to Detectron2-style `.pkl` file:
61
+
62
+ ```shell script
63
+ $ cd /path/to/SparK/downstream_d2
64
+ $ python3 convert-timm-to-d2.py /some/path/to/resnet50_1kpretrained_timm_style.pth d2-style.pkl
65
+ ```
66
+
67
+ For a ResNet50, you should see a log reporting `len(state)==318`:
68
+ ```text
69
+ [convert] .pkl is generated! (from `/some/path/to/resnet50_1kpretrained_timm_style.pth`, to `d2-style.pkl`, len(state)==318)
70
+ ```
71
+
72
+ Then run fine-tuning on single machine with 8 gpus:
73
+
74
+ ```shell script
75
+ $ cd /path/to/SparK/downstream_d2
76
+ $ python3 ./train_net.py --resume --num-gpus 8 --config-file ./configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml \
77
+ MODEL.WEIGHTS d2-style.pkl \
78
+ OUTPUT_DIR <your_output_dir>
79
+ ```
80
+
81
+ For multiple machines, plus these args:
82
+ ```shell script
83
+ --num-machines <total_num> --machine-rank <this_rank> --dist-url <url:port>
84
+ ```
85
+
86
+ In `<your_output_dir>` you'll see the log files generated by `Detectron2`.
87
+
88
+
89
+ ## Details: how we modify the official Detectron2's [tools/train_net.py](https://github.com/facebookresearch/detectron2/blob/v0.6/tools/train_net.py) to get our [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py)
90
+
91
+ 1. We add two new hyperparameters:
92
+ - str `SOLVER.OPTIMIZER`: use 'ADAM' (the same as 'ADAMW') or 'SGD' optimizer
93
+ - float `SOLVER.LR_DECAY`: the decay ratio (from 0. to 1.) of layer-wise learning rate decay trick
94
+
95
+ 2. We implement layer-wise lr decay in [downstream_d2/lr_decay.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/lr_decay.py).
96
+
97
+ 3. We write a script to convert our timm-style pre-trained ResNet weights to Detectron2-style in [downstream_d2/convert-timm-to-d2.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/convert-timm-to-d2.py).
98
+
99
+ 4. We also add a hook for logging results to `cfg.OUTPUT_DIR/d2_coco_log.txt`.
100
+
101
+ All of our modifications to the original are commented with `# [modification] ...` in [downstream_d2/train_net.py](https://github.com/keyu-tian/SparK/blob/main/downstream_d2/train_net.py) or other files.
spark/downstream_d2/configs/Base-RCNN-FPN.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GeneralizedRCNN"
3
+ BACKBONE:
4
+ NAME: "build_resnet_fpn_backbone"
5
+ RESNETS:
6
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
7
+ FPN:
8
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
9
+ ANCHOR_GENERATOR:
10
+ SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
11
+ ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
12
+ RPN:
13
+ IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
14
+ PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
15
+ PRE_NMS_TOPK_TEST: 1000 # Per FPN level
16
+ # Detectron1 uses 2000 proposals per-batch,
17
+ # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
18
+ # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
19
+ POST_NMS_TOPK_TRAIN: 1000
20
+ POST_NMS_TOPK_TEST: 1000
21
+ ROI_HEADS:
22
+ NAME: "StandardROIHeads"
23
+ IN_FEATURES: ["p2", "p3", "p4", "p5"]
24
+ ROI_BOX_HEAD:
25
+ NAME: "FastRCNNConvFCHead"
26
+ NUM_FC: 2
27
+ POOLER_RESOLUTION: 7
28
+ ROI_MASK_HEAD:
29
+ NAME: "MaskRCNNConvUpsampleHead"
30
+ NUM_CONV: 4
31
+ POOLER_RESOLUTION: 14
32
+ DATASETS:
33
+ TRAIN: ("coco_2017_train",)
34
+ TEST: ("coco_2017_val",)
35
+ SOLVER:
36
+ IMS_PER_BATCH: 16
37
+ BASE_LR: 0.02
38
+ STEPS: (60000, 80000)
39
+ MAX_ITER: 90000
40
+ INPUT:
41
+ MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
42
+ VERSION: 2
spark/downstream_d2/configs/coco_R_50_FPN_CONV_1x_moco_adam.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ WEIGHTS: "<see instructions>"
4
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
5
+ PIXEL_STD: [58.395, 57.120, 57.375]
6
+
7
+ MASK_ON: True
8
+ BACKBONE:
9
+ FREEZE_AT: 0
10
+ RESNETS:
11
+ DEPTH: 50
12
+ NORM: "SyncBN"
13
+ STRIDE_IN_1X1: False
14
+ FPN:
15
+ NORM: "SyncBN"
16
+ ROI_BOX_HEAD:
17
+ NAME: "FastRCNNConvFCHead"
18
+ NUM_FC: 1
19
+ NUM_CONV: 4
20
+ POOLER_RESOLUTION: 7
21
+ NORM: "SyncBN"
22
+ ROI_MASK_HEAD:
23
+ NAME: "MaskRCNNConvUpsampleHead"
24
+ NUM_CONV: 4
25
+ POOLER_RESOLUTION: 14
26
+ NORM: "SyncBN"
27
+
28
+ INPUT:
29
+ MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896)
30
+ CROP:
31
+ ENABLED: False
32
+ TYPE: "absolute_range"
33
+ SIZE: (384, 600)
34
+ FORMAT: "RGB"
35
+ TEST:
36
+ EVAL_PERIOD: 5000
37
+ PRECISE_BN:
38
+ ENABLED: True
39
+
40
+ SOLVER:
41
+ STEPS: (60000, 80000)
42
+ MAX_ITER: 90000
43
+ GAMMA: 0.25
44
+ BASE_LR: 0.00025
45
+ WARMUP_FACTOR: 0.01
46
+ WARMUP_ITERS: 1000
47
+ WEIGHT_DECAY: 0.0001
48
+ CHECKPOINT_PERIOD: 5000
49
+ CLIP_GRADIENTS:
50
+ ENABLED: False
51
+ CLIP_TYPE: "value"
52
+ CLIP_VALUE: 1.0
53
+ NORM_TYPE: 2.0
54
+
55
+ # compared to standard detectron2, we add these two new configurations:
56
+ OPTIMIZER: "ADAMW"
57
+ LR_DECAY: 0.6
spark/downstream_d2/convert-timm-to-d2.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+
3
+ # Copyright (c) ByteDance, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ import pickle as pkl
10
+
11
+ import torch
12
+
13
+
14
+ # we use `timm.models.ResNet` in pre-training, so keys are timm-style
15
+ def timm_resnet_to_detectron2_resnet(source_file, target_file):
16
+ pretrained: dict = torch.load(source_file, map_location='cpu')
17
+ for mod_k in {'state_dict', 'state', 'module', 'model'}:
18
+ if mod_k in pretrained:
19
+ pretrained = pretrained[mod_k]
20
+ if any(k.startswith('module.encoder_q.') for k in pretrained.keys()):
21
+ pretrained = {k.replace('module.encoder_q.', ''): v for k, v in pretrained.items() if k.startswith('module.encoder_q.')}
22
+
23
+ pkl_state = {}
24
+ for k, v in pretrained.items(): # convert resnet's keys from timm-style to d2-style
25
+ if 'layer' not in k:
26
+ k = 'stem.' + k
27
+ for t in [1, 2, 3, 4]:
28
+ k = k.replace(f'layer{t}', f'res{t+1}')
29
+ for t in [1, 2, 3]:
30
+ k = k.replace(f'bn{t}', f'conv{t}.norm')
31
+ k = k.replace('downsample.0', 'shortcut')
32
+ k = k.replace('downsample.1', 'shortcut.norm')
33
+
34
+ pkl_state[k] = v.detach().numpy()
35
+
36
+ with open(target_file, 'wb') as fp:
37
+ print(f'[convert] .pkl is generated! (from `{source_file}`, to `{target_file}`, len(state)=={len(pkl_state)})')
38
+ pkl.dump({'model': pkl_state, '__author__': 'https://github.com/keyu-tian/SparK', 'matching_heuristics': True}, fp)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ import sys
43
+ timm_resnet_to_detectron2_resnet(sys.argv[1], sys.argv[2])
spark/downstream_d2/lr_decay.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Set, Optional, Callable, Any
2
+ import torch
3
+ import copy
4
+
5
+ from detectron2.solver.build import reduce_param_groups
6
+
7
+
8
+ def lr_factor_func(para_name: str, is_resnet50, dec: float, debug=False) -> float:
9
+ if dec == 0:
10
+ dec = 1.
11
+
12
+ N = 5 if is_resnet50 else 11
13
+ if '.stem.' in para_name:
14
+ layer_id = 0
15
+ elif '.res' in para_name:
16
+ ls = para_name.split('.res')[1].split('.')
17
+ if ls[0].isnumeric() and ls[1].isnumeric():
18
+ stage_id, block_id = int(ls[0]), int(ls[1])
19
+ if stage_id == 2: # res2
20
+ layer_id = 1
21
+ elif stage_id == 3: # res3
22
+ layer_id = 2
23
+ elif stage_id == 4: # res4
24
+ layer_id = 3 + block_id // 3 # 3, 4 or 4, 5
25
+ else: # res5
26
+ layer_id = N
27
+ else:
28
+ assert para_name.startswith('roi_heads.res5.norm.')
29
+ layer_id = N + 1 # roi_heads.res5.norm.weight and roi_heads.res5.norm.bias of C4
30
+ else:
31
+ layer_id = N + 1
32
+
33
+ exp = N + 1 - layer_id
34
+ return f'{dec:g} ** {exp}' if debug else dec ** exp
35
+
36
+
37
+ # [modification] see: https://github.com/facebookresearch/detectron2/blob/v0.6/detectron2/solver/build.py#L134
38
+ # add the `lr_factor_func` to implement lr decay
39
+ def get_default_optimizer_params(
40
+ model: torch.nn.Module,
41
+ base_lr: Optional[float] = None,
42
+ weight_decay: Optional[float] = None,
43
+ weight_decay_norm: Optional[float] = None,
44
+ bias_lr_factor: Optional[float] = 1.0,
45
+ weight_decay_bias: Optional[float] = None,
46
+ lr_factor_func: Optional[Callable] = None,
47
+ overrides: Optional[Dict[str, Dict[str, float]]] = None,
48
+ ) -> List[Dict[str, Any]]:
49
+ """
50
+ Get default param list for optimizer, with support for a few types of
51
+ overrides. If no overrides needed, this is equivalent to `model.parameters()`.
52
+
53
+ Args:
54
+ base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
55
+ weight_decay: weight decay for every group by default. Can be omitted to use the one
56
+ in optimizer.
57
+ weight_decay_norm: override weight decay for params in normalization layers
58
+ bias_lr_factor: multiplier of lr for bias parameters.
59
+ weight_decay_bias: override weight decay for bias parameters.
60
+ lr_factor_func: function to calculate lr decay rate by mapping the parameter names to
61
+ corresponding lr decay rate. Note that setting this option requires
62
+ also setting ``base_lr``.
63
+ overrides: if not `None`, provides values for optimizer hyperparameters
64
+ (LR, weight decay) for module parameters with a given name; e.g.
65
+ ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
66
+ weight decay values for all module parameters named `embedding`.
67
+
68
+ For common detection models, ``weight_decay_norm`` is the only option
69
+ needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
70
+ from Detectron1 that are not found useful.
71
+
72
+ Example:
73
+ ::
74
+ torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
75
+ lr=0.01, weight_decay=1e-4, momentum=0.9)
76
+ """
77
+ if overrides is None:
78
+ overrides = {}
79
+ defaults = {}
80
+ if base_lr is not None:
81
+ defaults["lr"] = base_lr
82
+ if weight_decay is not None:
83
+ defaults["weight_decay"] = weight_decay
84
+ bias_overrides = {}
85
+ if bias_lr_factor is not None and bias_lr_factor != 1.0:
86
+ # NOTE: unlike Detectron v1, we now by default make bias hyperparameters
87
+ # exactly the same as regular weights.
88
+ if base_lr is None:
89
+ raise ValueError("bias_lr_factor requires base_lr")
90
+ bias_overrides["lr"] = base_lr * bias_lr_factor
91
+ if weight_decay_bias is not None:
92
+ bias_overrides["weight_decay"] = weight_decay_bias
93
+ if len(bias_overrides):
94
+ if "bias" in overrides:
95
+ raise ValueError("Conflicting overrides for 'bias'")
96
+ overrides["bias"] = bias_overrides
97
+ if lr_factor_func is not None:
98
+ if base_lr is None:
99
+ raise ValueError("lr_factor_func requires base_lr")
100
+ norm_module_types = (
101
+ torch.nn.BatchNorm1d,
102
+ torch.nn.BatchNorm2d,
103
+ torch.nn.BatchNorm3d,
104
+ torch.nn.SyncBatchNorm,
105
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
106
+ torch.nn.GroupNorm,
107
+ torch.nn.InstanceNorm1d,
108
+ torch.nn.InstanceNorm2d,
109
+ torch.nn.InstanceNorm3d,
110
+ torch.nn.LayerNorm,
111
+ torch.nn.LocalResponseNorm,
112
+ )
113
+ params: List[Dict[str, Any]] = []
114
+ memo: Set[torch.nn.parameter.Parameter] = set()
115
+ for module_name, module in model.named_modules():
116
+ for module_param_name, value in module.named_parameters(recurse=False):
117
+ if not value.requires_grad:
118
+ continue
119
+ # Avoid duplicating parameters
120
+ if value in memo:
121
+ continue
122
+ memo.add(value)
123
+
124
+ hyperparams = copy.copy(defaults)
125
+ if isinstance(module, norm_module_types) and weight_decay_norm is not None:
126
+ hyperparams["weight_decay"] = weight_decay_norm
127
+ if lr_factor_func is not None:
128
+ hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
129
+
130
+ hyperparams.update(overrides.get(module_param_name, {}))
131
+ params.append({"params": [value], **hyperparams})
132
+ return reduce_param_groups(params)
spark/downstream_d2/train_net.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+
3
+ # Copyright (c) ByteDance, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ import datetime
10
+ import json
11
+ import logging
12
+ import os
13
+ import time
14
+ from collections import OrderedDict, defaultdict
15
+ from functools import partial
16
+ from pprint import pformat
17
+
18
+ import numpy as np
19
+ import torch
20
+ import detectron2.utils.comm as comm
21
+ from detectron2.checkpoint import DetectionCheckpointer
22
+ from detectron2.config import get_cfg
23
+ from detectron2.data import MetadataCatalog
24
+ from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch, PeriodicWriter
25
+ from detectron2.evaluation import (
26
+ CityscapesInstanceEvaluator,
27
+ CityscapesSemSegEvaluator,
28
+ COCOEvaluator,
29
+ COCOPanopticEvaluator,
30
+ DatasetEvaluators,
31
+ LVISEvaluator,
32
+ PascalVOCDetectionEvaluator,
33
+ SemSegEvaluator,
34
+ verify_results,
35
+ )
36
+ from detectron2.layers import get_norm
37
+ from detectron2.modeling import GeneralizedRCNNWithTTA
38
+ from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
39
+ from detectron2.solver.build import maybe_add_gradient_clipping
40
+ from detectron2.utils.events import EventWriter
41
+
42
+ from lr_decay import get_default_optimizer_params, lr_factor_func
43
+
44
+
45
+ # [modification] for better logging
46
+ def _ex_repr(self):
47
+ d = vars(self)
48
+ ex = ', '.join(f'{k}={v}' for k, v in d.items() if not k.startswith('__') and k not in [
49
+ 'trainer', 'before_train', 'after_train', 'before_step', 'after_step', 'state_dict',
50
+ '_model', '_data_loader', 'logger',
51
+ ])
52
+ return f'{type(self).__name__}({ex})'
53
+ hooks.HookBase.__repr__ = _ex_repr
54
+ EventWriter.__repr__ = _ex_repr
55
+
56
+
57
+ # [modification] add norm
58
+ @ROI_HEADS_REGISTRY.register()
59
+ class Res5ROIHeadsExtraNorm(Res5ROIHeads):
60
+ """
61
+ As described in the MOCO paper, there is an extra BN layer
62
+ following the res5 stage.
63
+ """
64
+
65
+ def _build_res5_block(self, cfg):
66
+ seq, out_channels = super()._build_res5_block(cfg)
67
+ norm = cfg.MODEL.RESNETS.NORM
68
+ norm = get_norm(norm, out_channels)
69
+ seq.add_module("norm", norm)
70
+ return seq, out_channels
71
+
72
+
73
+ class Trainer(DefaultTrainer):
74
+ """
75
+ We use the "DefaultTrainer" which contains pre-defined default logic for
76
+ standard training workflow. They may not work for you, especially if you
77
+ are working on a new research project. In that case you can write your
78
+ own training loop. You can use "tools/plain_train_net.py" as an example.
79
+ """
80
+
81
+ # [modification] override the `build_optimizer` for using Adam and layer-wise lr decay
82
+ lr_decay_ratio: float = 1.0
83
+ @classmethod
84
+ def build_optimizer(cls, cfg, model):
85
+ is_resnet50 = int(cfg.MODEL.RESNETS.DEPTH) == 50
86
+ if comm.is_main_process():
87
+ dbg = defaultdict(list)
88
+ for module_name, module in model.named_modules():
89
+ for module_param_name, value in module.named_parameters(recurse=False):
90
+ if not value.requires_grad:
91
+ continue
92
+ lrf = lr_factor_func(f"{module_name}.{module_param_name}", is_resnet50=is_resnet50, dec=cls.lr_decay_ratio, debug=True)
93
+ dbg[lrf].append(f"{module_name}.{module_param_name}")
94
+ for k in sorted(dbg.keys()):
95
+ print(f'[{k}] {sorted(dbg[k])}')
96
+ print()
97
+
98
+ params = get_default_optimizer_params(
99
+ model,
100
+ base_lr=cfg.SOLVER.BASE_LR,
101
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
102
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
103
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
104
+ lr_factor_func=partial(lr_factor_func, is_resnet50=is_resnet50, dec=cls.lr_decay_ratio, debug=False)
105
+ )
106
+
107
+ opt_clz = {
108
+ 'sgd': partial(torch.optim.SGD, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV),
109
+ 'adamw': torch.optim.AdamW,
110
+ 'adam': torch.optim.AdamW,
111
+ }[cfg.SOLVER.OPTIMIZER.lower()]
112
+ return maybe_add_gradient_clipping(cfg, opt_clz)(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
113
+
114
+ @classmethod
115
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
116
+ return build_evaluator(cfg, dataset_name, output_folder)
117
+
118
+ @classmethod
119
+ def test_with_TTA(cls, cfg, model):
120
+ logger = logging.getLogger("detectron2.trainer")
121
+ # In the end of training, run an evaluation with TTA
122
+ # Only support some R-CNN models.
123
+ logger.info("Running inference with test-time augmentation ...")
124
+ model = GeneralizedRCNNWithTTA(cfg, model)
125
+ evaluators = [
126
+ cls.build_evaluator(
127
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
128
+ )
129
+ for name in cfg.DATASETS.TEST
130
+ ]
131
+ res = cls.test(cfg, model, evaluators)
132
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
133
+ return res
134
+
135
+
136
+ def setup(args):
137
+ """
138
+ Create configs and perform basic setups.
139
+ """
140
+ cfg = get_cfg()
141
+ # [modification] we add these two new keys
142
+ cfg.SOLVER.OPTIMIZER, cfg.SOLVER.LR_DECAY = 'sgd', 1.0 # by default using SGD and no lr_decay
143
+ cfg.merge_from_file(args.config_file)
144
+ cfg.merge_from_list(args.opts)
145
+ cfg.freeze()
146
+ default_setup(cfg, args)
147
+ return cfg
148
+
149
+
150
+ def main(args):
151
+ cfg = setup(args)
152
+ os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
153
+
154
+ # [modification] for implementing lr decay and for logging
155
+ Trainer.lr_decay_ratio = cfg.SOLVER.LR_DECAY
156
+
157
+ if args.eval_only:
158
+ model = Trainer.build_model(cfg)
159
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
160
+ cfg.MODEL.WEIGHTS, resume=args.resume
161
+ )
162
+ res = Trainer.test(cfg, model)
163
+ if cfg.TEST.AUG.ENABLED:
164
+ res.update(Trainer.test_with_TTA(cfg, model))
165
+ if comm.is_main_process():
166
+ verify_results(cfg, res)
167
+ return res
168
+
169
+ # [modification] just skip some warnings
170
+ import warnings
171
+ comm.synchronize()
172
+ warnings.filterwarnings('ignore', category=UserWarning)
173
+ _ = np.arange(3, dtype=np.int).astype(np.bool)
174
+ _ = np.array(torch.ones(3, dtype=torch.int32).numpy(), dtype=np.int)
175
+ _ = np.array(torch.ones(3, dtype=torch.int64).numpy(), dtype=np.int)
176
+ _ = np.array(torch.ones(3, dtype=torch.long).numpy(), dtype=np.int)
177
+ _ = torch.rand(100) // 5
178
+ _ = torch.meshgrid(torch.ones(1))
179
+ warnings.resetwarnings()
180
+ comm.synchronize()
181
+
182
+ """
183
+ If you'd like to do anything fancier than the standard training logic,
184
+ consider writing your own training loop (see plain_train_net.py) or
185
+ subclassing the trainer.
186
+ """
187
+ trainer = Trainer(cfg)
188
+ trainer.resume_or_load(resume=args.resume)
189
+ for h in trainer._hooks:
190
+ if isinstance(h, PeriodicWriter):
191
+ h._period = 1000 # [modification] less logging
192
+
193
+ # [modification] we add some hooks for logging
194
+ is_local_master = comm.get_rank() % args.num_gpus == 0
195
+ if comm.is_main_process():
196
+ print(f'[default hooks] {pformat(trainer._hooks, indent=2, width=300)}')
197
+ ex_hooks = [
198
+ hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model)) if cfg.TEST.AUG.ENABLED else None,
199
+ LogHook(cfg.TEST.EVAL_PERIOD, args.config_file, cfg.OUTPUT_DIR, is_local_master) if comm.is_main_process() else None,
200
+ ]
201
+ trainer.register_hooks(ex_hooks)
202
+ if comm.is_main_process():
203
+ print(f'[extra hooks] {pformat(ex_hooks, indent=2, width=300)}')
204
+
205
+ return trainer.train()
206
+
207
+
208
+ # [modification] we add a hook for logging results to `cfg.OUTPUT_DIR/d2_coco_log.txt`
209
+ class LogHook(hooks.HookBase):
210
+ def __init__(self, eval_period, config_file, output_dir, is_local_master):
211
+ self.eval_period = eval_period
212
+ self.log_period = eval_period // 4
213
+ self.log = {}
214
+
215
+ self.is_master = comm.is_main_process()
216
+ self.is_local_master = is_local_master
217
+
218
+ self.config_file = config_file
219
+ self.out_dir = output_dir
220
+ self.log_txt_name = os.path.join(self.out_dir, 'd2_coco_log.txt')
221
+
222
+ def __write_to_log_file(self, d):
223
+ if self.is_local_master:
224
+ self.log.update(d)
225
+ with open(self.log_txt_name, 'w') as fp:
226
+ json.dump(self.log, fp)
227
+ fp.write('\n')
228
+
229
+ def update_and_write_to_local_log(self):
230
+ stat = self.trainer.storage.latest()
231
+ self.log['boxAP'], self.log['bAP50'], self.log['bAP75'] = stat['bbox/AP'][0], stat['bbox/AP50'][0], stat['bbox/AP75'][0]
232
+ self.log['mskAP'], self.log['mAP50'], self.log['mAP75'] = stat['segm/AP'][0], stat['segm/AP50'][0], stat['segm/AP75'][0]
233
+ self.log['bAP-l'], self.log['bAP-m'], self.log['bAP-s'] = stat['bbox/APl'][0], stat['bbox/APm'][0], stat['bbox/APs'][0]
234
+ self.log['mAP-l'], self.log['mAP-m'], self.log['mAP-s'] = stat['segm/APl'][0], stat['segm/APm'][0], stat['segm/APs'][0]
235
+ all_ap = sorted([(v[0], k.split('AP-')[-1].strip()) for k, v in stat.items() if k.startswith('bbox/AP-')])
236
+ all_ap = [tu[1] for tu in all_ap]
237
+ self.log['easy'] = ' | '.join(all_ap[-7:])
238
+ self.log['hard'] = ' | '.join(all_ap[:7])
239
+ for k in self.log.keys():
240
+ if 'AP' in k:
241
+ self.log[k] = round(self.log[k], 3)
242
+ self.__write_to_log_file({})
243
+
244
+ def after_step(self):
245
+ next_iter = self.trainer.iter + 1
246
+ if self.eval_period > 0 and next_iter % self.eval_period == 0:
247
+ self.update_and_write_to_local_log()
248
+
249
+ if self.log_period > 0 and next_iter % self.log_period == 0:
250
+ stat = self.trainer.storage.latest()
251
+ remain_secs = round(stat['eta_seconds'][0])
252
+ d = {
253
+ 'cfg': self.config_file,
254
+ 'rema': str(datetime.timedelta(seconds=remain_secs)), 'fini': time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs)),
255
+ 'cur_iter': f'{next_iter}/{self.trainer.max_iter}',
256
+ }
257
+ self.__write_to_log_file(d)
258
+
259
+ def after_train(self):
260
+ self.update_and_write_to_local_log()
261
+ last_boxAP, last_mskAP = round(self.log['boxAP'], 3), round(self.log['mskAP'], 3)
262
+ self.__write_to_log_file({
263
+ 'rema': '-', 'fini': time.strftime("%m-%d %H:%M", time.localtime(time.time() - 120)),
264
+ 'last_boxAP': last_boxAP,
265
+ 'last_mskAP': last_mskAP,
266
+ })
267
+ time.sleep(5)
268
+ if self.is_master:
269
+ print(f'\n[finished] ========== last_boxAP={last_boxAP}, last_mskAP={last_mskAP} ==========\n')
270
+
271
+
272
+ if __name__ == "__main__":
273
+ args = default_argument_parser().parse_args()
274
+ print("Command Line Args:", args)
275
+ launch(
276
+ main,
277
+ args.num_gpus,
278
+ num_machines=args.num_machines,
279
+ machine_rank=args.machine_rank,
280
+ dist_url=args.dist_url,
281
+ args=(args,),
282
+ )
283
+
284
+
285
+ def build_evaluator(cfg, dataset_name, output_folder=None):
286
+ """
287
+ Create evaluator(s) for a given dataset.
288
+ This uses the special metadata "evaluator_type" associated with each builtin dataset.
289
+ For your own dataset, you can simply create an evaluator manually in your
290
+ script and do not have to worry about the hacky if-else logic here.
291
+ """
292
+ if output_folder is None:
293
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
294
+ evaluator_list = []
295
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
296
+ if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
297
+ evaluator_list.append(
298
+ SemSegEvaluator(
299
+ dataset_name,
300
+ distributed=True,
301
+ output_dir=output_folder,
302
+ )
303
+ )
304
+ if evaluator_type in ["coco", "coco_panoptic_seg"]:
305
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
306
+ if evaluator_type == "coco_panoptic_seg":
307
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
308
+ if evaluator_type == "cityscapes_instance":
309
+ return CityscapesInstanceEvaluator(dataset_name)
310
+ if evaluator_type == "cityscapes_sem_seg":
311
+ return CityscapesSemSegEvaluator(dataset_name)
312
+ elif evaluator_type == "pascal_voc":
313
+ return PascalVOCDetectionEvaluator(dataset_name)
314
+ elif evaluator_type == "lvis":
315
+ return LVISEvaluator(dataset_name, output_dir=output_folder)
316
+ if len(evaluator_list) == 0:
317
+ raise NotImplementedError(
318
+ "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
319
+ )
320
+ elif len(evaluator_list) == 1:
321
+ return evaluator_list[0]
322
+ return DatasetEvaluators(evaluator_list)
spark/downstream_imagenet/README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## About code isolation
2
+
3
+ This `downstream_imagenet` is isolated from pre-training codes. One can treat this `downstream_imagenet` as an independent codebase 🛠️.
4
+
5
+
6
+ ## Preparation for ImageNet-1k fine-tuning
7
+
8
+ See [INSTALL.md](https://github.com/keyu-tian/SparK/blob/main/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
9
+
10
+ **Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
11
+
12
+
13
+ ## Fine-tuning on ImageNet-1k from pre-trained weights
14
+
15
+ Run [/downstream_imagenet/main.py](/downstream_imagenet/main.py) via `torchrun`.
16
+ **It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), the model name (`--model`, valid choices see the keys of 'HP_DEFAULT_VALUES' in [/downstream_imagenet/arg.py line14](/downstream_imagenet/arg.py#L14)), and the pretrained weight file `--resume_from` to run fine-tuning.
17
+
18
+ All the other configurations have their default values, listed in [/downstream_imagenet/arg.py#L13](/downstream_imagenet/arg.py#L13).
19
+ You can overwrite any defaults by `--bs=1024` or something like that.
20
+
21
+
22
+ Here is an example to pretrain a ConvNeXt-Small on an 8-GPU single machine:
23
+ ```shell script
24
+ $ cd /path/to/SparK/downstream_imagenet
25
+ $ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=<some_port> main.py \
26
+ --data_path=/path/to/imagenet --exp_name=<your_exp_name> --exp_dir=/path/to/logdir \
27
+ --model=convnext_small --resume_from=/some/path/to/convnextS_1kpretrained_official_style.pth
28
+ ```
29
+
30
+ For multiple machines, change the `--nnodes` and `--master_addr` to your configurations. E.g.:
31
+ ```shell script
32
+ $ torchrun --nproc_per_node=8 --nnodes=<your_nnodes> --node_rank=<rank_starts_from_0> --master_address=<some_address> --master_port=<some_port> main.py \
33
+ ...
34
+ ```
35
+
36
+
37
+ ## Logging
38
+
39
+ See files under `--exp_dir` to track your experiment:
40
+
41
+ - `<model>_1kfinetuned_last.pth`: the latest model weights
42
+ - `<model>_1kfinetuned_best.pth`: model weights with the highest acc
43
+ - `<model>_1kfinetuned_best_ema.pth`: EMA weights with the highest acc
44
+ - `finetune_log.txt`: records some important information such as:
45
+ - `git_commit_id`: git version
46
+ - `cmd`: all arguments passed to the script
47
+
48
+ It also reports training loss/acc, best evaluation acc, and remaining time at each epoch.
49
+
50
+ - `tensorboard_log/`: saves a lot of tensorboard logs, you can visualize accuracies, loss values, learning rates, gradient norms and more things via `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333`.
51
+
52
+ ## Resuming
53
+
54
+ Use `--resume_from` again, like `--resume_from=path/to/<model>_1kfinetuned_last.pth`.
spark/downstream_imagenet/arg.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+
11
+ from tap import Tap
12
+
13
+ HP_DEFAULT_NAMES = ['bs', 'ep', 'wp_ep', 'opt', 'base_lr', 'lr_scale', 'wd', 'mixup', 'rep_aug', 'drop_path', 'ema']
14
+ HP_DEFAULT_VALUES = {
15
+ 'convnext_small': (4096, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999),
16
+ 'convnext_base': (4096, 400, 20, 'adam', 0.0001, 0.7, 0.01, 0.8, 3, 0.4, 0.9999),
17
+ 'convnext_large': (4096, 200, 10, 'adam', 0.0001, 0.7, 0.02, 0.8, 3, 0.5, 0.9999),
18
+ 'convnext_large_384': (1024, 200, 20, 'adam', 0.00006, 0.7, 0.01, 0.8, 3, 0.5, 0.99995),
19
+
20
+ 'resnet50': (4096, 300, 5, 'lamb', 0.002, 0.7, 0.02, 0.1, 0, 0.05, 0.9999),
21
+ 'resnet101': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
22
+ 'resnet152': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
23
+ 'resnet200': (4096, 300, 5, 'lamb', 0.001, 0.8, 0.02, 0.1, 0, 0.2, 0.9999),
24
+ }
25
+
26
+
27
+ class FineTuneArgs(Tap):
28
+ # environment
29
+ exp_name: str
30
+ exp_dir: str
31
+ data_path: str
32
+ model: str
33
+ resume_from: str = '' # resume from some checkpoint.pth
34
+
35
+ img_size: int = 640
36
+ dataloader_workers: int = 8
37
+
38
+ # ImageNet classification fine-tuning hyperparameters; see `HP_DEFAULT_VALUES` above for detailed default values
39
+ # - batch size, epoch
40
+ bs: int = 0 # global batch size (== batch_size_per_gpu * num_gpus)
41
+ ep: int = 0 # number of epochs
42
+ wp_ep: int = 0 # epochs for warmup
43
+
44
+ # - optimization
45
+ opt: str = '' # optimizer; 'adam' or 'lamb'
46
+ base_lr: float = 0. # lr == base_lr * (bs)
47
+ lr_scale: float = 0. # see file `lr_decay.py` for more details
48
+ clip: int = -1 # use gradient clipping if clip > 0
49
+
50
+ # - regularization tricks
51
+ wd: float = 0. # weight decay
52
+ mixup: float = 0. # use mixup if mixup > 0
53
+ rep_aug: int = 0 # use repeated augmentation if rep_aug > 0
54
+ drop_path: float = 0. # drop_path ratio
55
+
56
+ # - other tricks
57
+ ema: float = 0. # use EMA if ema > 0
58
+ sbn: bool = True # use SyncBatchNorm
59
+
60
+ # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically
61
+ lr: float = None
62
+ batch_size_per_gpu: int = 0
63
+ glb_batch_size: int = 0
64
+ device: str = 'cpu'
65
+ world_size: int = 1
66
+ global_rank: int = 0
67
+ local_rank: int = 0 # we DO USE this arg
68
+ is_master: bool = False
69
+ is_local_master: bool = False
70
+ cmd: str = ' '.join(sys.argv[1:])
71
+ commit_id: str = os.popen(f'git rev-parse HEAD').read().strip()
72
+ commit_msg: str = os.popen(f'git log -1').read().strip().splitlines()[-1].strip()
73
+ log_txt_name: str = '{args.exp_dir}/pretrain_log.txt'
74
+ tb_lg_dir: str = '' # tensorboard log directory
75
+
76
+ train_loss: float = 0.
77
+ train_acc: float = 0.
78
+ best_val_acc: float = 0.
79
+ cur_ep: str = ''
80
+ remain_time: str = ''
81
+ finish_time: str = ''
82
+ first_logging: bool = True
83
+
84
+ def log_epoch(self):
85
+ if not self.is_local_master:
86
+ return
87
+
88
+ if self.first_logging:
89
+ self.first_logging = False
90
+ with open(self.log_txt_name, 'w') as fp:
91
+ json.dump({
92
+ 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg,
93
+ 'model': self.model,
94
+ }, fp)
95
+ fp.write('\n\n')
96
+
97
+ with open(self.log_txt_name, 'a') as fp:
98
+ json.dump({
99
+ 'cur_ep': self.cur_ep,
100
+ 'train_L': self.train_loss, 'train_acc': self.train_acc,
101
+ 'best_val_acc': self.best_val_acc,
102
+ 'rema': self.remain_time, 'fini': self.finish_time,
103
+ }, fp)
104
+ fp.write('\n')
105
+
106
+
107
+ def get_args(world_size, global_rank, local_rank, device) -> FineTuneArgs:
108
+ # parse args and prepare directories
109
+ args = FineTuneArgs(explicit_bool=True).parse_args()
110
+ d_name, b_name = os.path.dirname(os.path.abspath(args.exp_dir)), os.path.basename(os.path.abspath(args.exp_dir))
111
+ b_name = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in b_name)
112
+ args.exp_dir = os.path.join(d_name, b_name)
113
+ os.makedirs(args.exp_dir, exist_ok=True)
114
+ args.log_txt_name = os.path.join(args.exp_dir, 'finetune_log.txt')
115
+
116
+ args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log')
117
+ try: os.makedirs(args.tb_lg_dir, exist_ok=True)
118
+ except: pass
119
+
120
+ # fill in args.bs, args.ep, etc. with their default values (if their values are not explicitly specified, i.e., if bool(they) == False)
121
+ if args.model == 'convnext_large' and args.img_size == 384:
122
+ default_values = HP_DEFAULT_VALUES['convnext_large_384']
123
+ else:
124
+ default_values = HP_DEFAULT_VALUES[args.model]
125
+ for k, v in zip(HP_DEFAULT_NAMES, default_values):
126
+ if bool(getattr(args, k)) == False:
127
+ setattr(args, k, v)
128
+
129
+ # update other runtime args
130
+ args.world_size, args.global_rank, args.local_rank, args.device = world_size, global_rank, local_rank, device
131
+ args.is_master = global_rank == 0
132
+ args.is_local_master = local_rank == 0
133
+ args.batch_size_per_gpu = args.bs // world_size
134
+ args.glb_batch_size = args.batch_size_per_gpu * world_size
135
+ args.lr = args.base_lr * args.glb_batch_size / 256
136
+
137
+ return args
spark/downstream_imagenet/data.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import random
9
+ import time
10
+
11
+ import PIL.Image as PImage
12
+ import numpy as np
13
+ import torch
14
+ import torchvision
15
+ from timm.data import AutoAugment as TimmAutoAugment
16
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
17
+ from timm.data.distributed_sampler import RepeatAugSampler
18
+ from timm.data.transforms_factory import transforms_imagenet_eval
19
+ from torch.utils.data import DataLoader
20
+ from torch.utils.data.sampler import Sampler
21
+ from torchvision.transforms import AutoAugment as TorchAutoAugment
22
+ from torchvision.transforms import transforms, TrivialAugmentWide
23
+
24
+ try:
25
+ from torchvision.transforms import InterpolationMode
26
+ interpolation = InterpolationMode.BICUBIC
27
+ except:
28
+ import PIL
29
+ interpolation = PIL.Image.BICUBIC
30
+
31
+
32
+ def create_classification_dataset(data_path, img_size, rep_aug, workers, batch_size_per_gpu, world_size, global_rank):
33
+ import warnings
34
+ warnings.filterwarnings('ignore', category=UserWarning)
35
+
36
+ mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
37
+ trans_train = create_transform(
38
+ is_training=True, input_size=img_size,
39
+ auto_augment='v0', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1,
40
+ mean=mean, std=std,
41
+ )
42
+ if img_size < 384:
43
+ for i, t in enumerate(trans_train.transforms):
44
+ if isinstance(t, (TorchAutoAugment, TimmAutoAugment)):
45
+ trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation)
46
+ break
47
+ trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std)
48
+ else:
49
+ trans_val = transforms.Compose([
50
+ transforms.Resize((img_size, img_size), interpolation=interpolation),
51
+ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std),
52
+ ])
53
+ print_transform(trans_train, '[train]')
54
+ print_transform(trans_val, '[val]')
55
+
56
+ imagenet_folder = os.path.abspath(data_path)
57
+ for postfix in ('train', 'val'):
58
+ if imagenet_folder.endswith(postfix):
59
+ imagenet_folder = imagenet_folder[:-len(postfix)]
60
+ dataset_train = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'train'), trans_train)
61
+ dataset_val = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'val'), trans_val)
62
+
63
+ if rep_aug:
64
+ print(f'[dataset] using repeated augmentation: count={rep_aug}')
65
+ train_sp = RepeatAugSampler(dataset_train, shuffle=True, num_repeats=rep_aug)
66
+ else:
67
+ train_sp = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True, drop_last=True)
68
+
69
+ loader_train = DataLoader(
70
+ dataset=dataset_train, num_workers=workers, pin_memory=True,
71
+ batch_size=batch_size_per_gpu, sampler=train_sp, persistent_workers=workers > 0,
72
+ worker_init_fn=worker_init_fn,
73
+ )
74
+ iters_train = len(loader_train)
75
+ print(f'[dataset: train] bs={world_size}x{batch_size_per_gpu}={world_size * batch_size_per_gpu}, num_iters={iters_train}')
76
+
77
+ val_ratio = 2
78
+ loader_val = DataLoader(
79
+ dataset=dataset_val, num_workers=workers, pin_memory=True,
80
+ batch_sampler=DistInfiniteBatchSampler(world_size, global_rank, len(dataset_val), glb_batch_size=val_ratio * batch_size_per_gpu, filling=False, shuffle=False),
81
+ worker_init_fn=worker_init_fn,
82
+ )
83
+ iters_val = len(loader_val)
84
+ print(f'[dataset: val] bs={world_size}x{val_ratio * batch_size_per_gpu}={val_ratio * world_size * batch_size_per_gpu}, num_iters={iters_val}')
85
+
86
+ time.sleep(3)
87
+ warnings.resetwarnings()
88
+ return loader_train, iters_train, iter(loader_val), iters_val
89
+
90
+
91
+ def worker_init_fn(worker_id):
92
+ # see: https://pytorch.org/docs/stable/notes/randomness.html#dataloader
93
+ worker_seed = torch.initial_seed() % 2 ** 32
94
+ np.random.seed(worker_seed)
95
+ random.seed(worker_seed)
96
+
97
+
98
+ def print_transform(transform, s):
99
+ print(f'Transform {s} = ')
100
+ for t in transform.transforms:
101
+ print(t)
102
+ print('---------------------------\n')
103
+
104
+
105
+ class DistInfiniteBatchSampler(Sampler):
106
+ def __init__(self, world_size, global_rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True):
107
+ assert glb_batch_size % world_size == 0
108
+ self.world_size, self.rank = world_size, global_rank
109
+ self.dataset_len = dataset_len
110
+ self.glb_batch_size = glb_batch_size
111
+ self.batch_size = glb_batch_size // world_size
112
+
113
+ self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
114
+ self.filling = filling
115
+ self.shuffle = shuffle
116
+ self.epoch = 0
117
+ self.seed = seed
118
+ self.indices = self.gener_indices()
119
+
120
+ def gener_indices(self):
121
+ global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
122
+ if self.shuffle:
123
+ g = torch.Generator()
124
+ g.manual_seed(self.epoch + self.seed)
125
+ global_indices = torch.randperm(self.dataset_len, generator=g)
126
+ else:
127
+ global_indices = torch.arange(self.dataset_len)
128
+ filling = global_max_p - global_indices.shape[0]
129
+ if filling > 0 and self.filling:
130
+ global_indices = torch.cat((global_indices, global_indices[:filling]))
131
+ global_indices = tuple(global_indices.numpy().tolist())
132
+
133
+ seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
134
+ local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
135
+ self.max_p = len(local_indices)
136
+ return local_indices
137
+
138
+ def __iter__(self):
139
+ self.epoch = 0
140
+ while True:
141
+ self.epoch += 1
142
+ p, q = 0, 0
143
+ while p < self.max_p:
144
+ q = p + self.batch_size
145
+ yield self.indices[p:q]
146
+ p = q
147
+ if self.shuffle:
148
+ self.indices = self.gener_indices()
149
+
150
+ def __len__(self):
151
+ return self.iters_per_ep
spark/downstream_imagenet/lr_decay.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from pprint import pformat
9
+
10
+
11
+ def lr_wd_annealing(optimizer, peak_lr, wd, cur_it, wp_it, max_it):
12
+ wp_it = round(wp_it)
13
+ if cur_it < wp_it:
14
+ cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it
15
+ else:
16
+ ratio = (cur_it - wp_it) / (max_it - 1 - wp_it)
17
+ cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio))
18
+
19
+ min_lr, max_lr = cur_lr, cur_lr
20
+ min_wd, max_wd = wd, wd
21
+ for param_group in optimizer.param_groups:
22
+ scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned
23
+ min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr)
24
+ scaled_wd = param_group['weight_decay'] = wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned
25
+ min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd)
26
+ return min_lr, max_lr, min_wd, max_wd
27
+
28
+
29
+ def get_param_groups(model, nowd_keys=(), lr_scale=0.0):
30
+ using_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0.0 < lr_scale < 1.0
31
+ print(f'[get_ft_param_groups][lr decay] using_lr_scale={using_lr_scale}, ft_lr_scale={lr_scale}')
32
+ para_groups, para_groups_dbg = {}, {}
33
+
34
+ for name, para in model.named_parameters():
35
+ if not para.requires_grad:
36
+ continue # frozen weights
37
+ if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys):
38
+ wd_scale, group_name = 0., 'no_decay'
39
+ else:
40
+ wd_scale, group_name = 1., 'decay'
41
+
42
+ if using_lr_scale:
43
+ layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
44
+ group_name = f'layer{layer_id}_' + group_name
45
+ this_lr_scale = lr_scale ** scale_exp
46
+ dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
47
+ else:
48
+ this_lr_scale = 1
49
+ dbg = f'[no scale]'
50
+
51
+ if group_name not in para_groups:
52
+ para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': this_lr_scale}
53
+ para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': dbg}
54
+ para_groups[group_name]['params'].append(para)
55
+ para_groups_dbg[group_name]['params'].append(name)
56
+
57
+ for g in para_groups_dbg.values():
58
+ g['params'] = pformat(', '.join(g['params']), width=200)
59
+
60
+ print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n')
61
+ return list(para_groups.values())
spark/downstream_imagenet/main.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import datetime
8
+ import time
9
+
10
+ import torch
11
+ import torch.distributed as tdist
12
+ from timm.utils import ModelEmaV2
13
+ from torch.utils.tensorboard import SummaryWriter
14
+
15
+ from arg import get_args, FineTuneArgs
16
+ from models import ConvNeXt, ResNet
17
+ __for_timm_registration = ConvNeXt, ResNet
18
+ from lr_decay import lr_wd_annealing
19
+ from util import init_distributed_environ, create_model_opt, load_checkpoint, save_checkpoint
20
+ from data import create_classification_dataset
21
+
22
+
23
+ def main_ft():
24
+ world_size, global_rank, local_rank, device = init_distributed_environ()
25
+ args: FineTuneArgs = get_args(world_size, global_rank, local_rank, device)
26
+ print(f'initial args:\n{str(args)}')
27
+ args.log_epoch()
28
+
29
+ criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer = create_model_opt(args)
30
+ ep_start, performance_desc = load_checkpoint(args.resume_from, model_without_ddp, model_ema, optimizer)
31
+
32
+ if ep_start >= args.ep: # load from a complete checkpoint file
33
+ print(f' [*] [FT already done] Max/Last Acc: {performance_desc}')
34
+ else:
35
+ tb_lg = SummaryWriter(args.tb_lg_dir) if args.is_master else None
36
+ loader_train, iters_train, iterator_val, iters_val = create_classification_dataset(
37
+ args.data_path, args.img_size, args.rep_aug,
38
+ args.dataloader_workers, args.batch_size_per_gpu, args.world_size, args.global_rank
39
+ )
40
+
41
+ # train & eval
42
+ tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
43
+ max_acc = last_acc
44
+ max_acc_e = last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)[-1]
45
+ print(f'[fine-tune] initial acc={last_acc:.2f}, ema={last_acc_e:.2f}')
46
+
47
+ ep_eval = set(range(0, args.ep//3, 5)) | set(range(args.ep//3, args.ep))
48
+ print(f'[FT start] ep_eval={sorted(ep_eval)} ')
49
+ print(f'[FT start] from ep{ep_start}')
50
+
51
+ params_req_grad = [p for p in model.parameters() if p.requires_grad]
52
+ ft_start_time = time.time()
53
+ for ep in range(ep_start, args.ep):
54
+ ep_start_time = time.time()
55
+ if hasattr(loader_train, 'sampler') and hasattr(loader_train.sampler, 'set_epoch'):
56
+ loader_train.sampler.set_epoch(ep)
57
+ if 0 <= ep <= 3:
58
+ print(f'[loader_train.sampler.set_epoch({ep})]')
59
+
60
+ train_loss, train_acc = fine_tune_one_epoch(ep, args, tb_lg, loader_train, iters_train, criterion, mixup_fn, model, model_ema, optimizer, params_req_grad)
61
+ if ep in ep_eval:
62
+ eval_start_time = time.time()
63
+ tot_pred, last_acc = evaluate(args.device, iterator_val, iters_val, model)
64
+ tot_pred_e, last_acc_e = evaluate(args.device, iterator_val, iters_val, model_ema.module)
65
+ eval_cost = round(time.time() - eval_start_time, 2)
66
+ performance_desc = f'Max (Last) Acc: {max(max_acc, last_acc):.2f} ({last_acc:.2f} o {tot_pred}) EMA: {max(max_acc_e, last_acc_e):.2f} ({last_acc_e:.2f} o {tot_pred_e})'
67
+ states = model_without_ddp.state_dict(), model_ema.module.state_dict(), optimizer.state_dict()
68
+ if last_acc > max_acc:
69
+ max_acc = last_acc
70
+ save_checkpoint(f'{args.model}_1kfinetuned_best.pth', args, ep, performance_desc, *states)
71
+ if last_acc_e > max_acc_e:
72
+ max_acc_e = last_acc_e
73
+ save_checkpoint(f'{args.model}_1kfinetuned_best_ema.pth', args, ep, performance_desc, *states)
74
+ save_checkpoint(f'{args.model}_1kfinetuned_last.pth', args, ep, performance_desc, *states)
75
+ else:
76
+ eval_cost = '-'
77
+
78
+ ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
79
+ remain_secs = (args.ep-1 - ep) * ep_cost
80
+ remain_time = datetime.timedelta(seconds=round(remain_secs))
81
+ finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
82
+ print(f'[ep{ep}/{args.ep}] {performance_desc} Ep cost: {ep_cost}s, Ev cost: {eval_cost}, Remain: {remain_time}, Finish @ {finish_time}')
83
+ args.cur_ep = f'{ep + 1}/{args.ep}'
84
+ args.remain_time, args.finish_time = str(remain_time), str(finish_time)
85
+ args.train_loss, args.train_acc, args.best_val_acc = train_loss, train_acc, max(max_acc, max_acc_e)
86
+ args.log_epoch()
87
+
88
+ if args.is_master:
89
+ tb_lg.add_scalar(f'ft_train/ep_loss', train_loss, ep)
90
+ tb_lg.add_scalar(f'ft_eval/max_acc', max_acc, ep)
91
+ tb_lg.add_scalar(f'ft_eval/last_acc', last_acc, ep)
92
+ tb_lg.add_scalar(f'ft_eval/max_acc_ema', max_acc_e, ep)
93
+ tb_lg.add_scalar(f'ft_eval/last_acc_ema', last_acc_e, ep)
94
+ tb_lg.add_scalar(f'ft_z_burnout/rest_hours', round(remain_secs/60/60, 2), ep)
95
+ tb_lg.flush()
96
+
97
+ # finish fine-tuning
98
+ result_acc = max(max_acc, max_acc_e)
99
+ if args.is_master:
100
+ tb_lg.add_scalar('ft_result/result_acc', result_acc, ep_start)
101
+ tb_lg.add_scalar('ft_result/result_acc', result_acc, args.ep)
102
+ tb_lg.flush()
103
+ tb_lg.close()
104
+ print(f'final args:\n{str(args)}')
105
+ print('\n\n')
106
+ print(f' [*] [FT finished] {performance_desc} Total Cost: {(time.time() - ft_start_time) / 60 / 60:.1f}h\n')
107
+ print(f' [*] [FT finished] max(max_acc, max_acc_e)={result_acc} EMA better={max_acc_e>max_acc}')
108
+ print('\n\n')
109
+ time.sleep(10)
110
+
111
+ args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
112
+ args.log_epoch()
113
+
114
+
115
+ def fine_tune_one_epoch(ep, args: FineTuneArgs, tb_lg: SummaryWriter, loader_train, iters_train, criterion, mixup_fn, model, model_ema: ModelEmaV2, optimizer, params_req_grad):
116
+ model.train()
117
+ tot_loss = tot_acc = 0.0
118
+ log_freq = max(1, round(iters_train * 0.7))
119
+ ep_start_time = time.time()
120
+ for it, (inp, tar) in enumerate(loader_train):
121
+ # adjust lr and wd
122
+ cur_it = it + ep * iters_train
123
+ min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, cur_it, args.wp_ep * iters_train, args.ep * iters_train)
124
+
125
+ # forward
126
+ inp = inp.to(args.device, non_blocking=True)
127
+ raw_tar = tar = tar.to(args.device, non_blocking=True)
128
+ if mixup_fn is not None:
129
+ inp, tar, raw_tar = mixup_fn(inp, tar)
130
+ oup = model(inp)
131
+ pred = oup.data.argmax(dim=1)
132
+ if mixup_fn is None:
133
+ acc = pred.eq(tar).float().mean().item() * 100
134
+ tot_acc += acc
135
+ else:
136
+ acc = (pred.eq(raw_tar) | pred.eq(raw_tar.flip(0))).float().mean().item() * 100
137
+ tot_acc += acc
138
+
139
+ # backward
140
+ optimizer.zero_grad()
141
+ loss = criterion(oup, tar)
142
+ loss.backward()
143
+ loss = loss.item()
144
+ tot_loss += loss
145
+ if args.clip > 0:
146
+ orig_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
147
+ else:
148
+ orig_norm = None
149
+ optimizer.step()
150
+ model_ema.update(model)
151
+ torch.cuda.synchronize()
152
+
153
+ # log
154
+ if args.is_master and cur_it % log_freq == 0:
155
+ tb_lg.add_scalar(f'ft_train/it_loss', loss, cur_it)
156
+ tb_lg.add_scalar(f'ft_train/it_acc', acc, cur_it)
157
+ tb_lg.add_scalar(f'ft_hp/min_lr', min_lr, cur_it), tb_lg.add_scalar(f'ft_hp/max_lr', max_lr, cur_it)
158
+ tb_lg.add_scalar(f'ft_hp/min_wd', min_wd, cur_it), tb_lg.add_scalar(f'ft_hp/max_wd', max_wd, cur_it)
159
+ if orig_norm is not None:
160
+ tb_lg.add_scalar(f'ft_hp/orig_norm', orig_norm, cur_it)
161
+
162
+ if it in [3, iters_train//2, iters_train-1]:
163
+ remain_secs = (iters_train-1 - it) * (time.time() - ep_start_time) / (it + 1)
164
+ remain_time = datetime.timedelta(seconds=round(remain_secs))
165
+ print(f'[ep{ep} it{it:3d}/{iters_train}] L: {loss:.4f} Acc: {acc:.2f} lr: {min_lr:.1e}~{max_lr:.1e} Remain: {remain_time}')
166
+
167
+ return tot_loss / iters_train, tot_acc / iters_train
168
+
169
+
170
+ @torch.no_grad()
171
+ def evaluate(dev, iterator_val, iters_val, model):
172
+ training = model.training
173
+ model.train(False)
174
+ tot_pred, tot_correct = 0., 0.
175
+ for _ in range(iters_val):
176
+ inp, tar = next(iterator_val)
177
+ tot_pred += tar.shape[0]
178
+ inp = inp.to(dev, non_blocking=True)
179
+ tar = tar.to(dev, non_blocking=True)
180
+ oup = model(inp)
181
+ tot_correct += oup.argmax(dim=1).eq(tar).sum().item()
182
+ model.train(training)
183
+ t = torch.tensor([tot_pred, tot_correct]).to(dev)
184
+ tdist.all_reduce(t)
185
+ return t[0].item(), (t[1] / t[0]).item() * 100.
186
+
187
+
188
+ if __name__ == '__main__':
189
+ main_ft()
spark/downstream_imagenet/mixup.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file is a modified version of timm.data.Mixup
8
+ # Fixed error of "Batch size should be even when using this"
9
+
10
+ """ Mixup and Cutmix
11
+
12
+ Papers:
13
+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
14
+
15
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
16
+
17
+ Code Reference:
18
+ CutMix: https://github.com/clovaai/CutMix-PyTorch
19
+
20
+ Hacked together by / Copyright 2019, Ross Wightman
21
+ """
22
+ import numpy as np
23
+ import torch
24
+
25
+
26
+ def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
27
+ x = x.long().view(-1, 1)
28
+ return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
29
+
30
+
31
+ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
32
+ off_value = smoothing / num_classes
33
+ on_value = 1. - smoothing + off_value
34
+ y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
35
+ y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
36
+ return y1 * lam + y2 * (1. - lam)
37
+
38
+
39
+ def rand_bbox(img_shape, lam, margin=0., count=None):
40
+ """ Standard CutMix bounding-box
41
+ Generates a random square bbox based on lambda value. This impl includes
42
+ support for enforcing a border margin as percent of bbox dimensions.
43
+
44
+ Args:
45
+ img_shape (tuple): Image shape as tuple
46
+ lam (float): Cutmix lambda value
47
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
48
+ count (int): Number of bbox to generate
49
+ """
50
+ ratio = np.sqrt(1 - lam)
51
+ img_h, img_w = img_shape[-2:]
52
+ cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
53
+ margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
54
+ cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
55
+ cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
56
+ yl = np.clip(cy - cut_h // 2, 0, img_h)
57
+ yh = np.clip(cy + cut_h // 2, 0, img_h)
58
+ xl = np.clip(cx - cut_w // 2, 0, img_w)
59
+ xh = np.clip(cx + cut_w // 2, 0, img_w)
60
+ return yl, yh, xl, xh
61
+
62
+
63
+ def rand_bbox_minmax(img_shape, minmax, count=None):
64
+ """ Min-Max CutMix bounding-box
65
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
66
+ based on min/max percent values applied to each dimension of the input image.
67
+
68
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
69
+
70
+ Args:
71
+ img_shape (tuple): Image shape as tuple
72
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
73
+ count (int): Number of bbox to generate
74
+ """
75
+ assert len(minmax) == 2
76
+ img_h, img_w = img_shape[-2:]
77
+ cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
78
+ cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
79
+ yl = np.random.randint(0, img_h - cut_h, size=count)
80
+ xl = np.random.randint(0, img_w - cut_w, size=count)
81
+ yu = yl + cut_h
82
+ xu = xl + cut_w
83
+ return yl, yu, xl, xu
84
+
85
+
86
+ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
87
+ """ Generate bbox and apply lambda correction.
88
+ """
89
+ if ratio_minmax is not None:
90
+ yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
91
+ else:
92
+ yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
93
+ if correct_lam or ratio_minmax is not None:
94
+ bbox_area = (yu - yl) * (xu - xl)
95
+ lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
96
+ return (yl, yu, xl, xu), lam
97
+
98
+
99
+ class BatchMixup:
100
+ """ Mixup/Cutmix that applies different params to each element or whole batch
101
+
102
+ Args:
103
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
104
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
105
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
106
+ prob (float): probability of applying mixup or cutmix per batch or element
107
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
108
+ mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
109
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
110
+ label_smoothing (float): apply label smoothing to the mixed target tensor
111
+ num_classes (int): number of classes for target
112
+ """
113
+ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
114
+ mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
115
+ assert mode == 'batch'
116
+ self.mixup_alpha = mixup_alpha
117
+ self.cutmix_alpha = cutmix_alpha
118
+ self.cutmix_minmax = cutmix_minmax
119
+ if self.cutmix_minmax is not None:
120
+ assert len(self.cutmix_minmax) == 2
121
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
122
+ self.cutmix_alpha = 1.0
123
+ self.mix_prob = prob
124
+ self.switch_prob = switch_prob
125
+ self.label_smoothing = label_smoothing
126
+ self.num_classes = num_classes
127
+ self.mode = mode
128
+ self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
129
+ self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
130
+
131
+ def _params_per_batch(self):
132
+ lam = 1.
133
+ use_cutmix = False
134
+ if self.mixup_enabled and np.random.rand() < self.mix_prob:
135
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
136
+ use_cutmix = np.random.rand() < self.switch_prob
137
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
138
+ np.random.beta(self.mixup_alpha, self.mixup_alpha)
139
+ elif self.mixup_alpha > 0.:
140
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
141
+ elif self.cutmix_alpha > 0.:
142
+ use_cutmix = True
143
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
144
+ else:
145
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
146
+ lam = float(lam_mix)
147
+ return lam, use_cutmix
148
+
149
+ def _mix_batch(self, x):
150
+ lam, use_cutmix = self._params_per_batch()
151
+ if lam == 1.:
152
+ return 1.
153
+ if use_cutmix:
154
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
155
+ x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
156
+ x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
157
+ else:
158
+ x_flipped = x.flip(0).mul_(1. - lam)
159
+ x.mul_(lam).add_(x_flipped)
160
+ return lam
161
+
162
+ def __call__(self, x, raw_target):
163
+ if x.shape[0] % 2 == 1:
164
+ x, raw_target = torch.cat((x[:1], x), dim=0), torch.cat((raw_target[:1], raw_target), dim=0)
165
+ # assert len(x) % 2 == 0, 'Batch size should be even when using this'
166
+ lam = self._mix_batch(x)
167
+ target = mixup_target(raw_target, self.num_classes, lam, self.label_smoothing, x.device)
168
+ return x, target, raw_target
spark/downstream_imagenet/models/__init__.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+ from timm.data import Mixup
11
+ from timm.loss import BinaryCrossEntropy, SoftTargetCrossEntropy
12
+ from timm.models.layers import drop
13
+ from timm.models.resnet import ResNet
14
+
15
+ from .convnext_official import ConvNeXt
16
+
17
+
18
+ def convnext_get_layer_id_and_scale_exp(self: ConvNeXt, para_name: str):
19
+ N = 12 if len(self.stages[-2]) > 9 else 6
20
+ if para_name.startswith("downsample_layers"):
21
+ stage_id = int(para_name.split('.')[1])
22
+ if stage_id == 0:
23
+ layer_id = 0
24
+ elif stage_id == 1 or stage_id == 2:
25
+ layer_id = stage_id + 1
26
+ else: # stage_id == 3:
27
+ layer_id = N
28
+ elif para_name.startswith("stages"):
29
+ stage_id = int(para_name.split('.')[1])
30
+ block_id = int(para_name.split('.')[2])
31
+ if stage_id == 0 or stage_id == 1:
32
+ layer_id = stage_id + 1
33
+ elif stage_id == 2:
34
+ layer_id = 3 + block_id // 3
35
+ else: # stage_id == 3:
36
+ layer_id = N
37
+ else:
38
+ layer_id = N + 1 # after backbone
39
+
40
+ return layer_id, N + 1 - layer_id
41
+
42
+
43
+ def resnets_get_layer_id_and_scale_exp(self: ResNet, para_name: str):
44
+ # stages:
45
+ # 50 : [3, 4, 6, 3]
46
+ # 101 : [3, 4, 23, 3]
47
+ # 152 : [3, 8, 36, 3]
48
+ # 200 : [3, 24, 36, 3]
49
+ # eca269d: [3, 30, 48, 8]
50
+
51
+ L2, L3 = len(self.layer2), len(self.layer3)
52
+ if L2 == 4 and L3 == 6:
53
+ blk2, blk3 = 2, 3
54
+ elif L2 == 4 and L3 == 23:
55
+ blk2, blk3 = 2, 3
56
+ elif L2 == 8 and L3 == 36:
57
+ blk2, blk3 = 4, 4
58
+ elif L2 == 24 and L3 == 36:
59
+ blk2, blk3 = 4, 4
60
+ elif L2 == 30 and L3 == 48:
61
+ blk2, blk3 = 5, 6
62
+ else:
63
+ raise NotImplementedError
64
+
65
+ N2, N3 = math.ceil(L2 / blk2 - 1e-5), math.ceil(L3 / blk3 - 1e-5)
66
+ N = 2 + N2 + N3
67
+ if para_name.startswith('layer'): # 1, 2, 3, 4, 5
68
+ stage_id, block_id = int(para_name.split('.')[0][5:]), int(para_name.split('.')[1])
69
+ if stage_id == 1:
70
+ layer_id = 1
71
+ elif stage_id == 2:
72
+ layer_id = 2 + block_id // blk2 # 2, 3
73
+ elif stage_id == 3:
74
+ layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 r101: 4, 5, ..., 11
75
+ else: # == 4
76
+ layer_id = N # r50: 6 r101: 12
77
+ elif para_name.startswith('fc.'):
78
+ layer_id = N + 1 # r50: 7 r101: 13
79
+ else:
80
+ layer_id = 0
81
+
82
+ return layer_id, N + 1 - layer_id # r50: 0-7, 7-0 r101: 0-13, 13-0
83
+
84
+
85
+ def _ex_repr(self):
86
+ return ', '.join(
87
+ f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
88
+ for k, v in vars(self).items()
89
+ if not k.startswith('_') and k != 'training'
90
+ and not isinstance(v, (torch.nn.Module, torch.Tensor))
91
+ )
92
+
93
+
94
+ # IMPORTANT: update some member functions
95
+ __UPDATED = False
96
+ if not __UPDATED:
97
+ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, BinaryCrossEntropy, Mixup, drop.DropPath):
98
+ if hasattr(clz, 'extra_repr'):
99
+ clz.extra_repr = _ex_repr
100
+ else:
101
+ clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
102
+ ResNet.get_layer_id_and_scale_exp = resnets_get_layer_id_and_scale_exp
103
+ ConvNeXt.get_layer_id_and_scale_exp = convnext_get_layer_id_and_scale_exp
104
+ __UPDATED = True
spark/downstream_imagenet/models/convnext_official.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This file is exactly the same as: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.registry import register_model
14
+
15
+ class Block(nn.Module):
16
+ r""" ConvNeXt Block. There are two equivalent implementations:
17
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
18
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
19
+ We use (2) as we find it slightly faster in PyTorch
20
+
21
+ Args:
22
+ dim (int): Number of input channels.
23
+ drop_path (float): Stochastic depth rate. Default: 0.0
24
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
25
+ """
26
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.norm = LayerNorm(dim, eps=1e-6)
30
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
31
+ self.act = nn.GELU()
32
+ self.pwconv2 = nn.Linear(4 * dim, dim)
33
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
34
+ requires_grad=True) if layer_scale_init_value > 0 else None
35
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
36
+
37
+ def forward(self, x):
38
+ input = x
39
+ x = self.dwconv(x)
40
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
41
+ x = self.norm(x)
42
+ x = self.pwconv1(x)
43
+ x = self.act(x)
44
+ x = self.pwconv2(x)
45
+ if self.gamma is not None:
46
+ x = self.gamma * x
47
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
48
+
49
+ x = input + self.drop_path(x)
50
+ return x
51
+
52
+ class ConvNeXt(nn.Module):
53
+ r""" ConvNeXt
54
+ A PyTorch impl of : `A ConvNet for the 2020s` -
55
+ https://arxiv.org/pdf/2201.03545.pdf
56
+ Args:
57
+ in_chans (int): Number of input image channels. Default: 3
58
+ num_classes (int): Number of classes for classification head. Default: 1000
59
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
60
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
61
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
62
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
63
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
64
+ """
65
+ def __init__(self, in_chans=3, num_classes=1000,
66
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
67
+ layer_scale_init_value=1e-6, head_init_scale=1.,
68
+ ):
69
+ super().__init__()
70
+
71
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
72
+ stem = nn.Sequential(
73
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
74
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
75
+ )
76
+ self.downsample_layers.append(stem)
77
+ for i in range(3):
78
+ downsample_layer = nn.Sequential(
79
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
80
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
81
+ )
82
+ self.downsample_layers.append(downsample_layer)
83
+
84
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
85
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
86
+ cur = 0
87
+ for i in range(4):
88
+ stage = nn.Sequential(
89
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
90
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
91
+ )
92
+ self.stages.append(stage)
93
+ cur += depths[i]
94
+
95
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
96
+ self.head = nn.Linear(dims[-1], num_classes)
97
+
98
+ self.apply(self._init_weights)
99
+ self.head.weight.data.mul_(head_init_scale)
100
+ self.head.bias.data.mul_(head_init_scale)
101
+
102
+ def _init_weights(self, m):
103
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
104
+ trunc_normal_(m.weight, std=.02)
105
+ nn.init.constant_(m.bias, 0)
106
+
107
+ def forward_features(self, x):
108
+ for i in range(4):
109
+ x = self.downsample_layers[i](x)
110
+ x = self.stages[i](x)
111
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
112
+
113
+ def forward(self, x):
114
+ x = self.forward_features(x)
115
+ x = self.head(x)
116
+ return x
117
+
118
+ class LayerNorm(nn.Module):
119
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
120
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
121
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
122
+ with shape (batch_size, channels, height, width).
123
+ """
124
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
125
+ super().__init__()
126
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
127
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
128
+ self.eps = eps
129
+ self.data_format = data_format
130
+ if self.data_format not in ["channels_last", "channels_first"]:
131
+ raise NotImplementedError
132
+ self.normalized_shape = (normalized_shape, )
133
+
134
+ def forward(self, x):
135
+ if self.data_format == "channels_last":
136
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
137
+ elif self.data_format == "channels_first":
138
+ u = x.mean(1, keepdim=True)
139
+ s = (x - u).pow(2).mean(1, keepdim=True)
140
+ x = (x - u) / torch.sqrt(s + self.eps)
141
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
142
+ return x
143
+
144
+
145
+ model_urls = {
146
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
147
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
148
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
149
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
150
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
151
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
152
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
153
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
154
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
155
+ }
156
+
157
+ @register_model
158
+ def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
159
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
160
+ if pretrained:
161
+ url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
162
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
163
+ model.load_state_dict(checkpoint["model"])
164
+ return model
165
+
166
+ @register_model
167
+ def convnext_small(pretrained=False,in_22k=False, **kwargs):
168
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
169
+ if pretrained:
170
+ url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
171
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
172
+ model.load_state_dict(checkpoint["model"])
173
+ return model
174
+
175
+ @register_model
176
+ def convnext_base(pretrained=False, in_22k=False, **kwargs):
177
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
178
+ if pretrained:
179
+ url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
180
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
181
+ model.load_state_dict(checkpoint["model"])
182
+ return model
183
+
184
+ @register_model
185
+ def convnext_large(pretrained=False, in_22k=False, **kwargs):
186
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
187
+ if pretrained:
188
+ url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
189
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
190
+ model.load_state_dict(checkpoint["model"])
191
+ return model
192
+
193
+ @register_model
194
+ def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
195
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
196
+ if pretrained:
197
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
198
+ url = model_urls['convnext_xlarge_22k']
199
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
200
+ model.load_state_dict(checkpoint["model"])
201
+ return model
spark/downstream_imagenet/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ Pillow
3
+ typed-argument-parser
4
+ timm==0.5.4
5
+ tensorboardx
spark/downstream_imagenet/util.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import datetime
8
+ import os
9
+ import sys
10
+ from functools import partial
11
+ from typing import List, Tuple, Callable
12
+
13
+ import pytz
14
+ import torch
15
+ import torch.distributed as tdist
16
+ import torch.multiprocessing as tmp
17
+ from timm import create_model
18
+ from timm.loss import SoftTargetCrossEntropy, BinaryCrossEntropy
19
+ from timm.optim import AdamW, Lamb
20
+ from timm.utils import ModelEmaV2
21
+ from torch.nn.parallel import DistributedDataParallel
22
+ from torch.optim.optimizer import Optimizer
23
+
24
+ from arg import FineTuneArgs
25
+ from downstream_imagenet.mixup import BatchMixup
26
+ from lr_decay import get_param_groups
27
+
28
+
29
+ def time_str(for_dirname=False):
30
+ return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('%m-%d_%H-%M-%S' if for_dirname else '[%m-%d %H:%M:%S]')
31
+
32
+
33
+ def init_distributed_environ():
34
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
35
+ if tmp.get_start_method(allow_none=True) is None:
36
+ tmp.set_start_method('spawn')
37
+ global_rank, num_gpus = int(os.environ.get('RANK', 'error')), torch.cuda.device_count()
38
+ local_rank = global_rank % num_gpus
39
+ torch.cuda.set_device(local_rank)
40
+
41
+ tdist.init_process_group(backend='nccl')
42
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
43
+ torch.backends.cudnn.benchmark = True
44
+ torch.backends.cudnn.deterministic = False
45
+
46
+ # print only when local_rank == 0 or print(..., force=True)
47
+ import builtins as __builtin__
48
+ builtin_print = __builtin__.print
49
+
50
+ def prt(msg, *args, **kwargs):
51
+ force = kwargs.pop('force', False)
52
+ if local_rank == 0 or force:
53
+ f_back = sys._getframe().f_back
54
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
55
+ builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=> {msg}', *args, **kwargs)
56
+
57
+ __builtin__.print = prt
58
+ tdist.barrier()
59
+ return tdist.get_world_size(), global_rank, local_rank, torch.empty(1).cuda().device
60
+
61
+
62
+ def create_model_opt(args: FineTuneArgs) -> Tuple[torch.nn.Module, Callable, torch.nn.Module, DistributedDataParallel, ModelEmaV2, Optimizer]:
63
+ num_classes = 1000
64
+ model_without_ddp: torch.nn.Module = create_model(args.model, num_classes=num_classes, drop_path_rate=args.drop_path).to(args.device)
65
+ model_para = f'{sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) / 1e6:.1f}M'
66
+ # create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
67
+ model_ema = ModelEmaV2(model_without_ddp, decay=args.ema, device=args.device)
68
+ if args.sbn:
69
+ model_without_ddp = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_without_ddp)
70
+ print(f'[model={args.model}] [#para={model_para}, drop_path={args.drop_path}, ema={args.ema}] {model_without_ddp}\n')
71
+ model = DistributedDataParallel(model_without_ddp, device_ids=[args.local_rank], find_unused_parameters=False, broadcast_buffers=False)
72
+ model.train()
73
+ opt_cls = {
74
+ 'adam': AdamW, 'adamw': AdamW,
75
+ 'lamb': partial(Lamb, max_grad_norm=1e7, always_adapt=True, bias_correction=False),
76
+ }
77
+ param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'}, lr_scale=args.lr_scale)
78
+ # param_groups[0] is like this: {'params': List[nn.Parameters], 'lr': float, 'lr_scale': float, 'weight_decay': float, 'weight_decay_scale': float}
79
+ optimizer = opt_cls[args.opt](param_groups, lr=args.lr, weight_decay=0)
80
+ print(f'[optimizer={type(optimizer)}]')
81
+ mixup_fn = BatchMixup(
82
+ mixup_alpha=args.mixup, cutmix_alpha=1.0, cutmix_minmax=None,
83
+ prob=1.0, switch_prob=0.5, mode='batch',
84
+ label_smoothing=0.1, num_classes=num_classes
85
+ )
86
+ mixup_fn.mixup_enabled = args.mixup > 0.0
87
+ if 'lamb' in args.opt:
88
+ # label smoothing is solved in AdaptiveMixup with `label_smoothing`, so here smoothing=0
89
+ criterion = BinaryCrossEntropy(smoothing=0, target_threshold=None)
90
+ else:
91
+ criterion = SoftTargetCrossEntropy()
92
+ print(f'[loss_fn] {criterion}')
93
+ print(f'[mixup_fn] {mixup_fn}')
94
+ return criterion, mixup_fn, model_without_ddp, model, model_ema, optimizer
95
+
96
+
97
+ def load_checkpoint(resume_from, model_without_ddp, ema_module, optimizer):
98
+ if len(resume_from) == 0 or not os.path.exists(resume_from):
99
+ raise AttributeError(f'ckpt `{resume_from}` not found!')
100
+ # return 0, '[no performance_desc]'
101
+ print(f'[try to resume from file `{resume_from}`]')
102
+ checkpoint = torch.load(resume_from, map_location='cpu')
103
+ assert checkpoint.get('is_pretrain', False) == False, 'Please do not use `*_withdecoder_1kpretrained_spark_style.pth`, which is ONLY for resuming the pretraining. Use `*_1kpretrained_timm_style.pth` or `*_1kfinetuned*.pth` instead.'
104
+
105
+ ep_start, performance_desc = checkpoint.get('epoch', -1) + 1, checkpoint.get('performance_desc', '[no performance_desc]')
106
+ missing, unexpected = model_without_ddp.load_state_dict(checkpoint.get('module', checkpoint), strict=False)
107
+ print(f'[load_checkpoint] missing_keys={missing}')
108
+ print(f'[load_checkpoint] unexpected_keys={unexpected}')
109
+ print(f'[load_checkpoint] ep_start={ep_start}, performance_desc={performance_desc}')
110
+
111
+ if 'optimizer' in checkpoint:
112
+ optimizer.load_state_dict(checkpoint['optimizer'])
113
+ if 'ema' in checkpoint:
114
+ ema_module.load_state_dict(checkpoint['ema'])
115
+ return ep_start, performance_desc
116
+
117
+
118
+ def save_checkpoint(save_to, args, epoch, performance_desc, model_without_ddp_state, ema_state, optimizer_state):
119
+ checkpoint_path = os.path.join(args.exp_dir, save_to)
120
+ if args.is_local_master:
121
+ to_save = {
122
+ 'args': str(args),
123
+ 'arch': args.model,
124
+ 'epoch': epoch,
125
+ 'performance_desc': performance_desc,
126
+ 'module': model_without_ddp_state,
127
+ 'ema': ema_state,
128
+ 'optimizer': optimizer_state,
129
+ 'is_pretrain': False,
130
+ }
131
+ torch.save(to_save, checkpoint_path)
spark/downstream_mmdet/README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## About code isolation
2
+
3
+ This `downstream_mmdet` is isolated from pre-training codes. One can treat this `downstream_mmdet` as an independent codebase 🛠️.
4
+
5
+ ## Fine-tuned ConvNeXt-B weights, log files, and performance
6
+
7
+
8
+ <div align="center">
9
+
10
+ [[`weights (pre-trained by SparK)`](https://drive.google.com/file/d/1ZjWbqI1qoBcqeQijI5xX9E-YNkxpJcYV/view?usp=share_link)]
11
+ [[`weights (fine-tuned on COCO)`](https://drive.google.com/file/d/1t10dmzg5KOO27o2yIglK-gQepB5gR4zR/view?usp=share_link)]
12
+ [[`log.json`](https://drive.google.com/file/d/1TuNboXl1qwjf1tggZ3QOssI67uU7Jtig/view?usp=share_link)]
13
+ [[`log`](https://drive.google.com/file/d/1JY5CkL_MX08zJ8P1FBIeC60OJsuIiyZc/view?usp=sharing)]
14
+ </div>
15
+
16
+
17
+ <p align="center">
18
+ <img src="https://user-images.githubusercontent.com/39692511/211497396-cd031318-ef54-45a4-a283-cd9810c15603.png" width=80%>
19
+ <p>
20
+
21
+
22
+ ## Installation [MMDetection with commit 6a979e2](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d) before fine-tuning ConvNeXt on COCO
23
+
24
+ We refer to the codebases of [ConvNeXt](https://github.com/facebookresearch/ConvNeXt/tree/048efcea897d999aed302f2639b6270aedf8d4c8) and [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/6a979e2164e3fb0de0ca2546545013a4d71b2f7d).
25
+ Please refer to [README.md](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/6a979e2164e3fb0de0ca2546545013a4d71b2f7d/README.md) for installation and dataset preparation instructions.
26
+
27
+ Note the COCO dataset folder should be at `downstream_mmdet/data/coco`.
28
+ The folder should follow the directory structure requried by `MMDetection`, which should look like this:
29
+ ```
30
+ downstream_mmdet/data/coco:
31
+ annotations/:
32
+ captions_train2017.json captions_val2017.json
33
+ instances_train2017.json instances_val2017.json
34
+ person_keypoints_train2017.json person_keypoints_val2017.json
35
+ train2017/:
36
+ a_lot_images.jpg
37
+ val2017/:
38
+ a_lot_images.jpg
39
+ ```
40
+
41
+
42
+ ### Training
43
+
44
+ To train a detector with pre-trained models, run:
45
+ ```
46
+ # single-gpu training
47
+ python tools/train.py <CONFIG_FILE> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
48
+
49
+ # multi-gpu training
50
+ tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
51
+ ```
52
+ For example, to train a Mask R-CNN model with a SparK pretrained `ConvNeXt-B` backbone and 4 gpus, run:
53
+ ```
54
+ tools/dist_train.sh configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py 4 \
55
+ --cfg-options model.pretrained=/some/path/to/official_convnext_base_1kpretrained.pth
56
+ ```
57
+
58
+ The Mask R-CNN 3x fine-tuning config file can be found at [`configs/convnext_spark`](configs/convnext_spark). This config is basically a copy of [https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py](https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py).
59
+
60
+ ### Inference
61
+ ```
62
+ # single-gpu testing
63
+ python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm
64
+
65
+ # multi-gpu testing
66
+ tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm
67
+ ```
68
+
69
+ ## Acknowledgment
70
+
71
+ We appreciate these useful codebases:
72
+
73
+ - [MMDetection](https://github.com/open-mmlab/mmdetection)
74
+ - [ConvNeXt](https://github.com/facebookresearch/ConvNeXt)
75
+ - [Swin-Transformer-Object-Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection)
76
+
spark/downstream_mmdet/configs/_base_/default_runtime.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_config = dict(interval=1)
2
+ # yapf:disable
3
+ log_config = dict(
4
+ interval=50,
5
+ hooks=[
6
+ dict(type='CustomizedTextLoggerHook'),
7
+ # dict(type='TensorboardLoggerHook')
8
+ ])
9
+ # yapf:enable
10
+ custom_hooks = [dict(type='NumClassCheckHook')]
11
+
12
+ dist_params = dict(backend='nccl')
13
+ log_level = 'INFO'
14
+ load_from = None
15
+ resume_from = None
16
+ workflow = [('train', 1)]
spark/downstream_mmdet/configs/_base_/models/cascade_mask_rcnn_convnext_fpn.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ # model settings
10
+ model = dict(
11
+ type='CascadeRCNN',
12
+ pretrained=None,
13
+ backbone=dict(
14
+ type='ConvNeXt',
15
+ in_chans=3,
16
+ depths=[3, 3, 9, 3],
17
+ dims=[96, 192, 384, 768],
18
+ drop_path_rate=0.2,
19
+ layer_scale_init_value=1e-6,
20
+ out_indices=[0, 1, 2, 3],
21
+ ),
22
+ neck=dict(
23
+ type='FPN',
24
+ in_channels=[128, 256, 512, 1024],
25
+ out_channels=256,
26
+ num_outs=5),
27
+ rpn_head=dict(
28
+ type='RPNHead',
29
+ in_channels=256,
30
+ feat_channels=256,
31
+ anchor_generator=dict(
32
+ type='AnchorGenerator',
33
+ scales=[8],
34
+ ratios=[0.5, 1.0, 2.0],
35
+ strides=[4, 8, 16, 32, 64]),
36
+ bbox_coder=dict(
37
+ type='DeltaXYWHBBoxCoder',
38
+ target_means=[.0, .0, .0, .0],
39
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
40
+ loss_cls=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
42
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
43
+ roi_head=dict(
44
+ type='CascadeRoIHead',
45
+ num_stages=3,
46
+ stage_loss_weights=[1, 0.5, 0.25],
47
+ bbox_roi_extractor=dict(
48
+ type='SingleRoIExtractor',
49
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
50
+ out_channels=256,
51
+ featmap_strides=[4, 8, 16, 32]),
52
+ bbox_head=[
53
+ dict(
54
+ type='Shared2FCBBoxHead',
55
+ in_channels=256,
56
+ fc_out_channels=1024,
57
+ roi_feat_size=7,
58
+ num_classes=80,
59
+ bbox_coder=dict(
60
+ type='DeltaXYWHBBoxCoder',
61
+ target_means=[0., 0., 0., 0.],
62
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
63
+ reg_class_agnostic=True,
64
+ loss_cls=dict(
65
+ type='CrossEntropyLoss',
66
+ use_sigmoid=False,
67
+ loss_weight=1.0),
68
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
69
+ loss_weight=1.0)),
70
+ dict(
71
+ type='Shared2FCBBoxHead',
72
+ in_channels=256,
73
+ fc_out_channels=1024,
74
+ roi_feat_size=7,
75
+ num_classes=80,
76
+ bbox_coder=dict(
77
+ type='DeltaXYWHBBoxCoder',
78
+ target_means=[0., 0., 0., 0.],
79
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
80
+ reg_class_agnostic=True,
81
+ loss_cls=dict(
82
+ type='CrossEntropyLoss',
83
+ use_sigmoid=False,
84
+ loss_weight=1.0),
85
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
86
+ loss_weight=1.0)),
87
+ dict(
88
+ type='Shared2FCBBoxHead',
89
+ in_channels=256,
90
+ fc_out_channels=1024,
91
+ roi_feat_size=7,
92
+ num_classes=80,
93
+ bbox_coder=dict(
94
+ type='DeltaXYWHBBoxCoder',
95
+ target_means=[0., 0., 0., 0.],
96
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
97
+ reg_class_agnostic=True,
98
+ loss_cls=dict(
99
+ type='CrossEntropyLoss',
100
+ use_sigmoid=False,
101
+ loss_weight=1.0),
102
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
103
+ ],
104
+ mask_roi_extractor=dict(
105
+ type='SingleRoIExtractor',
106
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
107
+ out_channels=256,
108
+ featmap_strides=[4, 8, 16, 32]),
109
+ mask_head=dict(
110
+ type='FCNMaskHead',
111
+ num_convs=4,
112
+ in_channels=256,
113
+ conv_out_channels=256,
114
+ num_classes=80,
115
+ loss_mask=dict(
116
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
117
+ # model training and testing settings
118
+ train_cfg = dict(
119
+ rpn=dict(
120
+ assigner=dict(
121
+ type='MaxIoUAssigner',
122
+ pos_iou_thr=0.7,
123
+ neg_iou_thr=0.3,
124
+ min_pos_iou=0.3,
125
+ match_low_quality=True,
126
+ ignore_iof_thr=-1),
127
+ sampler=dict(
128
+ type='RandomSampler',
129
+ num=256,
130
+ pos_fraction=0.5,
131
+ neg_pos_ub=-1,
132
+ add_gt_as_proposals=False),
133
+ allowed_border=0,
134
+ pos_weight=-1,
135
+ debug=False),
136
+ rpn_proposal=dict(
137
+ nms_across_levels=False,
138
+ nms_pre=2000,
139
+ nms_post=2000,
140
+ max_per_img=2000,
141
+ nms=dict(type='nms', iou_threshold=0.7),
142
+ min_bbox_size=0),
143
+ rcnn=[
144
+ dict(
145
+ assigner=dict(
146
+ type='MaxIoUAssigner',
147
+ pos_iou_thr=0.5,
148
+ neg_iou_thr=0.5,
149
+ min_pos_iou=0.5,
150
+ match_low_quality=False,
151
+ ignore_iof_thr=-1),
152
+ sampler=dict(
153
+ type='RandomSampler',
154
+ num=512,
155
+ pos_fraction=0.25,
156
+ neg_pos_ub=-1,
157
+ add_gt_as_proposals=True),
158
+ mask_size=28,
159
+ pos_weight=-1,
160
+ debug=False),
161
+ dict(
162
+ assigner=dict(
163
+ type='MaxIoUAssigner',
164
+ pos_iou_thr=0.6,
165
+ neg_iou_thr=0.6,
166
+ min_pos_iou=0.6,
167
+ match_low_quality=False,
168
+ ignore_iof_thr=-1),
169
+ sampler=dict(
170
+ type='RandomSampler',
171
+ num=512,
172
+ pos_fraction=0.25,
173
+ neg_pos_ub=-1,
174
+ add_gt_as_proposals=True),
175
+ mask_size=28,
176
+ pos_weight=-1,
177
+ debug=False),
178
+ dict(
179
+ assigner=dict(
180
+ type='MaxIoUAssigner',
181
+ pos_iou_thr=0.7,
182
+ neg_iou_thr=0.7,
183
+ min_pos_iou=0.7,
184
+ match_low_quality=False,
185
+ ignore_iof_thr=-1),
186
+ sampler=dict(
187
+ type='RandomSampler',
188
+ num=512,
189
+ pos_fraction=0.25,
190
+ neg_pos_ub=-1,
191
+ add_gt_as_proposals=True),
192
+ mask_size=28,
193
+ pos_weight=-1,
194
+ debug=False)
195
+ ]),
196
+ test_cfg = dict(
197
+ rpn=dict(
198
+ nms_across_levels=False,
199
+ nms_pre=1000,
200
+ nms_post=1000,
201
+ max_per_img=1000,
202
+ nms=dict(type='nms', iou_threshold=0.7),
203
+ min_bbox_size=0),
204
+ rcnn=dict(
205
+ score_thr=0.05,
206
+ nms=dict(type='nms', iou_threshold=0.5),
207
+ max_per_img=100,
208
+ mask_thr_binary=0.5)))
spark/downstream_mmdet/configs/_base_/models/mask_rcnn_convnext_fpn.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ # model settings
10
+ model = dict(
11
+ type='MaskRCNN',
12
+ pretrained=None,
13
+ backbone=dict(
14
+ type='ConvNeXt',
15
+ in_chans=3,
16
+ depths=[3, 3, 9, 3],
17
+ dims=[96, 192, 384, 768],
18
+ drop_path_rate=0.2,
19
+ layer_scale_init_value=1e-6,
20
+ out_indices=[0, 1, 2, 3],
21
+ ),
22
+ neck=dict(
23
+ type='FPN',
24
+ in_channels=[128, 256, 512, 1024],
25
+ out_channels=256,
26
+ num_outs=5),
27
+ rpn_head=dict(
28
+ type='RPNHead',
29
+ in_channels=256,
30
+ feat_channels=256,
31
+ anchor_generator=dict(
32
+ type='AnchorGenerator',
33
+ scales=[8],
34
+ ratios=[0.5, 1.0, 2.0],
35
+ strides=[4, 8, 16, 32, 64]),
36
+ bbox_coder=dict(
37
+ type='DeltaXYWHBBoxCoder',
38
+ target_means=[.0, .0, .0, .0],
39
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
40
+ loss_cls=dict(
41
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
42
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
43
+ roi_head=dict(
44
+ type='StandardRoIHead',
45
+ bbox_roi_extractor=dict(
46
+ type='SingleRoIExtractor',
47
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
48
+ out_channels=256,
49
+ featmap_strides=[4, 8, 16, 32]),
50
+ bbox_head=dict(
51
+ type='Shared2FCBBoxHead',
52
+ in_channels=256,
53
+ fc_out_channels=1024,
54
+ roi_feat_size=7,
55
+ num_classes=80,
56
+ bbox_coder=dict(
57
+ type='DeltaXYWHBBoxCoder',
58
+ target_means=[0., 0., 0., 0.],
59
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
60
+ reg_class_agnostic=False,
61
+ loss_cls=dict(
62
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
63
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
64
+ mask_roi_extractor=dict(
65
+ type='SingleRoIExtractor',
66
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
67
+ out_channels=256,
68
+ featmap_strides=[4, 8, 16, 32]),
69
+ mask_head=dict(
70
+ type='FCNMaskHead',
71
+ num_convs=4,
72
+ in_channels=256,
73
+ conv_out_channels=256,
74
+ num_classes=80,
75
+ loss_mask=dict(
76
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
77
+ # model training and testing settings
78
+ train_cfg=dict(
79
+ rpn=dict(
80
+ assigner=dict(
81
+ type='MaxIoUAssigner',
82
+ pos_iou_thr=0.7,
83
+ neg_iou_thr=0.3,
84
+ min_pos_iou=0.3,
85
+ match_low_quality=True,
86
+ ignore_iof_thr=-1),
87
+ sampler=dict(
88
+ type='RandomSampler',
89
+ num=256,
90
+ pos_fraction=0.5,
91
+ neg_pos_ub=-1,
92
+ add_gt_as_proposals=False),
93
+ allowed_border=-1,
94
+ pos_weight=-1,
95
+ debug=False),
96
+ rpn_proposal=dict(
97
+ nms_pre=2000,
98
+ max_per_img=1000,
99
+ nms=dict(type='nms', iou_threshold=0.7),
100
+ min_bbox_size=0),
101
+ rcnn=dict(
102
+ assigner=dict(
103
+ type='MaxIoUAssigner',
104
+ pos_iou_thr=0.5,
105
+ neg_iou_thr=0.5,
106
+ min_pos_iou=0.5,
107
+ match_low_quality=True,
108
+ ignore_iof_thr=-1),
109
+ sampler=dict(
110
+ type='RandomSampler',
111
+ num=512,
112
+ pos_fraction=0.25,
113
+ neg_pos_ub=-1,
114
+ add_gt_as_proposals=True),
115
+ mask_size=28,
116
+ pos_weight=-1,
117
+ debug=False)),
118
+ test_cfg=dict(
119
+ rpn=dict(
120
+ nms_pre=1000,
121
+ max_per_img=1000,
122
+ nms=dict(type='nms', iou_threshold=0.7),
123
+ min_bbox_size=0),
124
+ rcnn=dict(
125
+ score_thr=0.05,
126
+ nms=dict(type='nms', iou_threshold=0.5),
127
+ max_per_img=100,
128
+ mask_thr_binary=0.5)))
spark/downstream_mmdet/configs/convnext_spark/mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ We directly take the ConvNeXt-T+MaskRCNN 3x recipe from https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/mask_rcnn_convnext_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco_in1k.py
3
+ And we modify this ConvNeXt-T+MaskRCNN 3x recipe to our ConvNeXt-B+MaskRCNN 3x recipe.
4
+ The modifications (commented as [modified] below) are according to:
5
+ - 1. tiny-to-base: (some configs of ConvNext-T are updated to those of ConvNext-B, referring to https://github.com/facebookresearch/ConvNeXt/blob/main/object_detection/configs/convnext/cascade_mask_rcnn_convnext_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco_in22k.py)
6
+ - model.backbone.{depths, dims, drop_path_rate}
7
+ - models.neck
8
+ - optimizer.paramwise_cfg.num_layers
9
+
10
+ - 2. our paper (https://openreview.net/forum?id=NRxydtWup1S, or https://arxiv.org/abs/2301.03580):
11
+ - LR layer decay (optimizer.paramwise_cfg.decay_rate): 0.65
12
+ - LR scheduled ratio (lr_config.gamma): 0.2
13
+ - Learning rate (optimizer.lr): 0.0002
14
+ - optimizer_config.use_fp16: False (we just use fp32 by default; actually we didn't test the performance of using fp16)
15
+ """
16
+
17
+ _base_ = [
18
+ '../_base_/models/mask_rcnn_convnext_fpn.py',
19
+ '../_base_/datasets/coco_instance.py',
20
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
21
+ ]
22
+
23
+ model = dict(
24
+ backbone=dict(
25
+ in_chans=3,
26
+ depths=[3, 3, 27, 3], # [modified] according to tiny-to-base
27
+ dims=[128, 256, 512, 1024], # [modified] according to tiny-to-base
28
+ drop_path_rate=0.5, # [modified] according to tiny-to-base
29
+ layer_scale_init_value=1.0,
30
+ out_indices=[0, 1, 2, 3],
31
+ ),
32
+ neck=dict(in_channels=[128, 256, 512, 1024])) # [modified] according to tiny-to-base
33
+
34
+ img_norm_cfg = dict(
35
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
36
+
37
+ # augmentation strategy originates from DETR / Sparse RCNN
38
+ train_pipeline = [
39
+ dict(type='LoadImageFromFile'),
40
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
41
+ dict(type='RandomFlip', flip_ratio=0.5),
42
+ dict(type='AutoAugment',
43
+ policies=[
44
+ [
45
+ dict(type='Resize',
46
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
47
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
48
+ (736, 1333), (768, 1333), (800, 1333)],
49
+ multiscale_mode='value',
50
+ keep_ratio=True)
51
+ ],
52
+ [
53
+ dict(type='Resize',
54
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
55
+ multiscale_mode='value',
56
+ keep_ratio=True),
57
+ dict(type='RandomCrop',
58
+ crop_type='absolute_range',
59
+ crop_size=(384, 600),
60
+ allow_negative_crop=True),
61
+ dict(type='Resize',
62
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
63
+ (576, 1333), (608, 1333), (640, 1333),
64
+ (672, 1333), (704, 1333), (736, 1333),
65
+ (768, 1333), (800, 1333)],
66
+ multiscale_mode='value',
67
+ override=True,
68
+ keep_ratio=True)
69
+ ]
70
+ ]),
71
+ dict(type='Normalize', **img_norm_cfg),
72
+ dict(type='Pad', size_divisor=32),
73
+ dict(type='DefaultFormatBundle'),
74
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
75
+ ]
76
+ data = dict(train=dict(pipeline=train_pipeline))
77
+
78
+ optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
79
+ lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05, # [modified] according to our paper
80
+ paramwise_cfg={'decay_rate': 0.65, # [modified] according to our paper
81
+ 'decay_type': 'layer_wise',
82
+ 'num_layers': 12}) # [modified] according to tiny-to-base
83
+ lr_config = dict(step=[27, 33], gamma=0.2) # [modified] according to our paper
84
+ runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
85
+
86
+ # do not use mmdet version fp16
87
+ fp16 = None
88
+ optimizer_config = dict(
89
+ type="DistOptimizerHook",
90
+ update_interval=1,
91
+ grad_clip=None,
92
+ coalesce=True,
93
+ bucket_size_mb=-1,
94
+ use_fp16=False, # [modified] True => False
95
+ )
spark/downstream_mmdet/mmcv_custom/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ # -*- coding: utf-8 -*-
10
+
11
+ from .checkpoint import load_checkpoint
12
+ from .layer_decay_optimizer_constructor import LearningRateDecayOptimizerConstructor
13
+ from .customized_text import CustomizedTextLoggerHook
14
+
15
+ __all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'CustomizedTextLoggerHook']
spark/downstream_mmdet/mmcv_custom/customized_text.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import datetime
10
+ from collections import OrderedDict
11
+
12
+ import torch
13
+
14
+ import mmcv
15
+ from mmcv.runner import HOOKS
16
+ from mmcv.runner import TextLoggerHook
17
+
18
+
19
+ @HOOKS.register_module()
20
+ class CustomizedTextLoggerHook(TextLoggerHook):
21
+ """Customized Text Logger hook.
22
+
23
+ This logger prints out both lr and layer_0_lr.
24
+
25
+ """
26
+
27
+ def _log_info(self, log_dict, runner):
28
+ # print exp name for users to distinguish experiments
29
+ # at every ``interval_exp_name`` iterations and the end of each epoch
30
+ if runner.meta is not None and 'exp_name' in runner.meta:
31
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
32
+ self.by_epoch and self.end_of_epoch(runner)):
33
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
34
+ runner.logger.info(exp_info)
35
+
36
+ if log_dict['mode'] == 'train':
37
+ lr_str = {}
38
+ for lr_type in ['lr', 'layer_0_lr']:
39
+ if isinstance(log_dict[lr_type], dict):
40
+ lr_str[lr_type] = []
41
+ for k, val in log_dict[lr_type].items():
42
+ lr_str.append(f'{lr_type}_{k}: {val:.3e}')
43
+ lr_str[lr_type] = ' '.join(lr_str)
44
+ else:
45
+ lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}'
46
+
47
+ # by epoch: Epoch [4][100/1000]
48
+ # by iter: Iter [100/100000]
49
+ if self.by_epoch:
50
+ log_str = f'Epoch [{log_dict["epoch"]}]' \
51
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
52
+ else:
53
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
54
+ log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, '
55
+
56
+ if 'time' in log_dict.keys():
57
+ self.time_sec_tot += (log_dict['time'] * self.interval)
58
+ time_sec_avg = self.time_sec_tot / (
59
+ runner.iter - self.start_iter + 1)
60
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
61
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
62
+ log_str += f'eta: {eta_str}, '
63
+ log_str += f'time: {log_dict["time"]:.3f}, ' \
64
+ f'data_time: {log_dict["data_time"]:.3f}, '
65
+ # statistic memory
66
+ if torch.cuda.is_available():
67
+ log_str += f'memory: {log_dict["memory"]}, '
68
+ else:
69
+ # val/test time
70
+ # here 1000 is the length of the val dataloader
71
+ # by epoch: Epoch[val] [4][1000]
72
+ # by iter: Iter[val] [1000]
73
+ if self.by_epoch:
74
+ log_str = f'Epoch({log_dict["mode"]}) ' \
75
+ f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
76
+ else:
77
+ log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
78
+
79
+ log_items = []
80
+ for name, val in log_dict.items():
81
+ # TODO: resolve this hack
82
+ # these items have been in log_str
83
+ if name in [
84
+ 'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time',
85
+ 'memory', 'epoch'
86
+ ]:
87
+ continue
88
+ if isinstance(val, float):
89
+ val = f'{val:.4f}'
90
+ log_items.append(f'{name}: {val}')
91
+ log_str += ', '.join(log_items)
92
+
93
+ runner.logger.info(log_str)
94
+
95
+
96
+ def log(self, runner):
97
+ if 'eval_iter_num' in runner.log_buffer.output:
98
+ # this doesn't modify runner.iter and is regardless of by_epoch
99
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
100
+ else:
101
+ cur_iter = self.get_iter(runner, inner_iter=True)
102
+
103
+ log_dict = OrderedDict(
104
+ mode=self.get_mode(runner),
105
+ epoch=self.get_epoch(runner),
106
+ iter=cur_iter)
107
+
108
+ # record lr and layer_0_lr
109
+ cur_lr = runner.current_lr()
110
+ if isinstance(cur_lr, list):
111
+ log_dict['layer_0_lr'] = min(cur_lr)
112
+ log_dict['lr'] = max(cur_lr)
113
+ else:
114
+ assert isinstance(cur_lr, dict)
115
+ log_dict['lr'], log_dict['layer_0_lr'] = {}, {}
116
+ for k, lr_ in cur_lr.items():
117
+ assert isinstance(lr_, list)
118
+ log_dict['layer_0_lr'].update({k: min(lr_)})
119
+ log_dict['lr'].update({k: max(lr_)})
120
+
121
+ if 'time' in runner.log_buffer.output:
122
+ # statistic memory
123
+ if torch.cuda.is_available():
124
+ log_dict['memory'] = self._get_max_memory(runner)
125
+
126
+ log_dict = dict(log_dict, **runner.log_buffer.output)
127
+
128
+ self._log_info(log_dict, runner)
129
+ self._dump_log(log_dict, runner)
130
+ return log_dict
spark/downstream_mmdet/mmcv_custom/layer_decay_optimizer_constructor.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import json
10
+ from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
11
+ from mmcv.runner import get_dist_info
12
+
13
+
14
+ def get_num_layer_layer_wise(var_name, num_max_layer=12):
15
+
16
+ if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
17
+ return 0
18
+ elif var_name.startswith("backbone.downsample_layers"):
19
+ stage_id = int(var_name.split('.')[2])
20
+ if stage_id == 0:
21
+ layer_id = 0
22
+ elif stage_id == 1:
23
+ layer_id = 2
24
+ elif stage_id == 2:
25
+ layer_id = 3
26
+ elif stage_id == 3:
27
+ layer_id = num_max_layer
28
+ return layer_id
29
+ elif var_name.startswith("backbone.stages"):
30
+ stage_id = int(var_name.split('.')[2])
31
+ block_id = int(var_name.split('.')[3])
32
+ if stage_id == 0:
33
+ layer_id = 1
34
+ elif stage_id == 1:
35
+ layer_id = 2
36
+ elif stage_id == 2:
37
+ layer_id = 3 + block_id // 3
38
+ elif stage_id == 3:
39
+ layer_id = num_max_layer
40
+ return layer_id
41
+ else:
42
+ return num_max_layer + 1
43
+
44
+
45
+ def get_num_layer_stage_wise(var_name, num_max_layer):
46
+ if var_name in ("backbone.cls_token", "backbone.mask_token", "backbone.pos_embed"):
47
+ return 0
48
+ elif var_name.startswith("backbone.downsample_layers"):
49
+ return 0
50
+ elif var_name.startswith("backbone.stages"):
51
+ stage_id = int(var_name.split('.')[2])
52
+ return stage_id + 1
53
+ else:
54
+ return num_max_layer - 1
55
+
56
+
57
+ @OPTIMIZER_BUILDERS.register_module()
58
+ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
59
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
60
+ """Add all parameters of module to the params list.
61
+ The parameters of the given module will be added to the list of param
62
+ groups, with specific rules defined by paramwise_cfg.
63
+ Args:
64
+ params (list[dict]): A list of param groups, it will be modified
65
+ in place.
66
+ module (nn.Module): The module to be added.
67
+ prefix (str): The prefix of the module
68
+ is_dcn_module (int|float|None): If the current module is a
69
+ submodule of DCN, `is_dcn_module` will be passed to
70
+ control conv_offset layer's learning rate. Defaults to None.
71
+ """
72
+ parameter_groups = {}
73
+ print(self.paramwise_cfg)
74
+ num_layers = self.paramwise_cfg.get('num_layers') + 2
75
+ decay_rate = self.paramwise_cfg.get('decay_rate')
76
+ decay_type = self.paramwise_cfg.get('decay_type', "layer_wise")
77
+ print("Build LearningRateDecayOptimizerConstructor %s %f - %d" % (decay_type, decay_rate, num_layers))
78
+ weight_decay = self.base_wd
79
+
80
+ for name, param in module.named_parameters():
81
+ if not param.requires_grad:
82
+ continue # frozen weights
83
+ if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token'):
84
+ group_name = "no_decay"
85
+ this_weight_decay = 0.
86
+ else:
87
+ group_name = "decay"
88
+ this_weight_decay = weight_decay
89
+
90
+ if decay_type == "layer_wise":
91
+ layer_id = get_num_layer_layer_wise(name, self.paramwise_cfg.get('num_layers'))
92
+ elif decay_type == "stage_wise":
93
+ layer_id = get_num_layer_stage_wise(name, num_layers)
94
+
95
+ group_name = "layer_%d_%s" % (layer_id, group_name)
96
+
97
+ if group_name not in parameter_groups:
98
+ scale = decay_rate ** (num_layers - layer_id - 1)
99
+
100
+ parameter_groups[group_name] = {
101
+ "weight_decay": this_weight_decay,
102
+ "params": [],
103
+ "param_names": [],
104
+ "lr_scale": scale,
105
+ "group_name": group_name,
106
+ "lr": scale * self.base_lr,
107
+ }
108
+
109
+ parameter_groups[group_name]["params"].append(param)
110
+ parameter_groups[group_name]["param_names"].append(name)
111
+ rank, _ = get_dist_info()
112
+ if rank == 0:
113
+ to_display = {}
114
+ for key in parameter_groups:
115
+ to_display[key] = {
116
+ "param_names": parameter_groups[key]["param_names"],
117
+ "lr_scale": parameter_groups[key]["lr_scale"],
118
+ "lr": parameter_groups[key]["lr"],
119
+ "weight_decay": parameter_groups[key]["weight_decay"],
120
+ }
121
+ print("Param groups = %s" % json.dumps(to_display, indent=2))
122
+
123
+ params.extend(parameter_groups.values())
spark/downstream_mmdet/mmcv_custom/runner/checkpoint.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import os.path as osp
3
+ import time
4
+ from tempfile import TemporaryDirectory
5
+
6
+ import torch
7
+ from torch.optim import Optimizer
8
+
9
+ import mmcv
10
+ from mmcv.parallel import is_module_wrapper
11
+ from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
12
+
13
+ try:
14
+ import apex
15
+ except:
16
+ print('apex is not installed')
17
+
18
+
19
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
20
+ """Save checkpoint to file.
21
+
22
+ The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
23
+ ``optimizer``, ``amp``. By default ``meta`` will contain version
24
+ and time info.
25
+
26
+ Args:
27
+ model (Module): Module whose params are to be saved.
28
+ filename (str): Checkpoint filename.
29
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
30
+ meta (dict, optional): Metadata to be saved in checkpoint.
31
+ """
32
+ if meta is None:
33
+ meta = {}
34
+ elif not isinstance(meta, dict):
35
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
36
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
37
+
38
+ if is_module_wrapper(model):
39
+ model = model.module
40
+
41
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
42
+ # save class name to the meta
43
+ meta.update(CLASSES=model.CLASSES)
44
+
45
+ checkpoint = {
46
+ 'meta': meta,
47
+ 'state_dict': weights_to_cpu(get_state_dict(model))
48
+ }
49
+ # save optimizer state dict in the checkpoint
50
+ if isinstance(optimizer, Optimizer):
51
+ checkpoint['optimizer'] = optimizer.state_dict()
52
+ elif isinstance(optimizer, dict):
53
+ checkpoint['optimizer'] = {}
54
+ for name, optim in optimizer.items():
55
+ checkpoint['optimizer'][name] = optim.state_dict()
56
+
57
+ # save amp state dict in the checkpoint
58
+ # checkpoint['amp'] = apex.amp.state_dict()
59
+
60
+ if filename.startswith('pavi://'):
61
+ try:
62
+ from pavi import modelcloud
63
+ from pavi.exception import NodeNotFoundError
64
+ except ImportError:
65
+ raise ImportError(
66
+ 'Please install pavi to load checkpoint from modelcloud.')
67
+ model_path = filename[7:]
68
+ root = modelcloud.Folder()
69
+ model_dir, model_name = osp.split(model_path)
70
+ try:
71
+ model = modelcloud.get(model_dir)
72
+ except NodeNotFoundError:
73
+ model = root.create_training_model(model_dir)
74
+ with TemporaryDirectory() as tmp_dir:
75
+ checkpoint_file = osp.join(tmp_dir, model_name)
76
+ with open(checkpoint_file, 'wb') as f:
77
+ torch.save(checkpoint, f)
78
+ f.flush()
79
+ model.create_file(checkpoint_file, name=model_name)
80
+ else:
81
+ mmcv.mkdir_or_exist(osp.dirname(filename))
82
+ # immediately flush buffer
83
+ with open(filename, 'wb') as f:
84
+ torch.save(checkpoint, f)
85
+ f.flush()
spark/downstream_mmdet/mmdet/models/backbones/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .darknet import Darknet
2
+ from .detectors_resnet import DetectoRS_ResNet
3
+ from .detectors_resnext import DetectoRS_ResNeXt
4
+ from .hourglass import HourglassNet
5
+ from .hrnet import HRNet
6
+ from .regnet import RegNet
7
+ from .res2net import Res2Net
8
+ from .resnest import ResNeSt
9
+ from .resnet import ResNet, ResNetV1d
10
+ from .resnext import ResNeXt
11
+ from .ssd_vgg import SSDVGG
12
+ from .trident_resnet import TridentResNet
13
+ from .swin_transformer import SwinTransformer
14
+ from .convnext import ConvNeXt
15
+
16
+ __all__ = [
17
+ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
18
+ 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
19
+ 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'ConvNeXt'
20
+ ]
spark/downstream_mmdet/mmdet/models/backbones/convnext.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ from functools import partial
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from timm.models.layers import trunc_normal_, DropPath
14
+
15
+ from mmcv_custom import load_checkpoint
16
+ from mmdet.utils import get_root_logger
17
+ from ..builder import BACKBONES
18
+
19
+ class Block(nn.Module):
20
+ r""" ConvNeXt Block. There are two equivalent implementations:
21
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
22
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
23
+ We use (2) as we find it slightly faster in PyTorch
24
+
25
+ Args:
26
+ dim (int): Number of input channels.
27
+ drop_path (float): Stochastic depth rate. Default: 0.0
28
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
29
+ """
30
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
31
+ super().__init__()
32
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
33
+ self.norm = LayerNorm(dim, eps=1e-6)
34
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
35
+ self.act = nn.GELU()
36
+ self.pwconv2 = nn.Linear(4 * dim, dim)
37
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
38
+ requires_grad=True) if layer_scale_init_value > 0 else None
39
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
40
+
41
+ def forward(self, x):
42
+ input = x
43
+ x = self.dwconv(x)
44
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
45
+ x = self.norm(x)
46
+ x = self.pwconv1(x)
47
+ x = self.act(x)
48
+ x = self.pwconv2(x)
49
+ if self.gamma is not None:
50
+ x = self.gamma * x
51
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
52
+
53
+ x = input + self.drop_path(x)
54
+ return x
55
+
56
+ @BACKBONES.register_module()
57
+ class ConvNeXt(nn.Module):
58
+ r""" ConvNeXt
59
+ A PyTorch impl of : `A ConvNet for the 2020s` -
60
+ https://arxiv.org/pdf/2201.03545.pdf
61
+
62
+ Args:
63
+ in_chans (int): Number of input image channels. Default: 3
64
+ num_classes (int): Number of classes for classification head. Default: 1000
65
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
66
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
67
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
68
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
69
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
70
+ """
71
+ def __init__(self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
72
+ drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3],
73
+ ):
74
+ super().__init__()
75
+
76
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
77
+ stem = nn.Sequential(
78
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
79
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
80
+ )
81
+ self.downsample_layers.append(stem)
82
+ for i in range(3):
83
+ downsample_layer = nn.Sequential(
84
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
85
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
86
+ )
87
+ self.downsample_layers.append(downsample_layer)
88
+
89
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
90
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
91
+ cur = 0
92
+ for i in range(4):
93
+ stage = nn.Sequential(
94
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
95
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
96
+ )
97
+ self.stages.append(stage)
98
+ cur += depths[i]
99
+
100
+ self.out_indices = out_indices
101
+
102
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
103
+ for i_layer in range(4):
104
+ layer = norm_layer(dims[i_layer])
105
+ layer_name = f'norm{i_layer}'
106
+ self.add_module(layer_name, layer)
107
+
108
+ self.apply(self._init_weights)
109
+
110
+ def _init_weights(self, m):
111
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
112
+ trunc_normal_(m.weight, std=.02)
113
+ nn.init.constant_(m.bias, 0)
114
+
115
+ def init_weights(self, pretrained=None):
116
+ """Initialize the weights in backbone.
117
+ Args:
118
+ pretrained (str, optional): Path to pre-trained weights.
119
+ Defaults to None.
120
+ """
121
+
122
+ def _init_weights(m):
123
+ if isinstance(m, nn.Linear):
124
+ trunc_normal_(m.weight, std=.02)
125
+ if isinstance(m, nn.Linear) and m.bias is not None:
126
+ nn.init.constant_(m.bias, 0)
127
+ elif isinstance(m, nn.LayerNorm):
128
+ nn.init.constant_(m.bias, 0)
129
+ nn.init.constant_(m.weight, 1.0)
130
+
131
+ if isinstance(pretrained, str):
132
+ self.apply(_init_weights)
133
+ logger = get_root_logger()
134
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
135
+ elif pretrained is None:
136
+ self.apply(_init_weights)
137
+ else:
138
+ raise TypeError('pretrained must be a str or None')
139
+
140
+ def forward_features(self, x):
141
+ outs = []
142
+ for i in range(4):
143
+ x = self.downsample_layers[i](x)
144
+ x = self.stages[i](x)
145
+ if i in self.out_indices:
146
+ norm_layer = getattr(self, f'norm{i}')
147
+ x_out = norm_layer(x)
148
+ outs.append(x_out)
149
+
150
+ return tuple(outs)
151
+
152
+ def forward(self, x):
153
+ x = self.forward_features(x)
154
+ return x
155
+
156
+ class LayerNorm(nn.Module):
157
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
158
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
159
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
160
+ with shape (batch_size, channels, height, width).
161
+ """
162
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
163
+ super().__init__()
164
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
165
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
166
+ self.eps = eps
167
+ self.data_format = data_format
168
+ if self.data_format not in ["channels_last", "channels_first"]:
169
+ raise NotImplementedError
170
+ self.normalized_shape = (normalized_shape, )
171
+
172
+ def forward(self, x):
173
+ if self.data_format == "channels_last":
174
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
175
+ elif self.data_format == "channels_first":
176
+ u = x.mean(1, keepdim=True)
177
+ s = (x - u).pow(2).mean(1, keepdim=True)
178
+ x = (x - u) / torch.sqrt(s + self.eps)
179
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
180
+ return x
spark/pretrain/README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Preparation for ImageNet-1k pretraining
2
+
3
+ See [/INSTALL.md](/INSTALL.md) to prepare `pip` dependencies and the ImageNet dataset.
4
+
5
+ **Note: for neural network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](https://github.com/facebookresearch/ConvNeXt/blob/048efcea897d999aed302f2639b6270aedf8d4c8/models/convnext.py).**
6
+
7
+
8
+ ## Tutorial for pretraining your own CNN model
9
+
10
+ See [/pretrain/models/custom.py](/pretrain/models/custom.py). Your todo list is:
11
+
12
+ - implement `get_downsample_ratio` in [/pretrain/models/custom.py line20](/pretrain/models/custom.py#L20).
13
+ - implement `get_feature_map_channels` in [/pretrain/models/custom.py line29](/pretrain/models/custom.py#L29).
14
+ - implement `forward` in [/pretrain/models/custom.py line38](/pretrain/models/custom.py#L38).
15
+ - define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py line54](/pretrain/models/custom.py#L53-L54).
16
+ - add default kwargs of `your_convnet(...)` in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34).
17
+ - **Note: see [#54](/../../issues/54) if your CNN contains SE module or global average pooling layer, and see [#56](/../../issues/56) if it contains GroupNorm**.
18
+
19
+ Then run the experiment with `--model=your_convnet`.
20
+
21
+
22
+ ## Tutorial for pretraining on your own dataset
23
+
24
+ See the comment of `build_dataset_to_pretrain` in [line55 of /pretrain/utils/imagenet.py](/pretrain/utils/imagenet.py#L55). Your todo list:
25
+
26
+ - Define a subclass of `torch.utils.data.Dataset` for your own unlabeled dataset, to replace our `ImageNetDataset`.
27
+ - Use `args.data_path` and `args.input_size` to help build your dataset, with `--data_path=... --input_size=...` to specify them.
28
+ - Note the batch size `--bs` is the total batch size of all GPU, which may need to be adjusted based on your dataset size. FYI: we use `--bs=4096` for ImageNet, which contains 1.28 million images.
29
+
30
+ **If your dataset is relatively small**, you can try `--init_weight=/path/to/res50_withdecoder_1kpretrained_spark_style.pth` to do your pretraining *from our pretrained weights*, rather than *form scratch*.
31
+
32
+ ## Debug on 1 GPU (without DistributedDataParallel)
33
+
34
+ Use a small batch size `--bs=32` for avoiding OOM.
35
+
36
+ ```shell script
37
+ python3 main.py --exp_name=debug --data_path=/path/to/imagenet --model=resnet50 --bs=32
38
+ ```
39
+
40
+
41
+ ## Pretraining Any Model on ImageNet-1k (224x224)
42
+
43
+ For pretraining, run [/pretrain/main.py](/pretrain/main.py) with `torchrun`.
44
+ **It is required to specify** the ImageNet data folder (`--data_path`), your experiment name & log dir (`--exp_name` and `--exp_dir`, automatically created if not exists), and the model name (`--model`, valid choices see the keys of 'pretrain_default_model_kwargs' in [/pretrain/models/\_\_init\_\_.py line34](/pretrain/models/__init__.py#L34)).
45
+
46
+ We use the **same** pretraining configurations (lr, batch size, etc.) for all models (ResNets and ConvNeXts) in 224 pretraining.
47
+ Their **names** and **default values** are in [/pretrain/utils/arg_util.py line23-44](/pretrain/utils/arg_util.py#L23-L44).
48
+ All these default configurations (like batch size 4096) would be used, unless you specify some like `--bs=512`.
49
+
50
+ **Note: the batch size `--bs` is the total batch size of all GPU, and the learning rate `--base_lr` is the base lr. The actual lr would be `lr = base_lr * bs / 256`, as in [/pretrain/utils/arg_util.py line131](/pretrain/utils/arg_util.py#L131). So do not use `--lr` to specify a lr (that will be ignored)**
51
+
52
+ Here is an example to pretrain a ResNet50 on an 8-GPU single machine (we use DistributedDataParallel), overwriting the default batch size to 512:
53
+ ```shell script
54
+ $ cd /path/to/SparK/pretrain
55
+ $ torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=<some_port> main.py \
56
+ --data_path=/path/to/imagenet --exp_name=<your_exp_name> --exp_dir=/path/to/logdir \
57
+ --model=resnet50 --bs=512
58
+ ```
59
+
60
+ For multiple machines, change the `--nnodes`, `--node_rank`, `--master_address` and `--master_port` to your configurations. E.g.:
61
+ ```shell script
62
+ $ torchrun --nproc_per_node=8 --nnodes=<your_nnodes> --node_rank=<rank_starts_from_0> --master_address=<some_address> --master_port=<some_port> main.py \
63
+ ...
64
+ ```
65
+
66
+ ## Pretraining ConvNeXt-Large on ImageNet-1k (384x384)
67
+
68
+ For 384 pretraining we use a larger mask ratio (0.75), a half batch size (2048), and a double base learning rate (4e-4):
69
+
70
+ ```shell script
71
+ $ cd /path/to/SparK/pretrain
72
+ $ torchrun --nproc_per_node=8 --nnodes=<your_nnodes> --node_rank=<rank_starts_from_0> --master_address=<some_address> --master_port=<some_port> main.py \
73
+ --data_path=/path/to/imagenet --exp_name=<your_exp_name> --exp_dir=/path/to/logdir \
74
+ --model=convnext_large --input_size=384 --mask=0.75 --bs=2048 --base_lr=4e-4
75
+ ```
76
+
77
+ ## Logging
78
+
79
+ See files in your `--exp_dir` to track your experiment:
80
+
81
+ - `<model>_withdecoder_1kpretrained_spark_style.pth`: saves model and optimizer states, current epoch, current reconstruction loss, etc.; can be used to resume pretraining; can also be used for visualization in [/pretrain/viz_reconstruction.ipynb](/pretrain/viz_reconstruction.ipynb)
82
+ - `<model>_1kpretrained_timm_style.pth`: can be used for downstream finetuning
83
+ - `pretrain_log.txt`: records some important information such as:
84
+ - `git_commit_id`: git version
85
+ - `cmd`: the command of this experiment
86
+
87
+ It also reports the loss and remaining pretraining time.
88
+
89
+ - `tensorboard_log/`: saves a lot of tensorboard logs including loss values, learning rates, gradient norms and more things. Use `tensorboard --logdir /path/to/this/tensorboard_log/ --port 23333` for viz.
90
+ - `stdout_backup.txt` and `stderr_backup.txt`: backups stdout/stderr.
91
+
92
+ ## Resuming
93
+
94
+ Specify `--resume_from=path/to/<model>_withdecoder_1kpretrained_spark_style.pth` to resume pretraining. Note this is different from `--init_weight`:
95
+
96
+ - `--resume_from` will load three things: model weights, optimizer states, and current epoch, so it is used to resume some interrupted experiment (will start from that 'current epoch').
97
+ - `--init_weight` ONLY loads the model weights, so it's just like a model initialization (will start from epoch 0).
98
+
99
+
100
+ ## Regarding sparse convolution
101
+
102
+ We do not use sparse convolutions in this pytorch implementation, due to their limited optimization on modern hardware.
103
+ As can be found in [/pretrain/encoder.py](/pretrain/encoder.py), we use masked dense convolution to simulate submanifold sparse convolution.
104
+ We also define some sparse pooling or normalization layers in [/pretrain/encoder.py](/pretrain/encoder.py).
105
+ All these "sparse" layers are implemented through pytorch built-in operators.
106
+
107
+
108
+ ## Some details: how we mask images and how to set the patch size
109
+
110
+ In SparK, the mask patch size **equals to** the downsample ratio of the CNN model (so there is no configuration like `--patch_size=32`).
111
+
112
+ Here is the reason: when we do mask, we:
113
+
114
+ 1. first generate the binary mask for the **smallest** resolution feature map, i.e., generate the `_cur_active` or `active_b1ff` in [/pretrain/spark.py line86-87](/pretrain/spark.py#L86-L87), which is a `torch.BoolTensor` shaped as `[B, 1, fmap_h, fmap_w]`, and would be used to mask the smallest feature map.
115
+ 3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., dim=2)` and `repeat_interleave(..., dim=3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py#L16)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py#L21)) with larger resolutions .
116
+
117
+ So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
118
+ See [Tutorial for pretraining your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain/#tutorial-for-pretraining-your-own-cnn-model).
spark/pretrain/decoder.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import List
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from timm.models.layers import trunc_normal_
13
+
14
+ from utils.misc import is_pow2n
15
+
16
+
17
+ class UNetBlock(nn.Module):
18
+ def __init__(self, cin, cout, bn2d):
19
+ """
20
+ a UNet block with 2x up sampling
21
+ """
22
+ super().__init__()
23
+ self.up_sample = nn.ConvTranspose2d(cin, cin, kernel_size=4, stride=2, padding=1, bias=True)
24
+ self.conv = nn.Sequential(
25
+ nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cin), nn.ReLU6(inplace=True),
26
+ nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cout),
27
+ )
28
+
29
+ def forward(self, x):
30
+ x = self.up_sample(x)
31
+ return self.conv(x)
32
+
33
+
34
+ class LightDecoder(nn.Module):
35
+ def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
36
+ super().__init__()
37
+ self.width = width
38
+ assert is_pow2n(up_sample_ratio)
39
+ n = round(math.log2(up_sample_ratio))
40
+ channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule
41
+ bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d
42
+ self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])])
43
+ self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True)
44
+
45
+ self.initialize()
46
+
47
+ def forward(self, to_dec: List[torch.Tensor]):
48
+ x = 0
49
+ for i, d in enumerate(self.dec):
50
+ if i < len(to_dec) and to_dec[i] is not None:
51
+ x = x + to_dec[i]
52
+ x = self.dec[i](x)
53
+ return self.proj(x)
54
+
55
+ def extra_repr(self) -> str:
56
+ return f'width={self.width}'
57
+
58
+ def initialize(self):
59
+ for m in self.modules():
60
+ if isinstance(m, nn.Linear):
61
+ trunc_normal_(m.weight, std=.02)
62
+ if m.bias is not None:
63
+ nn.init.constant_(m.bias, 0)
64
+ elif isinstance(m, nn.Conv2d):
65
+ trunc_normal_(m.weight, std=.02)
66
+ if m.bias is not None:
67
+ nn.init.constant_(m.bias, 0)
68
+ elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
69
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
70
+ if m.bias is not None:
71
+ nn.init.constant_(m.bias, 0.)
72
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)):
73
+ nn.init.constant_(m.bias, 0)
74
+ nn.init.constant_(m.weight, 1.0)
spark/pretrain/dist.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from typing import List
9
+ from typing import Union
10
+
11
+ import sys
12
+ import torch
13
+ import torch.distributed as tdist
14
+ import torch.multiprocessing as mp
15
+
16
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu'
17
+ __initialized = False
18
+
19
+
20
+ def initialized():
21
+ return __initialized
22
+
23
+
24
+ def initialize(backend='nccl'):
25
+ global __device
26
+ if not torch.cuda.is_available():
27
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
28
+ return
29
+ elif 'RANK' not in os.environ:
30
+ __device = torch.empty(1).cuda().device
31
+ print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr)
32
+ return
33
+
34
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
35
+ if mp.get_start_method(allow_none=True) is None:
36
+ mp.set_start_method('spawn')
37
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
38
+ local_rank = global_rank % num_gpus
39
+ torch.cuda.set_device(local_rank)
40
+ tdist.init_process_group(backend=backend)
41
+
42
+ global __rank, __local_rank, __world_size, __initialized
43
+ __local_rank = local_rank
44
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
45
+ __device = torch.empty(1).cuda().device
46
+ __initialized = True
47
+
48
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
49
+
50
+
51
+ def get_rank():
52
+ return __rank
53
+
54
+
55
+ def get_local_rank():
56
+ return __local_rank
57
+
58
+
59
+ def get_world_size():
60
+ return __world_size
61
+
62
+
63
+ def get_device():
64
+ return __device
65
+
66
+
67
+ def is_master():
68
+ return __rank == 0
69
+
70
+
71
+ def is_local_master():
72
+ return __local_rank == 0
73
+
74
+
75
+ def barrier():
76
+ if __initialized:
77
+ tdist.barrier()
78
+
79
+
80
+ def parallelize(net, syncbn=False):
81
+ if syncbn:
82
+ net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
83
+ net = net.cuda()
84
+ net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
85
+ return net
86
+
87
+
88
+ def allreduce(t: torch.Tensor) -> None:
89
+ if __initialized:
90
+ if not t.is_cuda:
91
+ cu = t.detach().cuda()
92
+ tdist.all_reduce(cu)
93
+ t.copy_(cu.cpu())
94
+ else:
95
+ tdist.all_reduce(t)
96
+
97
+
98
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
99
+ if __initialized:
100
+ if not t.is_cuda:
101
+ t = t.cuda()
102
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
103
+ tdist.all_gather(ls, t)
104
+ else:
105
+ ls = [t]
106
+ if cat:
107
+ ls = torch.cat(ls, dim=0)
108
+ return ls
109
+
110
+
111
+ def broadcast(t: torch.Tensor, src_rank) -> None:
112
+ if __initialized:
113
+ if not t.is_cuda:
114
+ cu = t.detach().cuda()
115
+ tdist.broadcast(cu, src=src_rank)
116
+ t.copy_(cu.cpu())
117
+ else:
118
+ tdist.broadcast(t, src=src_rank)
spark/pretrain/encoder.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from timm.models.layers import DropPath
10
+
11
+
12
+ _cur_active: torch.Tensor = None # B1ff
13
+ # todo: try to use `gather` for speed?
14
+ def _get_active_ex_or_ii(H, W, returning_active_ex=True):
15
+ h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1]
16
+ active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3)
17
+ return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
18
+
19
+
20
+ def sp_conv_forward(self, x: torch.Tensor):
21
+ x = super(type(self), self).forward(x)
22
+ x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
23
+ return x
24
+
25
+
26
+ def sp_bn_forward(self, x: torch.Tensor):
27
+ ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
28
+
29
+ bhwc = x.permute(0, 2, 3, 1)
30
+ nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc`
31
+ nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
32
+
33
+ bchw = torch.zeros_like(bhwc)
34
+ bchw[ii] = nc
35
+ bchw = bchw.permute(0, 3, 1, 2)
36
+ return bchw
37
+
38
+
39
+ class SparseConv2d(nn.Conv2d):
40
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
41
+
42
+
43
+ class SparseMaxPooling(nn.MaxPool2d):
44
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
45
+
46
+
47
+ class SparseAvgPooling(nn.AvgPool2d):
48
+ forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
49
+
50
+
51
+ class SparseBatchNorm2d(nn.BatchNorm1d):
52
+ forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
53
+
54
+
55
+ class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
56
+ forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
57
+
58
+
59
+ class SparseConvNeXtLayerNorm(nn.LayerNorm):
60
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
61
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
62
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
63
+ with shape (batch_size, channels, height, width).
64
+ """
65
+
66
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
67
+ if data_format not in ["channels_last", "channels_first"]:
68
+ raise NotImplementedError
69
+ super().__init__(normalized_shape, eps, elementwise_affine=True)
70
+ self.data_format = data_format
71
+ self.sparse = sparse
72
+
73
+ def forward(self, x):
74
+ if x.ndim == 4: # BHWC or BCHW
75
+ if self.data_format == "channels_last": # BHWC
76
+ if self.sparse:
77
+ ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False)
78
+ nc = x[ii]
79
+ nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
80
+
81
+ x = torch.zeros_like(x)
82
+ x[ii] = nc
83
+ return x
84
+ else:
85
+ return super(SparseConvNeXtLayerNorm, self).forward(x)
86
+ else: # channels_first, BCHW
87
+ if self.sparse:
88
+ ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
89
+ bhwc = x.permute(0, 2, 3, 1)
90
+ nc = bhwc[ii]
91
+ nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
92
+
93
+ x = torch.zeros_like(bhwc)
94
+ x[ii] = nc
95
+ return x.permute(0, 3, 1, 2)
96
+ else:
97
+ u = x.mean(1, keepdim=True)
98
+ s = (x - u).pow(2).mean(1, keepdim=True)
99
+ x = (x - u) / torch.sqrt(s + self.eps)
100
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
101
+ return x
102
+ else: # BLC or BC
103
+ if self.sparse:
104
+ raise NotImplementedError
105
+ else:
106
+ return super(SparseConvNeXtLayerNorm, self).forward(x)
107
+
108
+ def __repr__(self):
109
+ return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
110
+
111
+
112
+ class SparseConvNeXtBlock(nn.Module):
113
+ r""" ConvNeXt Block. There are two equivalent implementations:
114
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
115
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
116
+ We use (2) as we find it slightly faster in PyTorch
117
+
118
+ Args:
119
+ dim (int): Number of input channels.
120
+ drop_path (float): Stochastic depth rate. Default: 0.0
121
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
122
+ """
123
+
124
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
125
+ super().__init__()
126
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv
127
+ self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
128
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
129
+ self.act = nn.GELU()
130
+ self.pwconv2 = nn.Linear(4 * dim, dim)
131
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
132
+ requires_grad=True) if layer_scale_init_value > 0 else None
133
+ self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
134
+ self.sparse = sparse
135
+
136
+ def forward(self, x):
137
+ input = x
138
+ x = self.dwconv(x)
139
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
140
+ x = self.norm(x)
141
+ x = self.pwconv1(x)
142
+ x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
143
+ x = self.pwconv2(x)
144
+ if self.gamma is not None:
145
+ x = self.gamma * x
146
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
147
+
148
+ if self.sparse:
149
+ x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True)
150
+
151
+ x = input + self.drop_path(x)
152
+ return x
153
+
154
+ def __repr__(self):
155
+ return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
156
+
157
+
158
+ class SparseEncoder(nn.Module):
159
+ def __init__(self, cnn, input_size, sbn=False, verbose=False):
160
+ super(SparseEncoder, self).__init__()
161
+ self.sp_cnn = SparseEncoder.dense_model_to_sparse(m=cnn, verbose=verbose, sbn=sbn)
162
+ self.input_size, self.downsample_raito, self.enc_feat_map_chs = input_size, cnn.get_downsample_ratio(), cnn.get_feature_map_channels()
163
+
164
+ @staticmethod
165
+ def dense_model_to_sparse(m: nn.Module, verbose=False, sbn=False):
166
+ oup = m
167
+ if isinstance(m, nn.Conv2d):
168
+ m: nn.Conv2d
169
+ bias = m.bias is not None
170
+ oup = SparseConv2d(
171
+ m.in_channels, m.out_channels,
172
+ kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
173
+ dilation=m.dilation, groups=m.groups, bias=bias, padding_mode=m.padding_mode,
174
+ )
175
+ oup.weight.data.copy_(m.weight.data)
176
+ if bias:
177
+ oup.bias.data.copy_(m.bias.data)
178
+ elif isinstance(m, nn.MaxPool2d):
179
+ m: nn.MaxPool2d
180
+ oup = SparseMaxPooling(m.kernel_size, stride=m.stride, padding=m.padding, dilation=m.dilation, return_indices=m.return_indices, ceil_mode=m.ceil_mode)
181
+ elif isinstance(m, nn.AvgPool2d):
182
+ m: nn.AvgPool2d
183
+ oup = SparseAvgPooling(m.kernel_size, m.stride, m.padding, ceil_mode=m.ceil_mode, count_include_pad=m.count_include_pad, divisor_override=m.divisor_override)
184
+ elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
185
+ m: nn.BatchNorm2d
186
+ oup = (SparseSyncBatchNorm2d if sbn else SparseBatchNorm2d)(m.weight.shape[0], eps=m.eps, momentum=m.momentum, affine=m.affine, track_running_stats=m.track_running_stats)
187
+ oup.weight.data.copy_(m.weight.data)
188
+ oup.bias.data.copy_(m.bias.data)
189
+ oup.running_mean.data.copy_(m.running_mean.data)
190
+ oup.running_var.data.copy_(m.running_var.data)
191
+ oup.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
192
+ if hasattr(m, "qconfig"):
193
+ oup.qconfig = m.qconfig
194
+ elif isinstance(m, nn.LayerNorm) and not isinstance(m, SparseConvNeXtLayerNorm):
195
+ m: nn.LayerNorm
196
+ oup = SparseConvNeXtLayerNorm(m.weight.shape[0], eps=m.eps)
197
+ oup.weight.data.copy_(m.weight.data)
198
+ oup.bias.data.copy_(m.bias.data)
199
+ elif isinstance(m, (nn.Conv1d,)):
200
+ raise NotImplementedError
201
+
202
+ for name, child in m.named_children():
203
+ oup.add_module(name, SparseEncoder.dense_model_to_sparse(child, verbose=verbose, sbn=sbn))
204
+ del m
205
+ return oup
206
+
207
+ def forward(self, x):
208
+ return self.sp_cnn(x, hierarchical=True)
spark/pretrain/main.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import datetime
8
+ import math
9
+ import sys
10
+ import time
11
+ from functools import partial
12
+ from typing import List
13
+
14
+ import torch
15
+ from torch.nn.parallel import DistributedDataParallel
16
+ from torch.utils.data import DataLoader
17
+
18
+ import dist
19
+ import encoder
20
+ from decoder import LightDecoder
21
+ from models import build_sparse_encoder
22
+ from sampler import DistInfiniteBatchSampler, worker_init_fn
23
+ from spark import SparK
24
+ from utils import arg_util, misc, lamb
25
+ from utils.imagenet import build_dataset_to_pretrain
26
+ from utils.lr_control import lr_wd_annealing, get_param_groups
27
+
28
+
29
+ class LocalDDP(torch.nn.Module):
30
+ def __init__(self, module):
31
+ super(LocalDDP, self).__init__()
32
+ self.module = module
33
+
34
+ def forward(self, *args, **kwargs):
35
+ return self.module(*args, **kwargs)
36
+
37
+
38
+ def main_pt():
39
+ args: arg_util.Args = arg_util.init_dist_and_get_args()
40
+ print(f'initial args:\n{str(args)}')
41
+ args.log_epoch()
42
+
43
+ # build data
44
+ print(f'[build data for pre-training] ...\n')
45
+ dataset_train = build_dataset_to_pretrain(args.data_path, args.input_size)
46
+ data_loader_train = DataLoader(
47
+ dataset=dataset_train, num_workers=args.dataloader_workers, pin_memory=True,
48
+ batch_sampler=DistInfiniteBatchSampler(
49
+ dataset_len=len(dataset_train), glb_batch_size=args.glb_batch_size,
50
+ shuffle=True, filling=True, rank=dist.get_rank(), world_size=dist.get_world_size(),
51
+ ), worker_init_fn=worker_init_fn
52
+ )
53
+ itrt_train, iters_train = iter(data_loader_train), len(data_loader_train)
54
+ print(f'[dataloader] gbs={args.glb_batch_size}, lbs={args.batch_size_per_gpu}, iters_train={iters_train}')
55
+
56
+ # build encoder and decoder
57
+ enc: encoder.SparseEncoder = build_sparse_encoder(args.model, input_size=args.input_size, sbn=args.sbn, drop_path_rate=args.dp, verbose=False)
58
+ dec = LightDecoder(enc.downsample_raito, sbn=args.sbn)
59
+ model_without_ddp = SparK(
60
+ sparse_encoder=enc, dense_decoder=dec, mask_ratio=args.mask,
61
+ densify_norm=args.densify_norm, sbn=args.sbn,
62
+ ).to(args.device)
63
+ print(f'[PT model] model = {model_without_ddp}\n')
64
+
65
+ # the model has been randomly initialized in their construction time
66
+ # now try to load some checkpoint as model weight initialization; this ONLY loads the model weights
67
+ misc.initialize_weight(args.init_weight, model_without_ddp)
68
+
69
+ if dist.initialized():
70
+ model: DistributedDataParallel = DistributedDataParallel(model_without_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, broadcast_buffers=False)
71
+ else:
72
+ model = LocalDDP(model_without_ddp)
73
+
74
+ # build optimizer and lr_scheduler
75
+ param_groups: List[dict] = get_param_groups(model_without_ddp, nowd_keys={'cls_token', 'pos_embed', 'mask_token', 'gamma'})
76
+ opt_clz = {
77
+ 'sgd': partial(torch.optim.SGD, momentum=0.9, nesterov=True),
78
+ 'adamw': partial(torch.optim.AdamW, betas=(0.9, args.ada)),
79
+ 'lamb': partial(lamb.TheSameAsTimmLAMB, betas=(0.9, args.ada), max_grad_norm=5.0),
80
+ }[args.opt]
81
+ optimizer = opt_clz(params=param_groups, lr=args.lr, weight_decay=0.0)
82
+ print(f'[optimizer] optimizer({opt_clz}) ={optimizer}\n')
83
+
84
+ # try to resume the experiment from some checkpoint.pth; this will load model weights, optimizer states, and last epoch (ep_start)
85
+ # if loaded, ep_start will be greater than 0
86
+ ep_start, performance_desc = misc.load_checkpoint(args.resume_from, model_without_ddp, optimizer)
87
+ if ep_start >= args.ep: # load from a complete checkpoint file
88
+ print(f' [*] [PT already done] Min/Last Recon Loss: {performance_desc}')
89
+ else: # perform pre-training
90
+ tb_lg = misc.TensorboardLogger(args.tb_lg_dir, is_master=dist.is_master(), prefix='pt')
91
+ min_loss = 1e9
92
+ print(f'[PT start] from ep{ep_start}')
93
+
94
+ pt_start_time = time.time()
95
+ for ep in range(ep_start, args.ep):
96
+ ep_start_time = time.time()
97
+ tb_lg.set_step(ep * iters_train)
98
+ if hasattr(itrt_train, 'set_epoch'):
99
+ itrt_train.set_epoch(ep)
100
+
101
+ stats = pre_train_one_ep(ep, args, tb_lg, itrt_train, iters_train, model, optimizer)
102
+ last_loss = stats['last_loss']
103
+ min_loss = min(min_loss, last_loss)
104
+ performance_desc = f'{min_loss:.4f} {last_loss:.4f}'
105
+ misc.save_checkpoint_with_meta_info_and_opt_state(f'{args.model}_withdecoder_1kpretrained_spark_style.pth', args, ep, performance_desc, model_without_ddp.state_dict(with_config=True), optimizer.state_dict())
106
+ misc.save_checkpoint_model_weights_only(f'{args.model}_1kpretrained_timm_style.pth', args, model_without_ddp.sparse_encoder.sp_cnn.state_dict())
107
+
108
+ ep_cost = round(time.time() - ep_start_time, 2) + 1 # +1s: approximate the following logging cost
109
+ remain_secs = (args.ep-1 - ep) * ep_cost
110
+ remain_time = datetime.timedelta(seconds=round(remain_secs))
111
+ finish_time = time.strftime("%m-%d %H:%M", time.localtime(time.time() + remain_secs))
112
+ print(f' [*] [ep{ep}/{args.ep}] Min/Last Recon Loss: {performance_desc}, Cost: {ep_cost}s, Remain: {remain_time}, Finish @ {finish_time}')
113
+
114
+ args.cur_ep = f'{ep + 1}/{args.ep}'
115
+ args.remain_time, args.finish_time = str(remain_time), str(finish_time)
116
+ args.last_loss = last_loss
117
+ args.log_epoch()
118
+
119
+ tb_lg.update(min_loss=min_loss, head='train', step=ep)
120
+ tb_lg.update(rest_hours=round(remain_secs/60/60, 2), head='z_burnout', step=ep)
121
+ tb_lg.flush()
122
+
123
+ # finish pre-training
124
+ tb_lg.update(min_loss=min_loss, head='result', step=ep_start)
125
+ tb_lg.update(min_loss=min_loss, head='result', step=args.ep)
126
+ tb_lg.flush()
127
+ print(f'final args:\n{str(args)}')
128
+ print('\n\n')
129
+ print(f' [*] [PT finished] Min/Last Recon Loss: {performance_desc}, Total Cost: {(time.time() - pt_start_time) / 60 / 60:.1f}h\n')
130
+ print('\n\n')
131
+ tb_lg.close()
132
+ time.sleep(10)
133
+
134
+ args.remain_time, args.finish_time = '-', time.strftime("%m-%d %H:%M", time.localtime(time.time()))
135
+ args.log_epoch()
136
+
137
+
138
+ def pre_train_one_ep(ep, args: arg_util.Args, tb_lg: misc.TensorboardLogger, itrt_train, iters_train, model: DistributedDataParallel, optimizer):
139
+ model.train()
140
+ me = misc.MetricLogger(delimiter=' ')
141
+ me.add_meter('max_lr', misc.SmoothedValue(window_size=1, fmt='{value:.5f}'))
142
+ header = f'[PT] Epoch {ep}:'
143
+
144
+ optimizer.zero_grad()
145
+ early_clipping = args.clip > 0 and not hasattr(optimizer, 'global_grad_norm')
146
+ late_clipping = hasattr(optimizer, 'global_grad_norm')
147
+ if early_clipping:
148
+ params_req_grad = [p for p in model.parameters() if p.requires_grad]
149
+
150
+ for it, inp in enumerate(me.log_every(iters_train, itrt_train, 3, header)):
151
+ # adjust lr and wd
152
+ min_lr, max_lr, min_wd, max_wd = lr_wd_annealing(optimizer, args.lr, args.wd, args.wde, it + ep * iters_train, args.wp_ep * iters_train, args.ep * iters_train)
153
+
154
+ # forward and backward
155
+ inp = inp.to(args.device, non_blocking=True)
156
+ SparK.forward
157
+ loss = model(inp, active_b1ff=None, vis=False)
158
+ optimizer.zero_grad()
159
+ loss.backward()
160
+ loss = loss.item()
161
+ if not math.isfinite(loss):
162
+ print(f'[rk{dist.get_rank():02d}] Loss is {loss}, stopping training!', force=True, flush=True)
163
+ sys.exit(-1)
164
+
165
+ # optimize
166
+ grad_norm = None
167
+ if early_clipping: grad_norm = torch.nn.utils.clip_grad_norm_(params_req_grad, args.clip).item()
168
+ optimizer.step()
169
+ if late_clipping: grad_norm = optimizer.global_grad_norm
170
+ torch.cuda.synchronize()
171
+
172
+ # log
173
+ me.update(last_loss=loss)
174
+ me.update(max_lr=max_lr)
175
+ tb_lg.update(loss=me.meters['last_loss'].global_avg, head='train_loss')
176
+ tb_lg.update(sche_lr=max_lr, head='train_hp/lr_max')
177
+ tb_lg.update(sche_lr=min_lr, head='train_hp/lr_min')
178
+ tb_lg.update(sche_wd=max_wd, head='train_hp/wd_max')
179
+ tb_lg.update(sche_wd=min_wd, head='train_hp/wd_min')
180
+
181
+ if grad_norm is not None:
182
+ me.update(orig_norm=grad_norm)
183
+ tb_lg.update(orig_norm=grad_norm, head='train_hp')
184
+ tb_lg.set_step()
185
+
186
+ me.synchronize_between_processes()
187
+ return {k: meter.global_avg for k, meter in me.meters.items()}
188
+
189
+
190
+ if __name__ == '__main__':
191
+ main_pt()
spark/pretrain/models/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from timm import create_model
9
+ from timm.loss import SoftTargetCrossEntropy
10
+ from timm.models.layers import drop
11
+
12
+
13
+ from models.convnext import ConvNeXt
14
+ from models.resnet import ResNet
15
+ from models.custom import YourConvNet
16
+ _import_resnets_for_timm_registration = (ResNet,)
17
+
18
+
19
+ # log more
20
+ def _ex_repr(self):
21
+ return ', '.join(
22
+ f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v))
23
+ for k, v in vars(self).items()
24
+ if not k.startswith('_') and k != 'training'
25
+ and not isinstance(v, (torch.nn.Module, torch.Tensor))
26
+ )
27
+ for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath):
28
+ if hasattr(clz, 'extra_repr'):
29
+ clz.extra_repr = _ex_repr
30
+ else:
31
+ clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})'
32
+
33
+
34
+ pretrain_default_model_kwargs = {
35
+ 'V9back': dict(),
36
+
37
+ 'resnet50': dict(drop_path_rate=0.05),
38
+ 'resnet101': dict(drop_path_rate=0.08),
39
+ 'resnet152': dict(drop_path_rate=0.10),
40
+ 'resnet200': dict(drop_path_rate=0.15),
41
+ 'convnext_small': dict(sparse=True, drop_path_rate=0.2),
42
+ 'convnext_base': dict(sparse=True, drop_path_rate=0.3),
43
+ 'convnext_large': dict(sparse=True, drop_path_rate=0.4),
44
+
45
+ }
46
+ for kw in pretrain_default_model_kwargs.values():
47
+ kw['pretrained'] = False
48
+ kw['num_classes'] = 0
49
+ kw['global_pool'] = ''
50
+
51
+
52
+ def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False):
53
+ from encoder import SparseEncoder
54
+
55
+ kwargs = pretrain_default_model_kwargs[name]
56
+ if drop_path_rate != 0:
57
+ kwargs['drop_path_rate'] = drop_path_rate
58
+ print(f'[build_sparse_encoder] model kwargs={kwargs}')
59
+ cnn = create_model(name, **kwargs)
60
+
61
+ return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose)
62
+
spark/pretrain/models/convnext.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This file is basically a copy of: https://github.com/facebookresearch/ConvNeXt/blob/06f7b05f922e21914916406141f50f82b4a15852/models/convnext.py
8
+ from typing import List
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from timm.models.layers import trunc_normal_
13
+ from timm.models.registry import register_model
14
+
15
+ from encoder import SparseConvNeXtBlock, SparseConvNeXtLayerNorm
16
+
17
+
18
+ class ConvNeXt(nn.Module):
19
+ r""" ConvNeXt
20
+ A PyTorch impl of : `A ConvNet for the 2020s` -
21
+ https://arxiv.org/pdf/2201.03545.pdf
22
+ Args:
23
+ in_chans (int): Number of input image channels. Default: 3
24
+ num_classes (int): Number of classes for classification head. Default: 1000
25
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
26
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
27
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
28
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
29
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
30
+ """
31
+
32
+ def __init__(self, in_chans=3, num_classes=1000,
33
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
34
+ layer_scale_init_value=1e-6, head_init_scale=1., global_pool='avg',
35
+ sparse=True,
36
+ ):
37
+ super().__init__()
38
+ self.dims: List[int] = dims
39
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
40
+ stem = nn.Sequential(
41
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
42
+ SparseConvNeXtLayerNorm(dims[0], eps=1e-6, data_format="channels_first", sparse=sparse)
43
+ )
44
+ self.downsample_layers.append(stem)
45
+ for i in range(3):
46
+ downsample_layer = nn.Sequential(
47
+ SparseConvNeXtLayerNorm(dims[i], eps=1e-6, data_format="channels_first", sparse=sparse),
48
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
49
+ )
50
+ self.downsample_layers.append(downsample_layer)
51
+
52
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
53
+ self.drop_path_rate = drop_path_rate
54
+ self.layer_scale_init_value = layer_scale_init_value
55
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
56
+ cur = 0
57
+ for i in range(4):
58
+ stage = nn.Sequential(
59
+ *[SparseConvNeXtBlock(dim=dims[i], drop_path=dp_rates[cur + j],
60
+ layer_scale_init_value=layer_scale_init_value, sparse=sparse) for j in range(depths[i])]
61
+ )
62
+ self.stages.append(stage)
63
+ cur += depths[i]
64
+ self.depths = depths
65
+
66
+ self.apply(self._init_weights)
67
+ if num_classes > 0:
68
+ self.norm = SparseConvNeXtLayerNorm(dims[-1], eps=1e-6, sparse=False) # final norm layer for LE/FT; should not be sparse
69
+ self.fc = nn.Linear(dims[-1], num_classes)
70
+ else:
71
+ self.norm = nn.Identity()
72
+ self.fc = nn.Identity()
73
+
74
+ def _init_weights(self, m):
75
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
76
+ trunc_normal_(m.weight, std=.02)
77
+ nn.init.constant_(m.bias, 0)
78
+
79
+ def get_downsample_ratio(self) -> int:
80
+ return 32
81
+
82
+ def get_feature_map_channels(self) -> List[int]:
83
+ return self.dims
84
+
85
+ def forward(self, x, hierarchical=False):
86
+ if hierarchical:
87
+ ls = []
88
+ for i in range(4):
89
+ x = self.downsample_layers[i](x)
90
+ x = self.stages[i](x)
91
+ ls.append(x)
92
+ return ls
93
+ else:
94
+ return self.fc(self.norm(x.mean([-2, -1]))) # (B, C, H, W) =mean=> (B, C) =norm&fc=> (B, NumCls)
95
+
96
+ def get_classifier(self):
97
+ return self.fc
98
+
99
+ def extra_repr(self):
100
+ return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}'
101
+
102
+
103
+ @register_model
104
+ def convnext_tiny(pretrained=False, in_22k=False, **kwargs):
105
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
106
+ return model
107
+
108
+
109
+ @register_model
110
+ def convnext_small(pretrained=False, in_22k=False, **kwargs):
111
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
112
+ return model
113
+
114
+
115
+ @register_model
116
+ def convnext_base(pretrained=False, in_22k=False, **kwargs):
117
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
118
+ return model
119
+
120
+
121
+ @register_model
122
+ def convnext_large(pretrained=False, in_22k=False, **kwargs):
123
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
124
+ return model
125
+
spark/pretrain/models/custom.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import List
10
+ from timm.models.registry import register_model
11
+ import torch
12
+ from torch import nn
13
+ import sys
14
+ from HG.HGBlock import HGStem,HGBlock
15
+ from HG.block import DWConv
16
+ from v9back.common import *
17
+
18
+
19
+ class YourConvNet(nn.Module):
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__()
22
+
23
+ self.mlist=nn.ModuleList(
24
+ [Silence(),
25
+ Bbackbone(),
26
+ ]
27
+ )
28
+ self.d0= Down0(64)
29
+ self.d1 = Down1(128)
30
+ self.d2 = Down2(256)
31
+ self.d3 = Down3(512)
32
+ self.d4 = Down4(1024)
33
+ self.alld = [self.d0,self.d1,self.d2,self.d3,self.d4]
34
+ self.cblinear1 = CBLinear(64,[64])
35
+ self.cblinear3 = CBLinear(128, [64, 128])
36
+ self.cblinear5 = CBLinear(256, [64, 128, 256])
37
+ self.cblinear7 = CBLinear(512, [64, 128, 256, 512])
38
+ self.cblinear9 = CBLinear(1024, [64, 128, 256, 512, 1024])
39
+ self.allcblinear = [self.cblinear1,self.cblinear3,self.cblinear5,self.cblinear7,self.cblinear9]
40
+ # # conv down 1
41
+ self.conv1 = Conv(3, 64, 3, 2 )
42
+ self.cbfuse1 = CBFuse([0, 0, 0, 0, 0])
43
+
44
+ ## conv down 2
45
+ self.conv2= Conv(64, 128, 3, 2)
46
+ self.cbfuse2 = CBFuse([1, 1, 1, 1])
47
+ self.rep2 = RepNCSPELAN4(128, 256, 128, 64, 2)
48
+ ## avg-conv down fuse 1
49
+ self.adown3 = ADown(256, 256)
50
+ self.cbfuse3 = CBFuse([2, 2, 2])
51
+ self.rep3 = RepNCSPELAN4(256, 512, 256, 128, 2)
52
+
53
+ ## avg-conv down fuse 2
54
+ self.adown4 = ADown(512, 512)
55
+ self.cbfuse4 = CBFuse([3,3])
56
+ self.rep4 = RepNCSPELAN4(512, 1024, 512, 256, 2)
57
+
58
+ ## avg-conv down fuse 3
59
+ self.adown5 = ADown(1024, 1024)
60
+ self.cbfuse5 = CBFuse([4])
61
+ self.rep5 = RepNCSPELAN4(1024, 1024, 512, 256, 2)
62
+
63
+ def get_downsample_ratio(self) -> int:
64
+ return 32
65
+
66
+ def get_feature_map_channels(self) -> List[int]:
67
+ return [ 256,512,1024,1024]
68
+
69
+ def forward(self, x: torch.Tensor, hierarchical=False):
70
+ if hierarchical:
71
+ origin = x.clone()
72
+ ls = []
73
+ tmp = []
74
+ bx = None
75
+ for index,modules in enumerate( self.mlist):
76
+ x = modules(x)
77
+ if index ==1:
78
+ bx = x
79
+ for i in range(5):
80
+ tmp.append(self.allcblinear[i](self.alld[i](bx)))
81
+
82
+ fuse1 = self.cbfuse1([tmp[0],tmp[1],tmp[2],tmp[3],tmp[4],self.conv1(origin)])
83
+ fuse2 = self.cbfuse2([tmp[1],tmp[2],tmp[3],tmp[4],self.conv2(fuse1)])
84
+ fuse2 = self.rep2(fuse2)
85
+
86
+ fuse3= self.cbfuse3([ tmp[2], tmp[3], tmp[4], self.adown3(fuse2)])
87
+ fuse3 = self.rep3(fuse3)
88
+
89
+ fuse4 = self.cbfuse4([tmp[3], tmp[4], self.adown4(fuse3)])
90
+ fuse4 = self.rep4(fuse4)
91
+
92
+ fuse5 = self.cbfuse5([tmp[4], self.adown5(fuse4)])
93
+ fuse5 = self.rep5(fuse5)
94
+
95
+ ls.append(fuse2)
96
+ ls.append(fuse3)
97
+ ls.append(fuse4)
98
+ ls.append(fuse5)
99
+ return ls
100
+ else:
101
+ for modules in self.mlist:
102
+ x = modules(x)
103
+ return x
104
+
105
+
106
+ @register_model
107
+ def V9back(pretrained=False, **kwargs):
108
+ return YourConvNet(**kwargs)
109
+
110
+
111
+ @torch.no_grad()
112
+ def convnet_test():
113
+ from timm.models import create_model
114
+ cnn = create_model('V9back')
115
+ print('get_downsample_ratio:', cnn.get_downsample_ratio())
116
+ print('get_feature_map_channels:', cnn.get_feature_map_channels())
117
+
118
+ downsample_ratio = cnn.get_downsample_ratio()
119
+ feature_map_channels = cnn.get_feature_map_channels()
120
+
121
+ # check the forward function
122
+ B, C, H, W = 4, 3, 224, 224
123
+ inp = torch.rand(B, C, H, W)
124
+ feats = cnn(inp, hierarchical=True)
125
+ assert isinstance(feats, list)
126
+ assert len(feats) == len(feature_map_channels)
127
+ print([tuple(t.shape) for t in feats])
128
+
129
+ # check the downsample ratio
130
+ feats = cnn(inp, hierarchical=True)
131
+ assert feats[-1].shape[-2] == H // downsample_ratio
132
+ assert feats[-1].shape[-1] == W // downsample_ratio
133
+
134
+ # check the channel number
135
+ for feat, ch in zip(feats, feature_map_channels):
136
+ assert feat.ndim == 4
137
+ assert feat.shape[1] == ch
138
+
139
+
140
+ if __name__ == '__main__':
141
+ convnet_test()
spark/pretrain/models/custom_detr.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import List
10
+ from timm.models.registry import register_model
11
+ import torch
12
+ from torch import nn
13
+ import sys
14
+ from HG.HGBlock import HGStem,HGBlock
15
+ from HG.block import DWConv
16
+
17
+
18
+ class YourConvNet(nn.Module):
19
+ def __init__(self, *args, **kwargs):
20
+ super().__init__()
21
+
22
+ self.mlist=nn.ModuleList(
23
+ [HGStem(3, 32, 64),
24
+ HGBlock(64, 64, 128, 3, n=6),
25
+
26
+ DWConv(128, 128, 3, 2, 1, False),
27
+ HGBlock(128, 128, 512, 3, n=6),
28
+ HGBlock(512, 128, 512, 3, lightconv=False,shortcut=True,n=6),
29
+
30
+
31
+ DWConv(512, 512, 3, 2, 1, False),
32
+ HGBlock(512, 256, 1024, 5,lightconv=True,shortcut=False,n=6),
33
+ HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
34
+ HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
35
+ HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
36
+ HGBlock(1024, 256, 1024, 5, lightconv=True, shortcut=True, n=6),
37
+
38
+
39
+
40
+ DWConv(1024, 1024, 3, 2, 1, False),
41
+ HGBlock(1024, 512, 2048, 5, lightconv=True, shortcut=False, n=6),
42
+ HGBlock(2048, 512, 2048, 5, lightconv=True, shortcut=True, n=6)
43
+ ]
44
+ )
45
+
46
+
47
+ def get_downsample_ratio(self) -> int:
48
+ return 32
49
+
50
+ def get_feature_map_channels(self) -> List[int]:
51
+ return [128,512,1024,2048]
52
+
53
+ def forward(self, x: torch.Tensor, hierarchical=False):
54
+ if hierarchical:
55
+ ls = []
56
+ for index,modules in enumerate( self.mlist):
57
+ x = modules(x)
58
+ if index in [1,4,10,13]:
59
+ ls.append(x)
60
+ return ls
61
+ else:
62
+ for modules in self.mlist:
63
+ x = modules(x)
64
+ return x
65
+
66
+
67
+ @register_model
68
+ def HGNetv2(pretrained=False, **kwargs):
69
+ return YourConvNet(**kwargs)
70
+
71
+
72
+ @torch.no_grad()
73
+ def convnet_test():
74
+ from timm.models import create_model
75
+ cnn = create_model('HGNetv2')
76
+ print('get_downsample_ratio:', cnn.get_downsample_ratio())
77
+ print('get_feature_map_channels:', cnn.get_feature_map_channels())
78
+
79
+ downsample_ratio = cnn.get_downsample_ratio()
80
+ feature_map_channels = cnn.get_feature_map_channels()
81
+
82
+ # check the forward function
83
+ B, C, H, W = 4, 3, 224, 224
84
+ inp = torch.rand(B, C, H, W)
85
+ feats = cnn(inp, hierarchical=True)
86
+ assert isinstance(feats, list)
87
+ assert len(feats) == len(feature_map_channels)
88
+ print([tuple(t.shape) for t in feats])
89
+
90
+ # check the downsample ratio
91
+ feats = cnn(inp, hierarchical=True)
92
+ assert feats[-1].shape[-2] == H // downsample_ratio
93
+ assert feats[-1].shape[-1] == W // downsample_ratio
94
+
95
+ # check the channel number
96
+ for feat, ch in zip(feats, feature_map_channels):
97
+ assert feat.ndim == 4
98
+ assert feat.shape[1] == ch
99
+
100
+
101
+ if __name__ == '__main__':
102
+ convnet_test()
spark/pretrain/models/custom_origin.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from typing import List
10
+ from timm.models.registry import register_model
11
+
12
+
13
+ class YourConvNet(nn.Module):
14
+ """
15
+ This is a template for your custom ConvNet.
16
+ It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
17
+ You can refer to the implementations in `pretrain\models\resnet.py` for an example.
18
+ """
19
+
20
+ def get_downsample_ratio(self) -> int:
21
+ """
22
+ This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
23
+
24
+ :return: the TOTAL downsample ratio of the ConvNet.
25
+ E.g., for a ResNet-50, this should return 32.
26
+ """
27
+ raise NotImplementedError
28
+
29
+ def get_feature_map_channels(self) -> List[int]:
30
+ """
31
+ This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
32
+
33
+ :return: a list of the number of channels of each feature map.
34
+ E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
35
+ """
36
+ raise NotImplementedError
37
+
38
+ def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
39
+ """
40
+ The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
41
+
42
+ :param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
43
+ :param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
44
+ :return:
45
+ - hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
46
+ - hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
47
+ E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
48
+ for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
49
+ """
50
+ raise NotImplementedError
51
+
52
+
53
+ @register_model
54
+ def your_convnet_small(pretrained=False, **kwargs):
55
+ raise NotImplementedError
56
+ return YourConvNet(**kwargs)
57
+
58
+
59
+ @torch.no_grad()
60
+ def convnet_test():
61
+ from timm.models import create_model
62
+ cnn = create_model('your_convnet_small')
63
+ print('get_downsample_ratio:', cnn.get_downsample_ratio())
64
+ print('get_feature_map_channels:', cnn.get_feature_map_channels())
65
+
66
+ downsample_ratio = cnn.get_downsample_ratio()
67
+ feature_map_channels = cnn.get_feature_map_channels()
68
+
69
+ # check the forward function
70
+ B, C, H, W = 4, 3, 224, 224
71
+ inp = torch.rand(B, C, H, W)
72
+ feats = cnn(inp, hierarchical=True)
73
+ assert isinstance(feats, list)
74
+ assert len(feats) == len(feature_map_channels)
75
+ print([tuple(t.shape) for t in feats])
76
+
77
+ # check the downsample ratio
78
+ feats = cnn(inp, hierarchical=True)
79
+ assert feats[-1].shape[-2] == H // downsample_ratio
80
+ assert feats[-1].shape[-1] == W // downsample_ratio
81
+
82
+ # check the channel number
83
+ for feat, ch in zip(feats, feature_map_channels):
84
+ assert feat.ndim == 4
85
+ assert feat.shape[1] == ch
86
+
87
+
88
+ if __name__ == '__main__':
89
+ convnet_test()