| | import os |
| | from huggingface_hub import snapshot_download, delete_repo, metadata_update |
| | import uuid |
| | import json |
| | import yaml |
| | import subprocess |
| |
|
| | HF_TOKEN = os.environ.get("HF_TOKEN") |
| | HF_DATASET = os.environ.get("DATA_PATH") |
| |
|
| |
|
| | def download_dataset(hf_dataset_path: str): |
| | random_id = str(uuid.uuid4()) |
| | snapshot_download( |
| | repo_id=hf_dataset_path, |
| | token=HF_TOKEN, |
| | local_dir=f"/tmp/{random_id}", |
| | repo_type="dataset", |
| | ) |
| | return f"/tmp/{random_id}" |
| |
|
| |
|
| | def process_dataset(dataset_dir: str): |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if not os.path.exists(os.path.join(dataset_dir, "config.yaml")): |
| | raise ValueError("config.yaml does not exist") |
| |
|
| | |
| | if os.path.exists(os.path.join(dataset_dir, "metadata.jsonl")): |
| | metadata = [] |
| | with open(os.path.join(dataset_dir, "metadata.jsonl"), "r") as f: |
| | for line in f: |
| | if len(line.strip()) > 0: |
| | metadata.append(json.loads(line)) |
| | for item in metadata: |
| | txt_path = os.path.join(dataset_dir, item["file_name"]) |
| | txt_path = txt_path.rsplit(".", 1)[0] + ".txt" |
| | with open(txt_path, "w") as f: |
| | f.write(item["prompt"]) |
| |
|
| | |
| | os.remove(os.path.join(dataset_dir, "metadata.jsonl")) |
| |
|
| | with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: |
| | config = yaml.safe_load(f) |
| |
|
| | |
| | config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_dir |
| |
|
| | with open(os.path.join(dataset_dir, "config.yaml"), "w") as f: |
| | yaml.dump(config, f) |
| |
|
| | return dataset_dir |
| |
|
| |
|
| | def run_training(hf_dataset_path: str): |
| |
|
| | dataset_dir = download_dataset(hf_dataset_path) |
| | dataset_dir = process_dataset(dataset_dir) |
| |
|
| | |
| | commands = "git clone https://github.com/ostris/ai-toolkit.git ai-toolkit && cd ai-toolkit && git checkout bc693488eb3cf48ded8bc2af845059d80f4cf7d0 && git submodule update --init --recursive" |
| | subprocess.run(commands, shell=True) |
| |
|
| | commands = f"python run.py {os.path.join(dataset_dir, 'config.yaml')}" |
| | process = subprocess.Popen(commands, shell=True, cwd="ai-toolkit", env=os.environ) |
| |
|
| | return process, dataset_dir |
| |
|
| |
|
| | if __name__ == "__main__": |
| | process, dataset_dir = run_training(HF_DATASET) |
| | process.wait() |
| |
|
| | with open(os.path.join(dataset_dir, "config.yaml"), "r") as f: |
| | config = yaml.safe_load(f) |
| | repo_id = config["config"]["process"][0]["save"]["hf_repo_id"] |
| |
|
| | metadata = { |
| | "tags": [ |
| | "autotrain", |
| | "spacerunner", |
| | "text-to-image", |
| | "flux", |
| | "lora", |
| | "diffusers", |
| | "template:sd-lora", |
| | ] |
| | } |
| | metadata_update(repo_id, metadata, token=HF_TOKEN, repo_type="model", overwrite=True) |
| | delete_repo(HF_DATASET, token=HF_TOKEN, repo_type="dataset", missing_ok=True) |
| |
|