| | """ |
| | Example script for running dreaming on a dataset. |
| | The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context. |
| | |
| | After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players), |
| | should be identical if model was ideal. |
| | """ |
| |
|
| | import argparse |
| | from pathlib import Path |
| | import os |
| | import subprocess |
| |
|
| | import cv2 |
| | from tensordict import TensorDict |
| | import torch as th |
| | from tqdm import tqdm |
| | import numpy as np |
| | import ffmpegcv |
| | from PIL import Image |
| |
|
| | import wham.utils as utils |
| |
|
| |
|
| | parser = argparse.ArgumentParser(description="Run dreaming.") |
| | parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.") |
| | parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.") |
| | parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.") |
| | parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.") |
| | parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.") |
| |
|
| |
|
| | parser.add_argument( |
| | "--protocol", |
| | type=str, |
| | default="base", |
| | choices=["base", "comprehensive"], |
| | help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.", |
| | ) |
| | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.") |
| | parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.") |
| | parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.") |
| |
|
| | parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.") |
| | parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.") |
| | parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.") |
| |
|
| |
|
| | def get_context_data(image_context, action_context, action_sequences): |
| | |
| | assert image_context.shape[-3] == 3, "Image context should be CHW" |
| |
|
| | image_context = th.from_numpy(image_context).cuda() |
| | action_data = th.from_numpy(action_context).float().cuda() |
| | action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None |
| |
|
| | return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2]) |
| |
|
| |
|
| | def add_video_metadata(file_path, metadata_config): |
| | |
| | cmd = [ |
| | 'exiftool', |
| | '-config', metadata_config, |
| | f'-ProgramName=\"{utils.PROGRAM_NAME}\"', |
| | '-overwrite_original', |
| | file_path |
| | ] |
| |
|
| | try: |
| | |
| | subprocess.run(cmd, check=True) |
| | print(f"Metadata modified successfully.") |
| | |
| | cmd_output = [ |
| | 'exiftool', |
| | file_path |
| | ] |
| | subprocess.run(cmd_output, check=True) |
| | except subprocess.CalledProcessError as e: |
| | print(f"Error modifying metadata: {e}") |
| |
|
| |
|
| | @th.no_grad() |
| | def do_dreaming(model, image_context, action_context, args, action_sequences=None): |
| | """ |
| | image_contect and action_context provide the initial context for the model to dream from. |
| | |
| | If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions. |
| | """ |
| | context_data = get_context_data(image_context, action_context, action_sequences) |
| | encoded_context_data = model.encode_context(context_data) |
| |
|
| | encoded_action_sequences = None |
| | if action_sequences is not None: |
| | assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)" |
| | action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda() |
| | encoded_action_sequences = model.encode_context(action_sequences) |
| |
|
| | encoded_dreamt_steps = [] |
| |
|
| | for dream_step in range(args.steps_to_dream): |
| | encoded_predicted_step, _ = model.predictor.predict_next_step( |
| | encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1 |
| | ) |
| |
|
| | |
| | if encoded_context_data.shape[1] == args.context_length: |
| | encoded_context_data = encoded_context_data[:, 1:] |
| |
|
| | |
| | append_step = encoded_predicted_step |
| | if encoded_action_sequences is not None: |
| | |
| | append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :] |
| | encoded_context_data = th.cat((encoded_context_data, append_step), dim=1) |
| |
|
| | encoded_dreamt_steps.append(encoded_predicted_step) |
| |
|
| | |
| | dreamed_images = [] |
| | actions_during_dream = [] |
| | for seq_i in range(args.steps_to_dream): |
| | decoded_step = model.decode_context(encoded_dreamt_steps[seq_i]) |
| | dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy()) |
| | actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy()) |
| |
|
| | dreamed_images = np.concatenate(dreamed_images, axis=1) |
| | actions_during_dream = np.concatenate(actions_during_dream, axis=1) |
| |
|
| | return dreamed_images, actions_during_dream |
| |
|
| |
|
| | @th.no_grad() |
| | def encode_decode_images(model, images): |
| | """ |
| | Pass ground_truth images through the encoding/decoding process of the model. |
| | """ |
| | context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2]) |
| | output_images = [] |
| | for seq_i in range(images.shape[1]): |
| | encoded_images = model.encode_context(context[:, [seq_i]]) |
| | decoded_images = model.decode_context(encoded_images) |
| | output_images.append(decoded_images["images"].cpu().numpy()) |
| | return np.concatenate(output_images, axis=1) |
| |
|
| |
|
| | def main(args): |
| | total_video_length = args.context_length + args.steps_to_dream |
| |
|
| | |
| | model_path = Path(args.model_path) |
| | assert model_path.is_file(), "Could not find the model!" |
| | model = utils.load_model_from_checkpoint(model_path).cuda() |
| |
|
| | |
| | data_path = Path(args.data_path) |
| | ground_truth_files = list(data_path.rglob("*.npz")) |
| | num_dreams = len(ground_truth_files) |
| |
|
| | if args.max_files is not None: |
| | |
| | ground_truth_files = sorted(ground_truth_files) |
| | ground_truth_files = ground_truth_files[: args.max_files] |
| | num_dreams = len(ground_truth_files) |
| |
|
| | output_path = Path(args.output) |
| | os.makedirs(output_path, exist_ok=True) |
| |
|
| | print("=" * 100) |
| | print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS") |
| | print(f"WRITING TO {args.output}") |
| | print("=" * 100) |
| |
|
| | dreams_created = 0 |
| | with tqdm(total=num_dreams, desc="Dreams") as pbar: |
| | while ground_truth_files: |
| | |
| | batches = min(args.batch_size, len(ground_truth_files)) |
| | batched_image_context = [] |
| | batched_image_sequence = [] |
| | batched_action_context = [] |
| | batched_action_sequence = [] |
| | episode_names = [] |
| | for i in range(batches): |
| | episode = ground_truth_files.pop() |
| | episode_names.append(episode) |
| | try: |
| | data = np.load(episode) |
| | images = data["images"] |
| | actions = data["actions"] |
| | except Exception: |
| | print(f"Failed to load episode {episode} - skipping.") |
| | continue |
| |
|
| | if actions.shape[0] < total_video_length: |
| | |
| | raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.") |
| | batched_image_context.append(images[: args.context_length]) |
| | batched_image_sequence.append(images[args.context_length: total_video_length]) |
| | batched_action_context.append(actions[: args.context_length]) |
| | batched_action_sequence.append(actions[args.context_length: total_video_length]) |
| |
|
| | image_context = np.array(batched_image_context) |
| | image_sequences = np.array(batched_image_sequence) |
| | action_context = np.array(batched_action_context) |
| | action_sequences = np.array(batched_action_sequence) |
| |
|
| | if args.protocol == "comprehensive": |
| | |
| | action_sequences = None |
| |
|
| | full_image_sequence = np.concatenate((image_context, image_sequences), axis=1) |
| |
|
| | dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences) |
| | encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence) |
| |
|
| | pbar.update(batches) |
| | dreams_created += batches |
| |
|
| | |
| | |
| | |
| | for i, dream in enumerate(dreamt_images): |
| | episode = episode_names[i] |
| | output_file = output_path / episode.relative_to(data_path) |
| | output_file.parent.mkdir(parents=True, exist_ok=True) |
| | np.savez( |
| | output_file, |
| | context_length=args.context_length, |
| | steps_to_dream=args.steps_to_dream, |
| | raw_context=image_context[i], |
| | dreamt_images=dream, |
| | all_actions=np.concatenate((action_context[i], actions_during_dream[i])), |
| | encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i], |
| | ) |
| |
|
| | video_file = str(output_file.with_suffix(".mp4")) |
| | writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS) |
| | full_sequence = np.concatenate((image_context[i], dream), axis=0) |
| | for frame in full_sequence: |
| | img = frame.transpose(1, 2, 0).astype(np.uint8).copy() |
| | |
| | (text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS) |
| | x = img.shape[1] - text_width - 10 |
| | y = img.shape[0] - 10 |
| | cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS) |
| |
|
| | |
| | pil_image = Image.fromarray(img) |
| | pil_image.info['Id'] = 0x0131 |
| | pil_image.info['Type'] = 2 |
| | pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8") |
| | pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1 |
| |
|
| | |
| | cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) |
| | writer.write(cv_image) |
| | writer.release() |
| | add_video_metadata(video_file, args.metadata_config) |
| |
|
| | if __name__ == "__main__": |
| | args = parser.parse_args() |
| | main(args) |
| |
|