learn2refocus / extra /compute_metrics.py
tedlasai's picture
commit
199f9c2
import torchmetrics
import os
import torch
from PIL import Image
import numpy as np
import csv
import sys
num_positions = 9
output_dir_path = "/datasets/sai/focal-burst-learning/metrics_output"
gt = "gt"
model = sys.argv[1]
gt_path = os.path.join(output_dir_path, gt)
model_path = os.path.join(output_dir_path, model)
device = sys.argv[2]
metrics_grid = []
for i in range(num_positions):
row = []
for j in range(num_positions):
metrics = {
"psnr": torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0).to(device),
"ssim": torchmetrics.image.StructuralSimilarityIndexMeasure().to(device),
"lpips": torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type='vgg', normalize=True).to(device),
"fid": torchmetrics.image.fid.FrechetInceptionDistance(normalize=True).to(device),
"vif": torchmetrics.image.VisualInformationFidelity().to(device),
}
row.append(metrics)
metrics_grid.append(row)
print("Created metrics for position", i)
#lopp through each directory in gt_path
#get all directories in gt_path
position_dirs = os.listdir(gt_path)
position_dirs = sorted([dir for dir in position_dirs if os.path.isdir(os.path.join(gt_path, dir))]) [0:num_positions]
for gt_dir in position_dirs:
position_number = int(gt_dir.split("_")[1])
#get pngs inside that directory
gt_pngs = sorted(os.listdir(os.path.join(gt_path, gt_dir, "images")))
#Confirm that number of pngs == 164*9
assert len(gt_pngs) == 164*9
#loop through the 164 imgs
for i in range(164):
#get the 9 frames
gt_frames_names = gt_pngs[i*9:(i+1)*9]
#load the 9 frames
gt_frames = [Image.open(os.path.join(gt_path, gt_dir, "images", frame)) for frame in gt_frames_names]
#make into numpy arraymo
gt_frames = [torch.tensor(np.array(frame)/255).to(torch.float32).to(device).permute(2,0,1).unsqueeze(0) for frame in gt_frames]
#load model_frames which is almost smae path but in model_path
model_frames = [Image.open(os.path.join(model_path, gt_dir, "images", frame)) for frame in gt_frames_names]
#make into numpy array
model_frames = [torch.tensor(np.array(frame)/255).to(torch.float32).to(device).permute(2,0,1).unsqueeze(0) for frame in model_frames]
#loop through the 9 frames
for j in range(num_positions):
#compute metrics
for key, metric in metrics_grid[position_number][j].items():
#if frames have a 4th channel discard it
if gt_frames[j].shape[1] == 4:
gt_frames[j] = gt_frames[j][:,:3,:,:]
if model_frames[j].shape[1] == 4:
model_frames[j] = model_frames[j][:,:3,:,:]
if key == "fid":
metric.update(model_frames[j], real=False)
metric.update(gt_frames[j], real=True)
else:
metric(gt_frames[j], model_frames[j])
print("Computed metrics for position", position_number, "frame", i)
#write the metrics to a csv (each metric as a csv)
def write_metrics_to_csv(metrics_grid, metric_names, formatting_options=None, output_dir="metrics_output"):
"""
Writes each metric in the metrics_grid to a separate CSV file.
Args:
metrics_grid (list): A 9x9 list of dictionaries containing metrics.
metric_names (list): List of metric names (e.g., ["psnr", "lpips", "fid"]).
output_dir (str): Directory where the CSV files will be saved.
"""
import os
os.makedirs(output_dir, exist_ok=True) # Create output directory if it doesn't exist
positions = list(range(1, num_positions+1))
for metric_name in metric_names:
output_file = os.path.join(output_dir, f"{metric_name}.csv")
# Get the formatting function for the current metric, or use default
format_fn = formatting_options.get(metric_name, lambda x: f"{x}") if formatting_options else lambda x: f"{x}"
# Write the metric to the CSV
with open(output_file, mode='w', newline='') as csv_file:
writer = csv.writer(csv_file)
header = ["Starting Position/End Position"] + [f"Position {i}" for i in positions]
writer.writerow(header)
# Iterate over the grid and extract the metric values
for i, row in enumerate(metrics_grid):
csv_row = [f"Position {positions[i]}"] # Add the column label as the first column
for cell in row:
metric = cell[metric_name]
# Assuming metrics are PyTorch objects with a `compute` method
# Replace `0.0` with metric.compute() if metric values are computed
value = 0.0 if not hasattr(metric, "compute") else metric.compute().item()
csv_row.append(format_fn(value)) # Format the value
writer.writerow(csv_row)
print(f"Wrote row for position {positions[i]} with metric {metric_name}")
print(f"Saved {metric_name} metrics to {output_file}")
formatting_options = {
"psnr": lambda x: f"{x:.2f}", # Two decimal places
"lpips": lambda x: f"{x:.4f}", # Four decimal places
"fid": lambda x: f"{x:.2f}", # Two decimal places
"ssim": lambda x: f"{x:.4f}", # Four decimal places
"vif": lambda x: f"{x:.4f}" # Four decimal places
}
write_metrics_to_csv(metrics_grid, ["psnr", "ssim", "lpips", "fid", "vif"], formatting_options=formatting_options, output_dir=f"{output_dir_path}/metrics_output/{model}")