| | """ |
| | https://huggingface.co/tomiwa1a/video-search |
| | """ |
| | from typing import Dict |
| |
|
| | from sentence_transformers import SentenceTransformer |
| | from tqdm import tqdm |
| | import whisper |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
| | import torch |
| | import pytube |
| | import time |
| |
|
| |
|
| | class EndpointHandler(): |
| | |
| | WHISPER_MODEL_NAME = "tiny.en" |
| | SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1" |
| | QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa" |
| | SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum" |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | device_number = 0 if torch.cuda.is_available() else -1 |
| |
|
| | def __init__(self, path=""): |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | device_number = 0 if torch.cuda.is_available() else -1 |
| | print(f'whisper and question_answer_model will use: {device}') |
| | print(f'whisper and question_answer_model will use device_number: {device_number}') |
| |
|
| | t0 = time.time() |
| | self.whisper_model = whisper.load_model(self.WHISPER_MODEL_NAME).to(device_number) |
| | t1 = time.time() |
| |
|
| | total = t1 - t0 |
| | print(f'Finished loading whisper_model in {total} seconds') |
| |
|
| | t0 = time.time() |
| | self.sentence_transformer_model = SentenceTransformer(self.SENTENCE_TRANSFORMER_MODEL_NAME) |
| | t1 = time.time() |
| |
|
| | total = t1 - t0 |
| | print(f'Finished loading sentence_transformer_model in {total} seconds') |
| |
|
| | t0 = time.time() |
| | self.summarizer = pipeline("summarization", model=self.SUMMARIZER_MODEL_NAME, device=device_number) |
| | t1 = time.time() |
| |
|
| | total = t1 - t0 |
| | print(f'Finished loading summarizer in {total} seconds') |
| |
|
| | self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME) |
| | t0 = time.time() |
| | self.question_answer_model = AutoModelForSeq2SeqLM.from_pretrained \ |
| | (self.QUESTION_ANSWER_MODEL_NAME).to(device_number) |
| | t1 = time.time() |
| | total = t1 - t0 |
| | print(f'Finished loading question_answer_model in {total} seconds') |
| |
|
| | def __call__(self, data: Dict[str, str]) -> Dict: |
| | """ |
| | Args: |
| | data (:obj:): |
| | includes the URL to video for transcription |
| | Return: |
| | A :obj:`dict`:. transcribed dict |
| | """ |
| | |
| | print('data', data) |
| |
|
| | if "inputs" not in data: |
| | raise Exception(f"data is missing 'inputs' key which EndpointHandler expects. Received: {data}" |
| | f" See: https://huggingface.co/docs/inference-endpoints/guides/custom_handler#2-create-endpointhandler-cp") |
| | video_url = data.pop("video_url", None) |
| | query = data.pop("query", None) |
| | long_form_answer = data.pop("long_form_answer", None) |
| | summarize = data.pop("summarize", False) |
| | encoded_segments = {} |
| | if video_url: |
| | video_with_transcript = self.transcribe_video(video_url) |
| | video_with_transcript['transcript']['transcription_source'] = f"whisper_{self.WHISPER_MODEL_NAME}" |
| | encode_transcript = data.pop("encode_transcript", True) |
| | if encode_transcript: |
| | encoded_segments = self.combine_transcripts(video_with_transcript) |
| | encoded_segments = { |
| | "encoded_segments": self.encode_sentences(encoded_segments) |
| | } |
| | return { |
| | **video_with_transcript, |
| | **encoded_segments |
| | } |
| | elif summarize: |
| | summary = self.summarize_video(data["segments"]) |
| | return {"summary": summary} |
| | elif query: |
| | if long_form_answer: |
| | context = data.pop("context", None) |
| | answer = self.generate_answer(query, context) |
| | response = { |
| | "answer": answer |
| | } |
| |
|
| | return response |
| | else: |
| | query = [{"text": query, "id": ""}] if isinstance(query, str) else query |
| | encoded_segments = self.encode_sentences(query) |
| |
|
| | response = { |
| | "encoded_segments": encoded_segments |
| | } |
| |
|
| | return response |
| |
|
| | else: |
| | return { |
| | "error": "'video_url' or 'query' must be provided" |
| | } |
| |
|
| | def transcribe_video(self, video_url): |
| | decode_options = { |
| | |
| | |
| | |
| | |
| | "language": "en", |
| | "verbose": True |
| | } |
| | yt = pytube.YouTube(video_url) |
| | video_info = { |
| | 'id': yt.video_id, |
| | 'thumbnail': yt.thumbnail_url, |
| | 'title': yt.title, |
| | 'views': yt.views, |
| | 'length': yt.length, |
| | |
| | |
| | 'url': f"https://www.youtube.com/watch?v={yt.video_id}" |
| | } |
| | stream = yt.streams.filter(only_audio=True)[0] |
| | path_to_audio = f"{yt.video_id}.mp3" |
| | stream.download(filename=path_to_audio) |
| | t0 = time.time() |
| | transcript = self.whisper_model.transcribe(path_to_audio, **decode_options) |
| | t1 = time.time() |
| | for segment in transcript['segments']: |
| | |
| | segment.pop('tokens', None) |
| |
|
| | total = t1 - t0 |
| | print(f'Finished transcription in {total} seconds') |
| |
|
| | |
| | return {"transcript": transcript, 'video': video_info} |
| |
|
| | def encode_sentences(self, transcripts, batch_size=64): |
| | """ |
| | Encoding all of our segments at once or storing them locally would require too much compute or memory. |
| | So we do it in batches of 64 |
| | :param transcripts: |
| | :param batch_size: |
| | :return: |
| | """ |
| | |
| | all_batches = [] |
| | for i in tqdm(range(0, len(transcripts), batch_size)): |
| | |
| | i_end = min(len(transcripts), i + batch_size) |
| | |
| | batch_meta = [{ |
| | **row |
| | } for row in transcripts[i:i_end]] |
| | |
| | batch_text = [ |
| | row['text'] for row in batch_meta |
| | ] |
| | |
| | batch_vectors = self.sentence_transformer_model.encode(batch_text).tolist() |
| |
|
| | batch_details = [ |
| | { |
| | **batch_meta[x], |
| | 'vectors': batch_vectors[x] |
| | } for x in range(0, len(batch_meta)) |
| | ] |
| | all_batches.extend(batch_details) |
| |
|
| | return all_batches |
| |
|
| | def summarize_video(self, segments): |
| | for index, segment in enumerate(segments): |
| | segment['summary'] = self.summarizer(segment['text']) |
| | segment['summary'] = segment['summary'][0]['summary_text'] |
| | print('index', index) |
| | print('length', segment['length']) |
| | print('text', segment['text']) |
| | print('summary', segment['summary']) |
| |
|
| | return segments |
| |
|
| | def generate_answer(self, query, documents): |
| |
|
| | |
| | conditioned_doc = "<P> " + " <P> ".join([d for d in documents]) |
| | query_and_docs = "question: {} context: {}".format(query, conditioned_doc) |
| |
|
| | model_input = self.question_answer_tokenizer(query_and_docs, truncation=False, padding=True, |
| | return_tensors="pt") |
| |
|
| | generated_answers_encoded = self.question_answer_model.generate( |
| | input_ids=model_input["input_ids"].to(self.device), |
| | attention_mask=model_input["attention_mask"].to(self.device), |
| | min_length=64, |
| | max_length=256, |
| | do_sample=False, |
| | early_stopping=True, |
| | num_beams=8, |
| | temperature=1.0, |
| | top_k=None, |
| | top_p=None, |
| | eos_token_id=self.question_answer_tokenizer.eos_token_id, |
| | no_repeat_ngram_size=3, |
| | num_return_sequences=1) |
| | answer = self.question_answer_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, |
| | clean_up_tokenization_spaces=True) |
| | return answer |
| |
|
| | @staticmethod |
| | def combine_transcripts(video, window=6, stride=3): |
| | """ |
| | |
| | :param video: |
| | :param window: number of sentences to combine |
| | :param stride: number of sentences to 'stride' over, used to create overlap |
| | :return: |
| | """ |
| | new_transcript_segments = [] |
| |
|
| | video_info = video['video'] |
| | transcript_segments = video['transcript']['segments'] |
| | for i in tqdm(range(0, len(transcript_segments), stride)): |
| | i_end = min(len(transcript_segments), i + window) |
| | text = ' '.join(transcript['text'] |
| | for transcript in |
| | transcript_segments[i:i_end]) |
| | |
| | start = int(transcript_segments[i]['start']) |
| | end = int(transcript_segments[i]['end']) |
| | new_transcript_segments.append({ |
| | **video_info, |
| | **{ |
| | 'start': start, |
| | 'end': end, |
| | 'title': video_info['title'], |
| | 'text': text, |
| | 'id': f"{video_info['id']}-t{start}", |
| | 'url': f"https://youtu.be/{video_info['id']}?t={start}", |
| | 'video_id': video_info['id'], |
| | } |
| | }) |
| | return new_transcript_segments |
| |
|