| | import json |
| | import collections |
| | import logging |
| | import sys |
| |
|
| | if len(sys.argv) != 4: |
| | print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json") |
| | exit(1) |
| |
|
| | annotation_file = sys.argv[1] |
| | images_dir = sys.argv[2] |
| | output_file = sys.argv[3] |
| |
|
| | logging.info("Processing COCO dataset") |
| |
|
| | with open(annotation_file, "r") as f: |
| | annotations = json.load(f)["annotations"] |
| |
|
| | image_path_to_caption = collections.defaultdict(list) |
| | for element in annotations: |
| | caption = f"{element['caption'].lower().rstrip('.')}" |
| | image_path = images_dir + "/%012d.jpg" % (element["image_id"]) |
| | image_path_to_caption[image_path].append(caption) |
| |
|
| | lines = [] |
| | for image_path, captions in image_path_to_caption.items(): |
| | lines.append(json.dumps({"image_path": image_path, "captions": captions})) |
| |
|
| | train_lines = lines[:-10_001] |
| | valid_lines = lines[-10_001:] |
| |
|
| | with open(output_file+"_train.json", "w") as f: |
| | f.write("\n".join(train_lines)) |
| |
|
| | with open(output_file+"_val.json", "w") as f: |
| | f.write("\n".join(valid_lines)) |
| |
|
| | logging.info(f"Processing COCO dataset done. {len(lines)} images processed.") |
| |
|
| | |
| |
|