Upload 3 files
Browse files- main_for_image.py +402 -0
- pyproject.toml +72 -0
- requirements.txt +15 -0
main_for_image.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import argparse
|
| 3 |
+
import datetime
|
| 4 |
+
import inspect
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import albumentations as A
|
| 11 |
+
import colorlog
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import yaml
|
| 16 |
+
from mmengine import Config
|
| 17 |
+
from torch.utils import data
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
import methods as model_zoo
|
| 21 |
+
from utils import io, ops, pipeline, pt_utils, py_utils, recorder
|
| 22 |
+
|
| 23 |
+
LOGGER = logging.getLogger("main")
|
| 24 |
+
LOGGER.propagate = False
|
| 25 |
+
LOGGER.setLevel(level=logging.DEBUG)
|
| 26 |
+
stream_handler = logging.StreamHandler()
|
| 27 |
+
stream_handler.setLevel(logging.DEBUG)
|
| 28 |
+
stream_handler.setFormatter(colorlog.ColoredFormatter("%(log_color)s[%(filename)s] %(reset)s%(message)s"))
|
| 29 |
+
LOGGER.addHandler(stream_handler)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ImageTestDataset(data.Dataset):
|
| 33 |
+
def __init__(self, dataset_info: dict, shape: dict):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.shape = shape
|
| 36 |
+
|
| 37 |
+
image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
|
| 38 |
+
image_suffix = dataset_info["image"]["suffix"]
|
| 39 |
+
mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
|
| 40 |
+
mask_suffix = dataset_info["mask"]["suffix"]
|
| 41 |
+
|
| 42 |
+
image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)]
|
| 43 |
+
mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)]
|
| 44 |
+
valid_names = sorted(set(image_names).intersection(mask_names))
|
| 45 |
+
self.total_data_paths = [
|
| 46 |
+
(os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix) for n in valid_names
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, index):
|
| 50 |
+
image_path, mask_path = self.total_data_paths[index]
|
| 51 |
+
image = io.read_color_array(image_path)
|
| 52 |
+
|
| 53 |
+
base_h = self.shape["h"]
|
| 54 |
+
base_w = self.shape["w"]
|
| 55 |
+
|
| 56 |
+
images = ops.ms_resize(image, scales=(1.5, 1.0, 2.0), base_h=base_h, base_w=base_w)
|
| 57 |
+
image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1)
|
| 58 |
+
image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1)
|
| 59 |
+
image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1)
|
| 60 |
+
|
| 61 |
+
return dict(
|
| 62 |
+
data={"image_s": image_s, "image_m": image_m, "image_l": image_l},
|
| 63 |
+
info=dict(mask_path=mask_path, group_name="image"),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def __len__(self):
|
| 67 |
+
return len(self.total_data_paths)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ImageTrainDataset(data.Dataset):
|
| 71 |
+
def __init__(self, dataset_infos: dict, shape: dict):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.shape = shape
|
| 74 |
+
|
| 75 |
+
self.total_data_paths = []
|
| 76 |
+
for dataset_name, dataset_info in dataset_infos.items():
|
| 77 |
+
image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
|
| 78 |
+
image_suffix = dataset_info["image"]["suffix"]
|
| 79 |
+
mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
|
| 80 |
+
mask_suffix = dataset_info["mask"]["suffix"]
|
| 81 |
+
|
| 82 |
+
image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)]
|
| 83 |
+
mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)]
|
| 84 |
+
valid_names = sorted(set(image_names).intersection(mask_names))
|
| 85 |
+
data_paths = [
|
| 86 |
+
(os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix)
|
| 87 |
+
for n in valid_names
|
| 88 |
+
]
|
| 89 |
+
LOGGER.info(f"Length of {dataset_name}: {len(data_paths)}")
|
| 90 |
+
self.total_data_paths.extend(data_paths)
|
| 91 |
+
|
| 92 |
+
self.trains = A.Compose(
|
| 93 |
+
[
|
| 94 |
+
A.HorizontalFlip(p=0.5),
|
| 95 |
+
A.Rotate(limit=90, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE),
|
| 96 |
+
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
|
| 97 |
+
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.5),
|
| 98 |
+
]
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def __getitem__(self, index):
|
| 102 |
+
image_path, mask_path = self.total_data_paths[index]
|
| 103 |
+
image = io.read_color_array(image_path)
|
| 104 |
+
mask = io.read_gray_array(mask_path, thr=0)
|
| 105 |
+
if image.shape[:2] != mask.shape:
|
| 106 |
+
h, w = mask.shape
|
| 107 |
+
image = ops.resize(image, height=h, width=w)
|
| 108 |
+
|
| 109 |
+
transformed = self.trains(image=image, mask=mask)
|
| 110 |
+
image = transformed["image"]
|
| 111 |
+
mask = transformed["mask"]
|
| 112 |
+
|
| 113 |
+
base_h = self.shape["h"]
|
| 114 |
+
base_w = self.shape["w"]
|
| 115 |
+
|
| 116 |
+
images = ops.ms_resize(image, scales=(1.5, 1.0, 2.0), base_h=base_h, base_w=base_w)
|
| 117 |
+
image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1)
|
| 118 |
+
image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1)
|
| 119 |
+
image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1)
|
| 120 |
+
|
| 121 |
+
mask = ops.resize(mask, height=base_h, width=base_w)
|
| 122 |
+
mask = torch.from_numpy(mask).unsqueeze(0)
|
| 123 |
+
|
| 124 |
+
return dict(
|
| 125 |
+
data={
|
| 126 |
+
"image_s": image_s,
|
| 127 |
+
"image_m": image_m,
|
| 128 |
+
"image_l": image_l,
|
| 129 |
+
"mask": mask,
|
| 130 |
+
}
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return len(self.total_data_paths)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Evaluator:
|
| 138 |
+
def __init__(self, device, metric_names, clip_range=None):
|
| 139 |
+
self.device = device
|
| 140 |
+
self.clip_range = clip_range
|
| 141 |
+
self.metric_names = metric_names
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def eval(self, model, data_loader, save_path=""):
|
| 145 |
+
model.eval()
|
| 146 |
+
all_metrics = recorder.GroupedMetricRecorder(metric_names=self.metric_names)
|
| 147 |
+
|
| 148 |
+
for batch in tqdm(data_loader, total=len(data_loader), ncols=79, desc="[EVAL]"):
|
| 149 |
+
batch_images = pt_utils.to_device(batch["data"], device=self.device)
|
| 150 |
+
logits = model(data=batch_images) # B,1,H,W
|
| 151 |
+
probs = logits.sigmoid().squeeze(1).cpu().detach().numpy()
|
| 152 |
+
probs = probs - probs.min()
|
| 153 |
+
probs = probs / (probs.max() + 1e-8)
|
| 154 |
+
|
| 155 |
+
mask_paths = batch["info"]["mask_path"]
|
| 156 |
+
group_names = batch["info"]["group_name"]
|
| 157 |
+
for pred_idx, pred in enumerate(probs):
|
| 158 |
+
mask_path = mask_paths[pred_idx]
|
| 159 |
+
mask_array = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
| 160 |
+
mask_array[mask_array > 0] = 255
|
| 161 |
+
mask_h, mask_w = mask_array.shape
|
| 162 |
+
pred = ops.resize(pred, height=mask_h, width=mask_w)
|
| 163 |
+
|
| 164 |
+
if self.clip_range is not None:
|
| 165 |
+
pred = ops.clip_to_normalize(pred, clip_range=self.clip_range)
|
| 166 |
+
|
| 167 |
+
group_name = group_names[pred_idx]
|
| 168 |
+
if save_path:
|
| 169 |
+
ops.save_array_as_image(
|
| 170 |
+
data_array=pred,
|
| 171 |
+
save_name=os.path.basename(mask_path),
|
| 172 |
+
save_dir=os.path.join(save_path, group_name),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
pred = (pred * 255).astype(np.uint8)
|
| 176 |
+
all_metrics.step(group_name=group_name, pre=pred, gt=mask_array, gt_path=mask_path)
|
| 177 |
+
return all_metrics.show()
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def test(model, cfg):
|
| 181 |
+
test_wrapper = Evaluator(device=cfg.device, metric_names=cfg.metric_names, clip_range=cfg.test.clip_range)
|
| 182 |
+
|
| 183 |
+
for te_name in cfg.test.data.names:
|
| 184 |
+
te_info = cfg.dataset_infos[te_name]
|
| 185 |
+
te_dataset = ImageTestDataset(dataset_info=te_info, shape=cfg.test.data.shape)
|
| 186 |
+
te_loader = data.DataLoader(
|
| 187 |
+
dataset=te_dataset, batch_size=cfg.test.batch_size, num_workers=cfg.test.num_workers, pin_memory=True
|
| 188 |
+
)
|
| 189 |
+
LOGGER.info(f"Testing with testset: {te_name}: {len(te_dataset)}")
|
| 190 |
+
|
| 191 |
+
if cfg.save_results:
|
| 192 |
+
save_path = os.path.join(cfg.path.save, te_name)
|
| 193 |
+
LOGGER.info(f"Results will be saved into {save_path}")
|
| 194 |
+
else:
|
| 195 |
+
save_path = ""
|
| 196 |
+
|
| 197 |
+
seg_results = test_wrapper.eval(model=model, data_loader=te_loader, save_path=save_path)
|
| 198 |
+
seg_results_str = ", ".join([f"{k}: {v:.03f}" for k, v in seg_results.items()])
|
| 199 |
+
LOGGER.info(f"({te_name}): {py_utils.mapping_to_str(te_info)}\n{seg_results_str}")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def train(model, cfg):
|
| 203 |
+
tr_dataset = ImageTrainDataset(
|
| 204 |
+
dataset_infos={data_name: cfg.dataset_infos[data_name] for data_name in cfg.train.data.names},
|
| 205 |
+
shape=cfg.train.data.shape,
|
| 206 |
+
)
|
| 207 |
+
LOGGER.info(f"Total Length of Image Trainset: {len(tr_dataset)}")
|
| 208 |
+
|
| 209 |
+
tr_loader = data.DataLoader(
|
| 210 |
+
dataset=tr_dataset,
|
| 211 |
+
batch_size=cfg.train.batch_size,
|
| 212 |
+
num_workers=cfg.train.num_workers,
|
| 213 |
+
shuffle=True,
|
| 214 |
+
drop_last=True,
|
| 215 |
+
pin_memory=True,
|
| 216 |
+
worker_init_fn=pt_utils.customized_worker_init_fn if cfg.use_custom_worker_init else None,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
counter = recorder.TrainingCounter(
|
| 220 |
+
epoch_length=len(tr_loader),
|
| 221 |
+
epoch_based=cfg.train.epoch_based,
|
| 222 |
+
num_epochs=cfg.train.num_epochs,
|
| 223 |
+
num_total_iters=cfg.train.num_iters,
|
| 224 |
+
)
|
| 225 |
+
optimizer = pipeline.construct_optimizer(
|
| 226 |
+
model=model,
|
| 227 |
+
initial_lr=cfg.train.lr,
|
| 228 |
+
mode=cfg.train.optimizer.mode,
|
| 229 |
+
group_mode=cfg.train.optimizer.group_mode,
|
| 230 |
+
cfg=cfg.train.optimizer.cfg,
|
| 231 |
+
)
|
| 232 |
+
scheduler = pipeline.Scheduler(
|
| 233 |
+
optimizer=optimizer,
|
| 234 |
+
num_iters=counter.num_total_iters,
|
| 235 |
+
epoch_length=counter.num_inner_iters,
|
| 236 |
+
scheduler_cfg=cfg.train.scheduler,
|
| 237 |
+
step_by_batch=cfg.train.sche_usebatch,
|
| 238 |
+
)
|
| 239 |
+
scheduler.record_lrs(param_groups=optimizer.param_groups)
|
| 240 |
+
scheduler.plot_lr_coef_curve(save_path=cfg.path.pth_log)
|
| 241 |
+
scaler = pipeline.Scaler(optimizer, cfg.train.use_amp, set_to_none=cfg.train.optimizer.set_to_none)
|
| 242 |
+
|
| 243 |
+
LOGGER.info(f"Scheduler:\n{scheduler}\nOptimizer:\n{optimizer}")
|
| 244 |
+
|
| 245 |
+
loss_recorder = recorder.HistoryBuffer()
|
| 246 |
+
iter_time_recorder = recorder.HistoryBuffer()
|
| 247 |
+
|
| 248 |
+
LOGGER.info(f"Image Mean: {model.normalizer.mean.flatten()}, Image Std: {model.normalizer.std.flatten()}")
|
| 249 |
+
if cfg.train.bn.freeze_encoder:
|
| 250 |
+
LOGGER.info(" >>> Freeze Backbone !!! <<< ")
|
| 251 |
+
model.encoder.requires_grad_(False)
|
| 252 |
+
|
| 253 |
+
train_start_time = time.perf_counter()
|
| 254 |
+
for _ in range(counter.num_epochs):
|
| 255 |
+
LOGGER.info(f"Exp_Name: {cfg.exp_name}")
|
| 256 |
+
|
| 257 |
+
model.train()
|
| 258 |
+
if cfg.train.bn.freeze_status:
|
| 259 |
+
pt_utils.frozen_bn_stats(model.encoder, freeze_affine=cfg.train.bn.freeze_affine)
|
| 260 |
+
|
| 261 |
+
for batch_idx, batch in enumerate(tr_loader):
|
| 262 |
+
iter_start_time = time.perf_counter()
|
| 263 |
+
scheduler.step(curr_idx=counter.curr_iter)
|
| 264 |
+
|
| 265 |
+
data_batch = pt_utils.to_device(data=batch["data"], device=cfg.device)
|
| 266 |
+
with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
|
| 267 |
+
outputs = model(data=data_batch, iter_percentage=counter.curr_percent)
|
| 268 |
+
|
| 269 |
+
loss = outputs["loss"]
|
| 270 |
+
loss_str = outputs["loss_str"]
|
| 271 |
+
loss = loss / cfg.train.grad_acc_step
|
| 272 |
+
scaler.calculate_grad(loss=loss)
|
| 273 |
+
if counter.every_n_iters(cfg.train.grad_acc_step):
|
| 274 |
+
scaler.update_grad()
|
| 275 |
+
|
| 276 |
+
item_loss = loss.item()
|
| 277 |
+
data_shape = tuple(data_batch["mask"].shape)
|
| 278 |
+
loss_recorder.update(value=item_loss, num=data_shape[0])
|
| 279 |
+
|
| 280 |
+
if cfg.log_interval > 0 and (
|
| 281 |
+
counter.every_n_iters(cfg.log_interval)
|
| 282 |
+
or counter.is_first_inner_iter()
|
| 283 |
+
or counter.is_last_inner_iter()
|
| 284 |
+
or counter.is_last_total_iter()
|
| 285 |
+
):
|
| 286 |
+
gpu_mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G"
|
| 287 |
+
eta_seconds = iter_time_recorder.avg * (counter.num_total_iters - counter.curr_iter - 1)
|
| 288 |
+
eta_string = f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}"
|
| 289 |
+
progress = (
|
| 290 |
+
f"{counter.curr_iter}:{counter.num_total_iters} "
|
| 291 |
+
f"{batch_idx}/{counter.num_inner_iters} "
|
| 292 |
+
f"{counter.curr_epoch}/{counter.num_epochs}"
|
| 293 |
+
)
|
| 294 |
+
loss_info = f"{loss_str} (M:{loss_recorder.global_avg:.5f}/C:{item_loss:.5f})"
|
| 295 |
+
lr_info = f"LR: {optimizer.lr_string()}"
|
| 296 |
+
LOGGER.info(f"{eta_string}({gpu_mem}) | {progress} | {lr_info} | {loss_info} | {data_shape}")
|
| 297 |
+
cfg.tb_logger.write_to_tb("lr", optimizer.lr_groups(), counter.curr_iter)
|
| 298 |
+
cfg.tb_logger.write_to_tb("iter_loss", item_loss, counter.curr_iter)
|
| 299 |
+
cfg.tb_logger.write_to_tb("avg_loss", loss_recorder.global_avg, counter.curr_iter)
|
| 300 |
+
|
| 301 |
+
if counter.curr_iter < 3:
|
| 302 |
+
recorder.plot_results(
|
| 303 |
+
dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
|
| 304 |
+
save_path=os.path.join(cfg.path.pth_log, "img", f"iter_{counter.curr_iter}.png"),
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
iter_time_recorder.update(value=time.perf_counter() - iter_start_time)
|
| 308 |
+
if counter.is_last_total_iter():
|
| 309 |
+
break
|
| 310 |
+
counter.update_iter_counter()
|
| 311 |
+
|
| 312 |
+
recorder.plot_results(
|
| 313 |
+
dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
|
| 314 |
+
save_path=os.path.join(cfg.path.pth_log, "img", f"epoch_{counter.curr_epoch}.png"),
|
| 315 |
+
)
|
| 316 |
+
io.save_weight(model=model, save_path=cfg.path.final_state_net)
|
| 317 |
+
counter.update_epoch_counter()
|
| 318 |
+
|
| 319 |
+
cfg.tb_logger.close_tb()
|
| 320 |
+
io.save_weight(model=model, save_path=cfg.path.final_state_net)
|
| 321 |
+
|
| 322 |
+
total_train_time = time.perf_counter() - train_start_time
|
| 323 |
+
total_other_time = datetime.timedelta(seconds=int(total_train_time - iter_time_recorder.global_sum))
|
| 324 |
+
LOGGER.info(
|
| 325 |
+
f"Total Training Time: {datetime.timedelta(seconds=int(total_train_time))} ({total_other_time} on others)"
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def parse_cfg():
|
| 330 |
+
parser = argparse.ArgumentParser("Training and evaluation script")
|
| 331 |
+
parser.add_argument("--config", required=True, type=str)
|
| 332 |
+
parser.add_argument("--data-cfg", type=str, default="./dataset.yaml")
|
| 333 |
+
parser.add_argument("--model-name", type=str, choices=model_zoo.__dict__.keys())
|
| 334 |
+
parser.add_argument("--output-dir", type=str, default="outputs")
|
| 335 |
+
parser.add_argument("--load-from", type=str)
|
| 336 |
+
parser.add_argument("--pretrained", action="store_true")
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--metric-names",
|
| 339 |
+
nargs="+",
|
| 340 |
+
type=str,
|
| 341 |
+
default=["sm", "wfm", "mae", "em", "fmeasure"],
|
| 342 |
+
choices=recorder.GroupedMetricRecorder.supported_metrics,
|
| 343 |
+
)
|
| 344 |
+
parser.add_argument("--evaluate", action="store_true")
|
| 345 |
+
parser.add_argument("--save-results", action="store_true")
|
| 346 |
+
parser.add_argument("--info", type=str)
|
| 347 |
+
args = parser.parse_args()
|
| 348 |
+
|
| 349 |
+
cfg = Config.fromfile(args.config)
|
| 350 |
+
cfg.merge_from_dict(vars(args))
|
| 351 |
+
|
| 352 |
+
with open(cfg.data_cfg, mode="r") as f:
|
| 353 |
+
cfg.dataset_infos = yaml.safe_load(f)
|
| 354 |
+
|
| 355 |
+
cfg.proj_root = os.path.dirname(os.path.abspath(__file__))
|
| 356 |
+
cfg.exp_name = py_utils.construct_exp_name(model_name=cfg.model_name, cfg=cfg)
|
| 357 |
+
cfg.output_dir = os.path.join(cfg.proj_root, cfg.output_dir)
|
| 358 |
+
cfg.path = py_utils.construct_path(output_dir=cfg.output_dir, exp_name=cfg.exp_name)
|
| 359 |
+
cfg.device = "cuda:0"
|
| 360 |
+
|
| 361 |
+
py_utils.pre_mkdir(cfg.path)
|
| 362 |
+
with open(cfg.path.cfg_copy, encoding="utf-8", mode="w") as f:
|
| 363 |
+
f.write(cfg.pretty_text)
|
| 364 |
+
shutil.copy(__file__, cfg.path.trainer_copy)
|
| 365 |
+
|
| 366 |
+
file_handler = logging.FileHandler(cfg.path.log)
|
| 367 |
+
file_handler.setLevel(logging.INFO)
|
| 368 |
+
file_handler.setFormatter(logging.Formatter("[%(filename)s] %(message)s"))
|
| 369 |
+
LOGGER.addHandler(file_handler)
|
| 370 |
+
LOGGER.info(cfg.pretty_text)
|
| 371 |
+
|
| 372 |
+
cfg.tb_logger = recorder.TBLogger(tb_root=cfg.path.tb)
|
| 373 |
+
return cfg
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def main():
|
| 377 |
+
cfg = parse_cfg()
|
| 378 |
+
pt_utils.initialize_seed_cudnn(seed=cfg.base_seed, deterministic=cfg.deterministic)
|
| 379 |
+
|
| 380 |
+
model_class = model_zoo.__dict__.get(cfg.model_name)
|
| 381 |
+
assert model_class is not None, "Please check your --model-name"
|
| 382 |
+
model_code = inspect.getsource(model_class)
|
| 383 |
+
model = model_class(num_frames=1, pretrained=cfg.pretrained)
|
| 384 |
+
LOGGER.info(model_code)
|
| 385 |
+
model.to(cfg.device)
|
| 386 |
+
|
| 387 |
+
if cfg.load_from:
|
| 388 |
+
io.load_weight(model=model, load_path=cfg.load_from, strict=True)
|
| 389 |
+
|
| 390 |
+
LOGGER.info(f"Number of Parameters: {sum((v.numel() for v in model.parameters(recurse=True)))}")
|
| 391 |
+
if not cfg.evaluate:
|
| 392 |
+
train(model=model, cfg=cfg)
|
| 393 |
+
|
| 394 |
+
if cfg.evaluate or cfg.has_test:
|
| 395 |
+
io.save_weight(model=model, save_path=cfg.path.final_state_net)
|
| 396 |
+
test(model=model, cfg=cfg)
|
| 397 |
+
|
| 398 |
+
LOGGER.info("End training...")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
if __name__ == "__main__":
|
| 402 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/LongTengDao/TOML/
|
| 2 |
+
|
| 3 |
+
[tool.isort]
|
| 4 |
+
# https://pycqa.github.io/isort/docs/configuration/options/
|
| 5 |
+
profile = "black"
|
| 6 |
+
multi_line_output = 3
|
| 7 |
+
filter_files = true
|
| 8 |
+
supported_extensions = "py"
|
| 9 |
+
|
| 10 |
+
[tool.black]
|
| 11 |
+
line-length = 119
|
| 12 |
+
include = '\.pyi?$'
|
| 13 |
+
exclude = '''
|
| 14 |
+
/(
|
| 15 |
+
\.eggs
|
| 16 |
+
| \.git
|
| 17 |
+
| \.idea
|
| 18 |
+
| \.vscode
|
| 19 |
+
| \.hg
|
| 20 |
+
| \.mypy_cache
|
| 21 |
+
| \.tox
|
| 22 |
+
| \.venv
|
| 23 |
+
| _build
|
| 24 |
+
| buck-out
|
| 25 |
+
| build
|
| 26 |
+
| dist
|
| 27 |
+
| output
|
| 28 |
+
)/
|
| 29 |
+
'''
|
| 30 |
+
|
| 31 |
+
[tool.ruff]
|
| 32 |
+
# Same as Black.
|
| 33 |
+
line-length = 119
|
| 34 |
+
indent-width = 4
|
| 35 |
+
# Exclude a variety of commonly ignored directories.
|
| 36 |
+
exclude = [
|
| 37 |
+
".bzr",
|
| 38 |
+
".direnv",
|
| 39 |
+
".eggs",
|
| 40 |
+
".git",
|
| 41 |
+
".git-rewrite",
|
| 42 |
+
".hg",
|
| 43 |
+
".ipynb_checkpoints",
|
| 44 |
+
".mypy_cache",
|
| 45 |
+
".nox",
|
| 46 |
+
".pants.d",
|
| 47 |
+
".pyenv",
|
| 48 |
+
".pytest_cache",
|
| 49 |
+
".pytype",
|
| 50 |
+
".ruff_cache",
|
| 51 |
+
".svn",
|
| 52 |
+
".tox",
|
| 53 |
+
".venv",
|
| 54 |
+
".vscode",
|
| 55 |
+
"__pypackages__",
|
| 56 |
+
"_build",
|
| 57 |
+
"buck-out",
|
| 58 |
+
"build",
|
| 59 |
+
"dist",
|
| 60 |
+
"node_modules",
|
| 61 |
+
"site-packages",
|
| 62 |
+
"venv",
|
| 63 |
+
]
|
| 64 |
+
[tool.ruff.format]
|
| 65 |
+
# Like Black, use double quotes for strings.
|
| 66 |
+
quote-style = "double"
|
| 67 |
+
# Like Black, indent with spaces, rather than tabs.
|
| 68 |
+
indent-style = "space"
|
| 69 |
+
# Like Black, respect magic trailing commas.
|
| 70 |
+
skip-magic-trailing-comma = false
|
| 71 |
+
# Like Black, automatically detect the appropriate line ending.
|
| 72 |
+
line-ending = "auto"
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatically generated by https://github.com/damnever/pigar.
|
| 2 |
+
|
| 3 |
+
adjustText==0.8
|
| 4 |
+
albumentations==1.3.1
|
| 5 |
+
colorlog==6.8.0
|
| 6 |
+
einops==0.7.0
|
| 7 |
+
matplotlib==3.8.2
|
| 8 |
+
mmengine==0.10.2
|
| 9 |
+
numpy==1.26.2
|
| 10 |
+
opencv-python-headless==4.8.1.78
|
| 11 |
+
pysodmetrics==1.4.2
|
| 12 |
+
PyYAML==6.0
|
| 13 |
+
scipy==1.11.4
|
| 14 |
+
timm==0.9.12
|
| 15 |
+
tqdm==4.66.1
|