MSRNet / main_for_image.py
linaa98's picture
Upload 3 files
58278de verified
raw
history blame
16.5 kB
# -*- coding: utf-8 -*-
import argparse
import datetime
import inspect
import logging
import os
import shutil
import time
import albumentations as A
import colorlog
import cv2
import numpy as np
import torch
import yaml
from mmengine import Config
from torch.utils import data
from tqdm import tqdm
import methods as model_zoo
from utils import io, ops, pipeline, pt_utils, py_utils, recorder
LOGGER = logging.getLogger("main")
LOGGER.propagate = False
LOGGER.setLevel(level=logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(colorlog.ColoredFormatter("%(log_color)s[%(filename)s] %(reset)s%(message)s"))
LOGGER.addHandler(stream_handler)
class ImageTestDataset(data.Dataset):
def __init__(self, dataset_info: dict, shape: dict):
super().__init__()
self.shape = shape
image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
image_suffix = dataset_info["image"]["suffix"]
mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
mask_suffix = dataset_info["mask"]["suffix"]
image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)]
mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)]
valid_names = sorted(set(image_names).intersection(mask_names))
self.total_data_paths = [
(os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix) for n in valid_names
]
def __getitem__(self, index):
image_path, mask_path = self.total_data_paths[index]
image = io.read_color_array(image_path)
base_h = self.shape["h"]
base_w = self.shape["w"]
images = ops.ms_resize(image, scales=(1.5, 1.0, 2.0), base_h=base_h, base_w=base_w)
image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1)
image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1)
image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1)
return dict(
data={"image_s": image_s, "image_m": image_m, "image_l": image_l},
info=dict(mask_path=mask_path, group_name="image"),
)
def __len__(self):
return len(self.total_data_paths)
class ImageTrainDataset(data.Dataset):
def __init__(self, dataset_infos: dict, shape: dict):
super().__init__()
self.shape = shape
self.total_data_paths = []
for dataset_name, dataset_info in dataset_infos.items():
image_path = os.path.join(dataset_info["root"], dataset_info["image"]["path"])
image_suffix = dataset_info["image"]["suffix"]
mask_path = os.path.join(dataset_info["root"], dataset_info["mask"]["path"])
mask_suffix = dataset_info["mask"]["suffix"]
image_names = [p[: -len(image_suffix)] for p in sorted(os.listdir(image_path)) if p.endswith(image_suffix)]
mask_names = [p[: -len(mask_suffix)] for p in sorted(os.listdir(mask_path)) if p.endswith(mask_suffix)]
valid_names = sorted(set(image_names).intersection(mask_names))
data_paths = [
(os.path.join(image_path, n) + image_suffix, os.path.join(mask_path, n) + mask_suffix)
for n in valid_names
]
LOGGER.info(f"Length of {dataset_name}: {len(data_paths)}")
self.total_data_paths.extend(data_paths)
self.trains = A.Compose(
[
A.HorizontalFlip(p=0.5),
A.Rotate(limit=90, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REPLICATE),
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.5),
]
)
def __getitem__(self, index):
image_path, mask_path = self.total_data_paths[index]
image = io.read_color_array(image_path)
mask = io.read_gray_array(mask_path, thr=0)
if image.shape[:2] != mask.shape:
h, w = mask.shape
image = ops.resize(image, height=h, width=w)
transformed = self.trains(image=image, mask=mask)
image = transformed["image"]
mask = transformed["mask"]
base_h = self.shape["h"]
base_w = self.shape["w"]
images = ops.ms_resize(image, scales=(1.5, 1.0, 2.0), base_h=base_h, base_w=base_w)
image_s = torch.from_numpy(images[0]).div(255).permute(2, 0, 1)
image_m = torch.from_numpy(images[1]).div(255).permute(2, 0, 1)
image_l = torch.from_numpy(images[2]).div(255).permute(2, 0, 1)
mask = ops.resize(mask, height=base_h, width=base_w)
mask = torch.from_numpy(mask).unsqueeze(0)
return dict(
data={
"image_s": image_s,
"image_m": image_m,
"image_l": image_l,
"mask": mask,
}
)
def __len__(self):
return len(self.total_data_paths)
class Evaluator:
def __init__(self, device, metric_names, clip_range=None):
self.device = device
self.clip_range = clip_range
self.metric_names = metric_names
@torch.no_grad()
def eval(self, model, data_loader, save_path=""):
model.eval()
all_metrics = recorder.GroupedMetricRecorder(metric_names=self.metric_names)
for batch in tqdm(data_loader, total=len(data_loader), ncols=79, desc="[EVAL]"):
batch_images = pt_utils.to_device(batch["data"], device=self.device)
logits = model(data=batch_images) # B,1,H,W
probs = logits.sigmoid().squeeze(1).cpu().detach().numpy()
probs = probs - probs.min()
probs = probs / (probs.max() + 1e-8)
mask_paths = batch["info"]["mask_path"]
group_names = batch["info"]["group_name"]
for pred_idx, pred in enumerate(probs):
mask_path = mask_paths[pred_idx]
mask_array = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask_array[mask_array > 0] = 255
mask_h, mask_w = mask_array.shape
pred = ops.resize(pred, height=mask_h, width=mask_w)
if self.clip_range is not None:
pred = ops.clip_to_normalize(pred, clip_range=self.clip_range)
group_name = group_names[pred_idx]
if save_path:
ops.save_array_as_image(
data_array=pred,
save_name=os.path.basename(mask_path),
save_dir=os.path.join(save_path, group_name),
)
pred = (pred * 255).astype(np.uint8)
all_metrics.step(group_name=group_name, pre=pred, gt=mask_array, gt_path=mask_path)
return all_metrics.show()
def test(model, cfg):
test_wrapper = Evaluator(device=cfg.device, metric_names=cfg.metric_names, clip_range=cfg.test.clip_range)
for te_name in cfg.test.data.names:
te_info = cfg.dataset_infos[te_name]
te_dataset = ImageTestDataset(dataset_info=te_info, shape=cfg.test.data.shape)
te_loader = data.DataLoader(
dataset=te_dataset, batch_size=cfg.test.batch_size, num_workers=cfg.test.num_workers, pin_memory=True
)
LOGGER.info(f"Testing with testset: {te_name}: {len(te_dataset)}")
if cfg.save_results:
save_path = os.path.join(cfg.path.save, te_name)
LOGGER.info(f"Results will be saved into {save_path}")
else:
save_path = ""
seg_results = test_wrapper.eval(model=model, data_loader=te_loader, save_path=save_path)
seg_results_str = ", ".join([f"{k}: {v:.03f}" for k, v in seg_results.items()])
LOGGER.info(f"({te_name}): {py_utils.mapping_to_str(te_info)}\n{seg_results_str}")
def train(model, cfg):
tr_dataset = ImageTrainDataset(
dataset_infos={data_name: cfg.dataset_infos[data_name] for data_name in cfg.train.data.names},
shape=cfg.train.data.shape,
)
LOGGER.info(f"Total Length of Image Trainset: {len(tr_dataset)}")
tr_loader = data.DataLoader(
dataset=tr_dataset,
batch_size=cfg.train.batch_size,
num_workers=cfg.train.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True,
worker_init_fn=pt_utils.customized_worker_init_fn if cfg.use_custom_worker_init else None,
)
counter = recorder.TrainingCounter(
epoch_length=len(tr_loader),
epoch_based=cfg.train.epoch_based,
num_epochs=cfg.train.num_epochs,
num_total_iters=cfg.train.num_iters,
)
optimizer = pipeline.construct_optimizer(
model=model,
initial_lr=cfg.train.lr,
mode=cfg.train.optimizer.mode,
group_mode=cfg.train.optimizer.group_mode,
cfg=cfg.train.optimizer.cfg,
)
scheduler = pipeline.Scheduler(
optimizer=optimizer,
num_iters=counter.num_total_iters,
epoch_length=counter.num_inner_iters,
scheduler_cfg=cfg.train.scheduler,
step_by_batch=cfg.train.sche_usebatch,
)
scheduler.record_lrs(param_groups=optimizer.param_groups)
scheduler.plot_lr_coef_curve(save_path=cfg.path.pth_log)
scaler = pipeline.Scaler(optimizer, cfg.train.use_amp, set_to_none=cfg.train.optimizer.set_to_none)
LOGGER.info(f"Scheduler:\n{scheduler}\nOptimizer:\n{optimizer}")
loss_recorder = recorder.HistoryBuffer()
iter_time_recorder = recorder.HistoryBuffer()
LOGGER.info(f"Image Mean: {model.normalizer.mean.flatten()}, Image Std: {model.normalizer.std.flatten()}")
if cfg.train.bn.freeze_encoder:
LOGGER.info(" >>> Freeze Backbone !!! <<< ")
model.encoder.requires_grad_(False)
train_start_time = time.perf_counter()
for _ in range(counter.num_epochs):
LOGGER.info(f"Exp_Name: {cfg.exp_name}")
model.train()
if cfg.train.bn.freeze_status:
pt_utils.frozen_bn_stats(model.encoder, freeze_affine=cfg.train.bn.freeze_affine)
for batch_idx, batch in enumerate(tr_loader):
iter_start_time = time.perf_counter()
scheduler.step(curr_idx=counter.curr_iter)
data_batch = pt_utils.to_device(data=batch["data"], device=cfg.device)
with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
outputs = model(data=data_batch, iter_percentage=counter.curr_percent)
loss = outputs["loss"]
loss_str = outputs["loss_str"]
loss = loss / cfg.train.grad_acc_step
scaler.calculate_grad(loss=loss)
if counter.every_n_iters(cfg.train.grad_acc_step):
scaler.update_grad()
item_loss = loss.item()
data_shape = tuple(data_batch["mask"].shape)
loss_recorder.update(value=item_loss, num=data_shape[0])
if cfg.log_interval > 0 and (
counter.every_n_iters(cfg.log_interval)
or counter.is_first_inner_iter()
or counter.is_last_inner_iter()
or counter.is_last_total_iter()
):
gpu_mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G"
eta_seconds = iter_time_recorder.avg * (counter.num_total_iters - counter.curr_iter - 1)
eta_string = f"ETA: {datetime.timedelta(seconds=int(eta_seconds))}"
progress = (
f"{counter.curr_iter}:{counter.num_total_iters} "
f"{batch_idx}/{counter.num_inner_iters} "
f"{counter.curr_epoch}/{counter.num_epochs}"
)
loss_info = f"{loss_str} (M:{loss_recorder.global_avg:.5f}/C:{item_loss:.5f})"
lr_info = f"LR: {optimizer.lr_string()}"
LOGGER.info(f"{eta_string}({gpu_mem}) | {progress} | {lr_info} | {loss_info} | {data_shape}")
cfg.tb_logger.write_to_tb("lr", optimizer.lr_groups(), counter.curr_iter)
cfg.tb_logger.write_to_tb("iter_loss", item_loss, counter.curr_iter)
cfg.tb_logger.write_to_tb("avg_loss", loss_recorder.global_avg, counter.curr_iter)
if counter.curr_iter < 3:
recorder.plot_results(
dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
save_path=os.path.join(cfg.path.pth_log, "img", f"iter_{counter.curr_iter}.png"),
)
iter_time_recorder.update(value=time.perf_counter() - iter_start_time)
if counter.is_last_total_iter():
break
counter.update_iter_counter()
recorder.plot_results(
dict(img=data_batch["image_m"], msk=data_batch["mask"], **outputs["vis"]),
save_path=os.path.join(cfg.path.pth_log, "img", f"epoch_{counter.curr_epoch}.png"),
)
io.save_weight(model=model, save_path=cfg.path.final_state_net)
counter.update_epoch_counter()
cfg.tb_logger.close_tb()
io.save_weight(model=model, save_path=cfg.path.final_state_net)
total_train_time = time.perf_counter() - train_start_time
total_other_time = datetime.timedelta(seconds=int(total_train_time - iter_time_recorder.global_sum))
LOGGER.info(
f"Total Training Time: {datetime.timedelta(seconds=int(total_train_time))} ({total_other_time} on others)"
)
def parse_cfg():
parser = argparse.ArgumentParser("Training and evaluation script")
parser.add_argument("--config", required=True, type=str)
parser.add_argument("--data-cfg", type=str, default="./dataset.yaml")
parser.add_argument("--model-name", type=str, choices=model_zoo.__dict__.keys())
parser.add_argument("--output-dir", type=str, default="outputs")
parser.add_argument("--load-from", type=str)
parser.add_argument("--pretrained", action="store_true")
parser.add_argument(
"--metric-names",
nargs="+",
type=str,
default=["sm", "wfm", "mae", "em", "fmeasure"],
choices=recorder.GroupedMetricRecorder.supported_metrics,
)
parser.add_argument("--evaluate", action="store_true")
parser.add_argument("--save-results", action="store_true")
parser.add_argument("--info", type=str)
args = parser.parse_args()
cfg = Config.fromfile(args.config)
cfg.merge_from_dict(vars(args))
with open(cfg.data_cfg, mode="r") as f:
cfg.dataset_infos = yaml.safe_load(f)
cfg.proj_root = os.path.dirname(os.path.abspath(__file__))
cfg.exp_name = py_utils.construct_exp_name(model_name=cfg.model_name, cfg=cfg)
cfg.output_dir = os.path.join(cfg.proj_root, cfg.output_dir)
cfg.path = py_utils.construct_path(output_dir=cfg.output_dir, exp_name=cfg.exp_name)
cfg.device = "cuda:0"
py_utils.pre_mkdir(cfg.path)
with open(cfg.path.cfg_copy, encoding="utf-8", mode="w") as f:
f.write(cfg.pretty_text)
shutil.copy(__file__, cfg.path.trainer_copy)
file_handler = logging.FileHandler(cfg.path.log)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter("[%(filename)s] %(message)s"))
LOGGER.addHandler(file_handler)
LOGGER.info(cfg.pretty_text)
cfg.tb_logger = recorder.TBLogger(tb_root=cfg.path.tb)
return cfg
def main():
cfg = parse_cfg()
pt_utils.initialize_seed_cudnn(seed=cfg.base_seed, deterministic=cfg.deterministic)
model_class = model_zoo.__dict__.get(cfg.model_name)
assert model_class is not None, "Please check your --model-name"
model_code = inspect.getsource(model_class)
model = model_class(num_frames=1, pretrained=cfg.pretrained)
LOGGER.info(model_code)
model.to(cfg.device)
if cfg.load_from:
io.load_weight(model=model, load_path=cfg.load_from, strict=True)
LOGGER.info(f"Number of Parameters: {sum((v.numel() for v in model.parameters(recurse=True)))}")
if not cfg.evaluate:
train(model=model, cfg=cfg)
if cfg.evaluate or cfg.has_test:
io.save_weight(model=model, save_path=cfg.path.final_state_net)
test(model=model, cfg=cfg)
LOGGER.info("End training...")
if __name__ == "__main__":
main()