| | import cv2 |
| | import librosa |
| | import mediapipe as mp |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms.v2 as transforms |
| | from numpy.typing import NDArray |
| | from python_speech_features import logfbank |
| | from transformers import FeatureExtractionMixin |
| | from transformers.feature_extraction_utils import BatchFeature |
| |
|
| | mp_face_mesh = mp.solutions.face_mesh |
| |
|
| |
|
| | class AVHubertFeatureExtractor(FeatureExtractionMixin): |
| | model_input_names = ["input_values", "pixel_values"] |
| |
|
| | def __init__( |
| | self, |
| | max_sample_size: int | None = None, |
| | normalize: bool = True, |
| | stack_order_audio: int = 4, |
| | image_crop_size: int = 88, |
| | image_mean: float = 0.421, |
| | image_std: float = 0.165, |
| | sr: int = 16_000, |
| | static_image_mode: bool = False, |
| | refine_landmarks: bool = False, |
| | min_detection_confidence: float = 0.5, |
| | min_tracking_confidence: float = 0.5, |
| | landmark_indices: tuple[int, ...] = (5, 411, 199, 187), |
| | **kwargs, |
| | ) -> None: |
| | super().__init__(**kwargs) |
| | self.max_sample_size = max_sample_size |
| | self.normalize = normalize |
| | self.stack_order_audio = stack_order_audio |
| | self.image_crop_size = image_crop_size |
| | self.transforms = transforms.Compose( |
| | [ |
| | transforms.ToImage(), |
| | transforms.CenterCrop(image_crop_size), |
| | transforms.ToDtype(torch.float32, scale=True), |
| | transforms.Normalize([image_mean], [image_std]), |
| | ] |
| | ) |
| | self.sr = sr |
| | self.static_image_mode = static_image_mode |
| | self.refine_landmarks = refine_landmarks |
| | self.min_detection_confidence = min_detection_confidence |
| | self.min_tracking_confidence = min_tracking_confidence |
| | self.landmark_indices = landmark_indices |
| |
|
| | def _load_video(self, video: str | NDArray[np.uint8], extract_mouth: bool = False) -> torch.FloatTensor: |
| | """Input video must be in RGB format if type is numpy array.""" |
| | if isinstance(video, str): |
| | cap = cv2.VideoCapture(video) |
| | frames = [] |
| | for _ in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))): |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | if not extract_mouth: |
| | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)) |
| | else: |
| | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| | frames_np = np.stack(frames, axis=0) |
| | else: |
| | frames_np = video |
| | if not extract_mouth: |
| | frames_np = np.stack([cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in frames_np], axis=0) |
| |
|
| | if extract_mouth: |
| | frames_np = self._extract_mouth(frames_np) |
| |
|
| | return torch.from_numpy(frames_np).unsqueeze(dim=1) |
| |
|
| | def _extract_mouth(self, frames: NDArray[np.uint8]) -> NDArray[np.uint8]: |
| | mouth_frames = [] |
| | top_idx, right_idx, bottom_idx, left_idx = self.landmark_indices |
| | with mp_face_mesh.FaceMesh( |
| | static_image_mode=self.static_image_mode, |
| | max_num_faces=1, |
| | refine_landmarks=self.refine_landmarks, |
| | min_detection_confidence=self.min_detection_confidence, |
| | min_tracking_confidence=self.min_tracking_confidence, |
| | ) as face_mesh: |
| | for frame in frames: |
| | res = face_mesh.process(frame) |
| | if res.multi_face_landmarks is None or len(res.multi_face_landmarks) == 0: |
| | mouth_frames.append(np.zeros([self.image_crop_size, self.image_crop_size], dtype=np.uint8)) |
| | continue |
| | landmarks = res.multi_face_landmarks[0].landmark |
| | top = landmarks[top_idx] |
| | left = landmarks[left_idx] |
| | right = landmarks[right_idx] |
| | bottom = landmarks[bottom_idx] |
| |
|
| | H, W = frame.shape[:2] |
| | xmax = max(top.x, left.x, right.x, bottom.x) |
| | ymax = max(top.y, left.y, right.y, bottom.y) |
| | xmin = min(top.x, left.x, right.x, bottom.x) |
| | ymin = min(top.y, left.y, right.y, bottom.y) |
| |
|
| | patch_size = max((xmax - xmin) * W, (ymax - ymin) * H) |
| | half = int(patch_size / 2) |
| | y_center = int(ymin * H) + int(((ymax - ymin) / 2) * H) |
| | x_center = int(xmin * W) + int(((xmax - xmin) / 2) * W) |
| | lip = frame[ |
| | y_center - half : y_center + half, |
| | x_center - half : x_center + half, |
| | :, |
| | ] |
| | try: |
| | lip = cv2.resize(lip, (self.image_crop_size, self.image_crop_size)) |
| | except Exception: |
| | lip = np.zeros([self.image_crop_size, self.image_crop_size, 3], dtype=np.uint8) |
| | mouth_frames.append(cv2.cvtColor(lip, cv2.COLOR_RGB2GRAY)) |
| | return np.stack(mouth_frames, axis=0) |
| |
|
| | def _load_audio(self, audio: str | NDArray[np.float32]) -> torch.FloatTensor: |
| | def stacker(feats, stack_order): |
| | feat_dim = feats.shape[1] |
| | if len(feats) % stack_order != 0: |
| | res = stack_order - len(feats) % stack_order |
| | res = np.zeros([res, feat_dim]).astype(feats.dtype) |
| | feats = np.concatenate([feats, res], axis=0) |
| | feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order * feat_dim) |
| | return feats |
| |
|
| | sr = None |
| | if isinstance(audio, str): |
| | audio, sr = librosa.load(audio, sr=16_000) |
| | if sr is None: |
| | sr = self.sr |
| | fbank = logfbank(audio, samplerate=sr).astype(np.float32) |
| | fbank = stacker(fbank, self.stack_order_audio) |
| | return torch.from_numpy(fbank) |
| |
|
| | def _align_time_steps( |
| | self, audio: list[torch.FloatTensor], video: list[torch.FloatTensor] |
| | ) -> tuple[list[torch.FloatTensor], list[torch.FloatTensor]]: |
| | aligned_indices = [] |
| | for sample_audio, sample_video in zip(audio, video): |
| | diff = len(sample_audio) - len(sample_video) |
| | if diff != 0: |
| | aligned_indices.append( |
| | torch.arange(0, len(sample_audio)).float() * len(sample_video) / len(sample_audio) |
| | ) |
| | else: |
| | aligned_indices.append(torch.arange(0, len(sample_audio))) |
| | return ( |
| | audio, |
| | [ |
| | sample[torch.clamp(torch.floor(indices), max=sample.shape[0] - 1).long()] |
| | for sample, indices in zip(video, aligned_indices) |
| | ], |
| | ) |
| |
|
| | def __call__( |
| | self, |
| | raw_audio: NDArray[np.float32] | str | list[NDArray[np.float32]] | list[str] | None = None, |
| | raw_video: NDArray[np.uint8] | str | list[NDArray[np.uint8]] | list[str] | None = None, |
| | extract_mouth: bool = False, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | if not isinstance(raw_audio, list): |
| | raw_audio = [raw_audio] |
| | if not isinstance(raw_video, list): |
| | raw_video = [raw_video] |
| |
|
| | audio = [self._load_audio(sample) if sample is not None else None for sample in raw_audio] |
| | video = [self._load_video(sample, extract_mouth) if sample is not None else None for sample in raw_video] |
| | for batch_idx in range(len(audio)): |
| | sample_a = audio[batch_idx] |
| | sample_v = video[batch_idx] |
| | assert sample_a is not None or sample_v is not None |
| | if sample_a is None: |
| | sample_a = torch.zeros((sample_v.shape[0], 26 * self.stack_order_audio)) |
| | audio[batch_idx] = sample_a |
| | elif sample_v is None: |
| | sample_v = torch.zeros((sample_a.shape[0], 1, self.image_crop_size, self.image_crop_size)) |
| | video[batch_idx] = sample_v |
| |
|
| | audio, video = self._align_time_steps(audio, video) |
| | max_length = max(len(data) for data in audio) |
| | input_values = [] |
| | pixel_values = [] |
| | padding_mask = [] |
| | for feat_audio, feat_video in zip(audio, video): |
| | remainder_length = max_length - len(feat_audio) |
| | audio_remainder = torch.zeros( |
| | size=(remainder_length,) + feat_audio.size()[1:], |
| | dtype=feat_audio.dtype, |
| | ) |
| | video_remainder = torch.zeros( |
| | size=(remainder_length,) + feat_video.size()[1:], |
| | dtype=feat_video.dtype, |
| | ) |
| |
|
| | feat_audio = torch.cat((feat_audio, audio_remainder)) |
| | feat_video = torch.cat((feat_video, video_remainder)) |
| | if self.max_sample_size: |
| | feat_audio = feat_audio[: self.max_sample_size] |
| | feat_video = feat_video[: self.max_sample_size] |
| | pad_mask = torch.zeros(max_length) |
| | pad_mask[max_length - remainder_length :] = 1 |
| |
|
| | input_values.append(feat_audio) |
| | pixel_values.append(feat_video) |
| | padding_mask.append(pad_mask) |
| |
|
| | input_values = torch.stack(input_values) |
| | batch = BatchFeature( |
| | { |
| | "input_values": ( |
| | F.layer_norm(input_values, input_values.shape[2:]) if self.normalize else input_values |
| | ), |
| | "pixel_values": self.transforms(torch.stack(pixel_values)), |
| | "padding_mask": torch.stack(padding_mask), |
| | } |
| | ) |
| | return batch |
| |
|
| | def to_dict(self): |
| | output = super().to_dict() |
| | output["transforms"] = self._transforms_to_dict(output["transforms"]) |
| | return output |
| |
|
| | def _transforms_to_dict(self, transforms: transforms.Compose): |
| | output = [] |
| | for component in transforms.__dict__["transforms"]: |
| | name = component.__class__.__name__ |
| | component_dict = {"transforms_type": name} |
| | for k, v in component.__dict__.items(): |
| | if k.startswith("_"): |
| | continue |
| | component_dict[k] = str(v) |
| | output.append(component_dict) |
| | return output |
| |
|