|
|
import random |
|
|
import os |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import pandas as pd |
|
|
import wandb |
|
|
import time |
|
|
from tqdm import trange |
|
|
from torch.utils.data import IterableDataset |
|
|
from datasets.dummy import DummyVideoDataset |
|
|
from datasets.openx_base import OpenXVideoDataset |
|
|
from datasets.droid import DroidVideoDataset |
|
|
from datasets.something_something import SomethingSomethingDataset |
|
|
from datasets.epic_kitchen import EpicKitchenDataset |
|
|
from datasets.pandas import PandasVideoDataset |
|
|
from datasets.ego4d import Ego4DVideoDataset |
|
|
from datasets.mixture import MixtureDataset |
|
|
from datasets.agibot_world import AgibotWorldDataset |
|
|
from .exp_base import BaseExperiment |
|
|
from utils.gemini_utils import GeminiCaptionProcessor |
|
|
|
|
|
|
|
|
class ProcessDataExperiment(BaseExperiment): |
|
|
""" |
|
|
An experiment class for you to easily process an existing |
|
|
dataset into another, by creating a new csv metadata file and new files. |
|
|
|
|
|
e.g. The `cache_prompt_embed` method illustrates caching the prompt embeddings and |
|
|
adding a field `prompt_embed_path` to a copy ofthe metadata csv. |
|
|
|
|
|
e.g. The `visualize_dataset` method illustrates visualizing a sample of videos from the dataset with their captions. |
|
|
|
|
|
Add your processing methods here, and follow README.md to run. |
|
|
""" |
|
|
|
|
|
compatible_datasets = dict( |
|
|
mixture=MixtureDataset, |
|
|
mixture_robot=MixtureDataset, |
|
|
dummy=DummyVideoDataset, |
|
|
something_something=SomethingSomethingDataset, |
|
|
epic_kitchen=EpicKitchenDataset, |
|
|
pandas=PandasVideoDataset, |
|
|
ego4d=Ego4DVideoDataset, |
|
|
bridge=OpenXVideoDataset, |
|
|
droid=DroidVideoDataset, |
|
|
agibot_world=AgibotWorldDataset, |
|
|
language_table=OpenXVideoDataset, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
def _build_dataset( |
|
|
self, disable_filtering: bool = True, split: str = "all" |
|
|
) -> torch.utils.data.Dataset: |
|
|
if disable_filtering: |
|
|
self.root_cfg.dataset.filtering.disable = True |
|
|
return self.compatible_datasets[self.root_cfg.dataset._name]( |
|
|
self.root_cfg.dataset, split=split |
|
|
) |
|
|
|
|
|
def _get_save_dir(self, dataset: torch.utils.data.Dataset): |
|
|
save_dir = self.cfg.new_data_root |
|
|
if self.cfg.new_data_root is None: |
|
|
save_dir = self.output_dir / dataset.data_root.name |
|
|
else: |
|
|
save_dir = Path(save_dir) |
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
return save_dir |
|
|
|
|
|
def benchmark_dataloader(self): |
|
|
"""Benchmark the speed of the dataloader.""" |
|
|
cfg = self.cfg.benchmark_dataloader |
|
|
dataset = self._build_dataset() |
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=cfg.batch_size, |
|
|
num_workers=cfg.num_workers, |
|
|
shuffle=False, |
|
|
) |
|
|
for i in trange(1000000): |
|
|
time.sleep(0.001) |
|
|
|
|
|
def visualize_dataset(self): |
|
|
"""Visualize a sample of videos from the dataset with their captions. |
|
|
|
|
|
This method: |
|
|
1. Creates a dataloader for the dataset |
|
|
2. Logs the videos and their captions to wandb |
|
|
|
|
|
Sample command: |
|
|
python main.py +name=process_data experiment=process_data dataset=video_openx experiment.tasks=[visualize_dataset] |
|
|
""" |
|
|
|
|
|
cfg = self.cfg.visualize_dataset |
|
|
dataset = self._build_dataset( |
|
|
disable_filtering=cfg.disable_filtering, split="training" |
|
|
) |
|
|
shuffle = not isinstance(dataset, IterableDataset) |
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, batch_size=1, num_workers=0, shuffle=shuffle |
|
|
) |
|
|
|
|
|
log_dict = {} |
|
|
self._build_logger() |
|
|
|
|
|
samples_seen = 0 |
|
|
for batch in dataloader: |
|
|
if samples_seen >= cfg.n_samples: |
|
|
break |
|
|
|
|
|
for i in range(len(batch["videos"])): |
|
|
if samples_seen >= cfg.n_samples: |
|
|
break |
|
|
|
|
|
prompts = None |
|
|
if "prompts" in batch: |
|
|
prompts = batch["prompts"][i] |
|
|
|
|
|
if cfg.use_processed: |
|
|
video = batch["videos"][i] |
|
|
|
|
|
video = ((video + 1) / 2 * 255).clamp(0, 255) |
|
|
video = video.to(torch.uint8).numpy() |
|
|
log_dict[f"sample_{samples_seen}"] = wandb.Video( |
|
|
video, caption=prompts, fps=16 |
|
|
) |
|
|
else: |
|
|
|
|
|
video_path = str(dataset.data_root / batch["video_path"][i]) |
|
|
log_dict[f"sample_{samples_seen}"] = wandb.Video( |
|
|
video_path, caption=prompts, fps=16 |
|
|
) |
|
|
|
|
|
samples_seen += 1 |
|
|
if samples_seen % 8 == 0: |
|
|
wandb.log(log_dict) |
|
|
log_dict = {} |
|
|
|
|
|
|
|
|
if log_dict: |
|
|
wandb.log(log_dict) |
|
|
|
|
|
def cache_prompt_embed(self): |
|
|
"""Cache prompt embeddings for all captions in the dataset. |
|
|
|
|
|
This method: |
|
|
1. Takes captions from the dataset metadata |
|
|
2. Generates T5 embeddings for each caption using CogVideo's T5 encoder |
|
|
3. Saves embeddings as .pt files alongside the videos |
|
|
4. Creates a new metadata CSV with an added 'prompt_embed_path' column |
|
|
|
|
|
Sample commands: |
|
|
# Cache embeddings for OpenX dataset: |
|
|
python main.py +name=process_data experiment=process_data dataset=video_openx experiment.tasks=[cache_prompt_embed] |
|
|
|
|
|
# Specify custom output directory: |
|
|
python main.py +name=process_data experiment=process_data dataset=video_openx experiment.tasks=[cache_prompt_embed] experiment.new_data_root=data/processed |
|
|
|
|
|
# Adjust batch size: |
|
|
python main.py +name=process_data experiment=process_data dataset=video_openx experiment.tasks=[cache_prompt_embed] experiment.cache_prompt_embed.batch_size=64 |
|
|
""" |
|
|
cfg = self.cfg.cache_prompt_embed |
|
|
batch_size = cfg.batch_size |
|
|
|
|
|
if self.cfg.num_nodes != 1: |
|
|
raise ValueError("This script only supports 1 node. ") |
|
|
|
|
|
from algorithms.cogvideo.t5 import T5Encoder |
|
|
|
|
|
t5_encoder = T5Encoder(self.root_cfg.algorithm).cuda() |
|
|
dataset = self._build_dataset() |
|
|
records = dataset.records |
|
|
|
|
|
save_dir = self._get_save_dir(dataset) |
|
|
metadata_path = save_dir / dataset.metadata_path |
|
|
metadata_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
print("Saving prompt embeddings and new metadata to ", save_dir) |
|
|
|
|
|
new_records = [] |
|
|
for i in trange(0, len(records), batch_size): |
|
|
batch = records[i : i + batch_size] |
|
|
prompts = [dataset.id_token + r["caption"] for r in batch] |
|
|
embeds = t5_encoder.predict(prompts).cpu() |
|
|
for r, embed in zip(batch, embeds): |
|
|
video_path = Path(r["video_path"]) |
|
|
prompt_embed_path = ( |
|
|
save_dir / "prompt_embeds" / video_path.with_suffix(".pt") |
|
|
) |
|
|
prompt_embed_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
torch.save(embed.clone(), prompt_embed_path) |
|
|
r["prompt_embed_path"] = str(prompt_embed_path.relative_to(save_dir)) |
|
|
new_records.append(r) |
|
|
|
|
|
df = pd.DataFrame.from_records(new_records) |
|
|
df.to_csv(metadata_path, index=False) |
|
|
|
|
|
print("To review the prompt embeddings, go to ", save_dir) |
|
|
print( |
|
|
"If everything looks good, you can merge the new dataset into the old " |
|
|
"one with the following command:" |
|
|
) |
|
|
print(f"rsync -av {save_dir}/* {dataset.data_root} && rm -rf {save_dir}") |
|
|
|
|
|
def create_gemini_caption(self): |
|
|
""" |
|
|
Create Gemini caption for each video in the dataset. |
|
|
|
|
|
1. Init the Dataset, and load all raw records. |
|
|
2. Init the GeminiCaptionProcessor with two params: output_file and num_workers. |
|
|
3. Start the processor, and process each record. It will write to the output file. |
|
|
|
|
|
For each record in the dataset, it must has "video_path" as the absolute path. |
|
|
If each record has some additional keys, like: duration, fps, height, width, n_frames, youtube_key_segment, etc. |
|
|
they will be added to the output file. Check "Class VideoEntry" below for more details. |
|
|
|
|
|
Sample command: |
|
|
python main.py +name=create_gemini_caption experiment=process_data dataset=pandas experiment.tasks=[create_gemini_caption] |
|
|
""" |
|
|
cfg = self.cfg.create_gemini_caption |
|
|
num_workers = cfg.n_workers |
|
|
|
|
|
dataset = self._build_dataset() |
|
|
records = dataset.records |
|
|
|
|
|
save_dir = self._get_save_dir(dataset) |
|
|
metadata_path = dataset.metadata_path.with_suffix(".json") |
|
|
metadata_path = metadata_path.parent / ("gemini_" + metadata_path.name) |
|
|
output_file = save_dir / metadata_path |
|
|
|
|
|
for r in records: |
|
|
r["video_path"] = str((dataset.data_root / r["video_path"]).absolute()) |
|
|
|
|
|
if not os.path.exists(records[0]["video_path"]): |
|
|
raise ValueError("video_path must be an absolute path") |
|
|
|
|
|
processor = GeminiCaptionProcessor(output_file, num_workers=num_workers) |
|
|
processor.process_entries(records) |
|
|
print("To review the captions, go to ", output_file) |
|
|
print( |
|
|
"If everything looks good, you can merge the new dataset into the old " |
|
|
"one with the following command:" |
|
|
) |
|
|
print(f"rsync -av {save_dir}/* {dataset.data_root} && rm -rf {save_dir}") |
|
|
|
|
|
def run_hand_pose_estimation(self): |
|
|
|
|
|
import queue |
|
|
import threading |
|
|
import decord |
|
|
|
|
|
|
|
|
from sapiens_inference import SapiensPoseEstimation, SapiensPoseEstimationType |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
hand_keypoints_keys_list = [ |
|
|
|
|
|
|
|
|
"right_wrist", |
|
|
"right_thumb4", |
|
|
"right_thumb3", |
|
|
"right_thumb2", |
|
|
"right_thumb_third_joint", |
|
|
"right_forefinger4", |
|
|
"right_forefinger3", |
|
|
"right_forefinger2", |
|
|
"right_forefinger_third_joint", |
|
|
"right_middle_finger4", |
|
|
"right_middle_finger3", |
|
|
"right_middle_finger2", |
|
|
"right_middle_finger_third_joint", |
|
|
"right_ring_finger4", |
|
|
"right_ring_finger3", |
|
|
"right_ring_finger2", |
|
|
"right_ring_finger_third_joint", |
|
|
"right_pinky_finger4", |
|
|
"right_pinky_finger3", |
|
|
"right_pinky_finger2", |
|
|
"right_pinky_finger_third_joint", |
|
|
|
|
|
"left_wrist", |
|
|
"left_thumb4", |
|
|
"left_thumb3", |
|
|
"left_thumb2", |
|
|
"left_thumb_third_joint", |
|
|
"left_forefinger4", |
|
|
"left_forefinger3", |
|
|
"left_forefinger2", |
|
|
"left_forefinger_third_joint", |
|
|
"left_middle_finger4", |
|
|
"left_middle_finger3", |
|
|
"left_middle_finger2", |
|
|
"left_middle_finger_third_joint", |
|
|
"left_ring_finger4", |
|
|
"left_ring_finger3", |
|
|
"left_ring_finger2", |
|
|
"left_ring_finger_third_joint", |
|
|
"left_pinky_finger4", |
|
|
"left_pinky_finger3", |
|
|
"left_pinky_finger2", |
|
|
"left_pinky_finger_third_joint", |
|
|
] |
|
|
|
|
|
cfg = self.cfg.run_hand_pose_estimation |
|
|
|
|
|
dataset = self._build_dataset() |
|
|
records = dataset.records |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
records = random.sample(records, 50) |
|
|
save_dir = self._get_save_dir(dataset) |
|
|
Path(save_dir).mkdir(parents=True, exist_ok=True) |
|
|
print(f"Saving hand pose estimation results to {save_dir}") |
|
|
|
|
|
|
|
|
frame_queue = queue.Queue( |
|
|
maxsize=100 |
|
|
) |
|
|
STOP_TOKEN = "DONE" |
|
|
|
|
|
def producer(records, data_root): |
|
|
for record in records: |
|
|
try: |
|
|
video_path = Path(data_root) / record["video_path"] |
|
|
vr = decord.VideoReader(str(video_path)) |
|
|
n_frames = len(vr) |
|
|
|
|
|
if n_frames == 0: |
|
|
print(f"No frames found in {record['video_path']}") |
|
|
continue |
|
|
|
|
|
|
|
|
frame_indices = [0, n_frames // 2, n_frames - 1] |
|
|
frames = vr.get_batch( |
|
|
frame_indices |
|
|
).asnumpy() |
|
|
|
|
|
|
|
|
|
|
|
frame_queue.put( |
|
|
{ |
|
|
"frames": frames, |
|
|
"video_path": str( |
|
|
record["video_path"] |
|
|
), |
|
|
"frame_indices": frame_indices, |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error processing {record['video_path']}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
frame_queue.put(STOP_TOKEN) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
producer_thread = threading.Thread( |
|
|
target=producer, args=(records, dataset.data_root), daemon=True |
|
|
) |
|
|
producer_thread.start() |
|
|
|
|
|
|
|
|
dtype = torch.float16 |
|
|
estimator = SapiensPoseEstimation( |
|
|
SapiensPoseEstimationType.POSE_ESTIMATION_03B, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
while True: |
|
|
item = frame_queue.get() |
|
|
if item == STOP_TOKEN: |
|
|
break |
|
|
|
|
|
frames = item["frames"] |
|
|
video_path = item["video_path"] |
|
|
frame_indices = item.get("frame_indices", [0, 1, 2]) |
|
|
|
|
|
ret_per_video = { |
|
|
"video_path": video_path, |
|
|
"frame_indices": frame_indices, |
|
|
"keypoints_list": [], |
|
|
} |
|
|
for idx, frame in zip(frame_indices, frames): |
|
|
try: |
|
|
|
|
|
|
|
|
frame_rgb = frame |
|
|
|
|
|
|
|
|
result_img, keypoints = estimator(frame_rgb) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
keypoints_flat = keypoints |
|
|
|
|
|
|
|
|
keypoints_flat = [ |
|
|
{ |
|
|
k: kp_dict[k] |
|
|
for k in hand_keypoints_keys_list |
|
|
if k in kp_dict |
|
|
} |
|
|
for kp_dict in keypoints_flat |
|
|
] |
|
|
|
|
|
|
|
|
keypoints_flat = [ |
|
|
{k: v for k, v in kp_dict.items() if v[2] > 0.3} |
|
|
for kp_dict in keypoints_flat |
|
|
] |
|
|
result_entry = { |
|
|
"frame_index": idx, |
|
|
"keypoints_list": keypoints_flat, |
|
|
"num_keypoints": sum([len(_) for _ in keypoints_flat]), |
|
|
} |
|
|
|
|
|
ret_per_video["keypoints_list"].append(result_entry) |
|
|
|
|
|
except Exception as e: |
|
|
print( |
|
|
f"Error running pose estimation for frame {idx} of {video_path}: {e}" |
|
|
) |
|
|
continue |
|
|
|
|
|
|
|
|
num_keypoints = sum( |
|
|
[_.get("num_keypoints", 0) for _ in ret_per_video["keypoints_list"]] |
|
|
) |
|
|
if num_keypoints > 0: |
|
|
results.append(ret_per_video) |
|
|
frame_queue.task_done() |
|
|
|
|
|
producer_thread.join() |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"Time taken: {end_time - start_time} seconds") |
|
|
print(f"Total number of videos processed with keypoints: {len(results)}") |
|
|
|
|
|
|
|
|
if results: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
json_path = Path(save_dir) / "hand_pose_results.json" |
|
|
import json |
|
|
|
|
|
with open(json_path, "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
print(f"Results saved to {json_path}") |
|
|
else: |
|
|
print("No results to save.") |
|
|
|
|
|
def run_human_detection(self): |
|
|
|
|
|
import queue |
|
|
import threading |
|
|
import decord |
|
|
from utils.detector_utils import Detector |
|
|
import time |
|
|
|
|
|
detector = Detector() |
|
|
|
|
|
cfg = self.cfg.run_human_detection |
|
|
|
|
|
dataset = self._build_dataset() |
|
|
records = dataset.records |
|
|
|
|
|
|
|
|
|
|
|
num_workers = cfg.total_workers |
|
|
job_id = cfg.job_id |
|
|
|
|
|
records = records[job_id::num_workers] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_dir = self._get_save_dir(dataset) |
|
|
Path(save_dir).mkdir(parents=True, exist_ok=True) |
|
|
print(f"Saving hand pose estimation results to {save_dir}") |
|
|
|
|
|
|
|
|
frame_queue = queue.Queue( |
|
|
maxsize=100 |
|
|
) |
|
|
STOP_TOKEN = "DONE" |
|
|
|
|
|
def producer(records, data_root): |
|
|
for record in records: |
|
|
try: |
|
|
video_path = Path(data_root) / record["video_path"] |
|
|
vr = decord.VideoReader(str(video_path)) |
|
|
n_frames = len(vr) |
|
|
|
|
|
if n_frames == 0: |
|
|
print(f"No frames found in {record['video_path']}") |
|
|
continue |
|
|
|
|
|
|
|
|
fps = vr.get_avg_fps() |
|
|
frame_indices = [int(i * fps) for i in range(int(n_frames // fps))] |
|
|
frames = vr.get_batch( |
|
|
frame_indices |
|
|
).asnumpy() |
|
|
|
|
|
|
|
|
|
|
|
frame_queue.put( |
|
|
{ |
|
|
"frames": frames, |
|
|
"video_path": str( |
|
|
record["video_path"] |
|
|
), |
|
|
"frame_indices": frame_indices, |
|
|
} |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error processing {record['video_path']}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
frame_queue.put(STOP_TOKEN) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
producer_thread = threading.Thread( |
|
|
target=producer, args=(records, dataset.data_root), daemon=True |
|
|
) |
|
|
producer_thread.start() |
|
|
|
|
|
|
|
|
dtype = torch.float16 |
|
|
|
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
while True: |
|
|
item = frame_queue.get() |
|
|
if item == STOP_TOKEN: |
|
|
break |
|
|
|
|
|
frames = item["frames"] |
|
|
video_path = item["video_path"] |
|
|
frame_indices = item.get("frame_indices", [0, 1, 2]) |
|
|
|
|
|
ret_per_video = { |
|
|
"video_path": video_path, |
|
|
"frame_indices": frame_indices, |
|
|
"bbox_list": [], |
|
|
} |
|
|
num_detections = 0 |
|
|
for idx, frame in zip(frame_indices, frames): |
|
|
try: |
|
|
bboxes = detector.detect( |
|
|
frame |
|
|
).tolist() |
|
|
ret_per_video["bbox_list"].append(bboxes) |
|
|
num_detections += len(bboxes) |
|
|
except Exception as e: |
|
|
print( |
|
|
f"Error running human detection for frame {idx} of {video_path}: {e}" |
|
|
) |
|
|
continue |
|
|
|
|
|
results.append(ret_per_video) |
|
|
frame_queue.task_done() |
|
|
|
|
|
producer_thread.join() |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"Time taken: {end_time - start_time} seconds") |
|
|
print(f"Total number of videos processed with human detections: {len(results)}") |
|
|
|
|
|
|
|
|
if results: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
json_path = Path(save_dir) / f"human_detection_results_{job_id}.json" |
|
|
import json |
|
|
|
|
|
with open(json_path, "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
print(f"Results saved to {json_path}") |
|
|
else: |
|
|
print("No results to save.") |
|
|
|