Spaces:
Running on Zero
Running on Zero
| import logging | |
| from collections import defaultdict | |
| import numpy as np | |
| from tensorboardX import SummaryWriter | |
| WINDOW_SIZE = 100 | |
| class MetricCounter: | |
| def __init__(self, exp_name): | |
| self.writer = SummaryWriter(exp_name) | |
| logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG) | |
| self.metrics = defaultdict(list) | |
| self.images = defaultdict(list) | |
| self.best_metric = 0 | |
| def add_image(self, x: np.ndarray, tag: str): | |
| self.images[tag].append(x) | |
| def clear(self): | |
| self.metrics = defaultdict(list) | |
| self.images = defaultdict(list) | |
| def add_losses(self, l_G): | |
| for name, value in zip(('G_loss', None), (l_G, None)): | |
| self.metrics[name].append(value) | |
| def add_metrics(self, psnr, ssim): | |
| for name, value in zip(('PSNR', 'SSIM'), | |
| (psnr, ssim)): | |
| self.metrics[name].append(value) | |
| def loss_message(self): | |
| metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM')) | |
| return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics)) | |
| def write_to_tensorboard(self, epoch_num, validation=False): | |
| scalar_prefix = 'Validation' if validation else 'Train' | |
| for tag in ('G_loss', 'SSIM', 'PSNR'): | |
| self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num) | |
| for tag in self.images: | |
| imgs = self.images[tag] | |
| if imgs: | |
| imgs = np.array(imgs) | |
| self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC', | |
| global_step=epoch_num) | |
| self.images[tag] = [] | |
| def update_best_model(self): | |
| cur_metric = np.mean(self.metrics['PSNR']) | |
| if self.best_metric < cur_metric: | |
| self.best_metric = cur_metric | |
| return True | |
| return False | |