| | import pandas as pd |
| | import os.path |
| | import sys |
| | import json |
| | import logging |
| | import contexttimer |
| |
|
| | |
| | logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO) |
| |
|
| | if len(sys.argv) != 4: |
| | print("Provide .tsv file name, images dir, output file name. e.g. python coco.py coco_captions_train2017.json /mnt/disks/data-1/flickr8k/coco_train.json coco_dataset_train.json") |
| | exit(1) |
| |
|
| | annotation_file = sys.argv[1] |
| | images_dir = sys.argv[2] |
| | output_file = sys.argv[3] |
| |
|
| | logging.info("Processing cc12m dataset") |
| |
|
| | with contexttimer.Timer(prefix="Loading from tsv"): |
| | df = pd.read_csv(annotation_file, delimiter='\t') |
| |
|
| | lines = [] |
| |
|
| | df = df[["caption", "url"]] |
| |
|
| | print(f"Loaded {len(df)} images.") |
| |
|
| | for index, caption_reference_description, image_url in df.itertuples(): |
| | index+=1 |
| | base_url = os.path.basename(image_url) |
| | stem, ext = os.path.splitext(base_url) |
| | filename = f'{index:08d}---{stem}.jpg' |
| |
|
| | full_image_path = images_dir+"/"+filename |
| |
|
| | if os.path.isfile(full_image_path): |
| | lines.append(json.dumps({"image_path": full_image_path, "captions": [caption_reference_description]})) |
| | else: |
| | |
| | logging.error(full_image_path) |
| |
|
| |
|
| | with open(output_file, "w") as f: |
| | f.write("\n".join(lines)) |
| |
|
| | logging.info(f"Processing cc12m dataset done. {len(lines)} images processed.") |
| |
|
| |
|