diff --git a/.gitignore b/.gitignore index e08cddb..4f09084 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ tmp/ video_database/ .git.bak/ +.gradio/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index bdb2e7d..332dd9c 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [](https://arxiv.org/abs/2505.18079) [](https://opensource.org/licenses/MIT) +[](https://87fc7dc81d4b38ed01.gradio.live) This repository contains the official implementation of the paper [Deep Video Discovery: Agentic Search with Tool Use for Long-form Video Understanding](https://arxiv.org/abs/2505.18079), which achieves the state-of-the-art performance by a large margin on multiple long video benchmarks including the challenging [LVBench](https://lvbench.github.io/). @@ -9,6 +10,17 @@ This repository contains the official implementation of the paper [Deep Video Di  +## Update + +- **2025/10/15**: Add a markdown to help to reproduce the LVBench results: [REPRODUCE.md](reproduce/README.md). Email me (xiaoyizhang [/at/] microsoft.com) if your issue gets no response in 24 hours. +- **2025/09/19**: Accepted by NeurIPS 2025 🎉 +- **2025/08/04**: Upload captions on benchmarks for reproduction: [LVBench](https://1drv.ms/u/c/f029f6f5a52c17c4/ETR7ogx7YCtBgtDu66a4R14B7RKLZJoz20D4Z5I1KD6HTg?e=404kKg), [LVBench w/ transcripts](https://1drv.ms/u/c/f029f6f5a52c17c4/EcqO2lC_hRxGn-0t0IBNKZcBts3HDCEg8mZo4ltN6kXFUQ?e=XmabUn), [Video-MME](https://1drv.ms/u/c/f029f6f5a52c17c4/EVKjXQnPjeZGi-onOxEMb8UBxqI9NexKzccHuYEe8-0Lig?e=a4SxCU), [LongVideoBench](https://1drv.ms/u/c/f029f6f5a52c17c4/EQp_PABeb3ZIiysjIn-_5gEBbkhtfcBwCM1pel9xl3JHPg?e=TLpQXQ) and [EgoSchema](https://1drv.ms/u/c/f029f6f5a52c17c4/Ec0oEX3tO5pIknRdEqT9LDQB0hbS9vR9fUJaVbRfCQPJKg?e=bszgh6). +- **2025/08/02**: Support auto subtitle in the demo. +- **2025/07/17**: Add gradio demo. +- **2025/07/16**: Add `lite_mode` to enable a lightweight version of the agent that uses only subtitles. Good for Youtube podcast analysis! +- **2025/07/14**: Support OpenAI API and Azure OpenAI API. +- **2025/07/08**: Initial release of the Deep Video Discovery codebase. + ## Introduction **Deep Video Discovery (DVD)** is a deep-research style question answering agent designed for understanding extra-long videos. Leveraging the powerful capabilities of large language models (LLMs), DVD effectively interprets and processes extensive video content to answer complex user queries. @@ -45,6 +57,11 @@ The core design of DVD includes: pip install -r requirements.txt ``` +3. (Optional) **Install gradio for demo:** + ```bash + pip install gradio + ``` + ## Usage Note: Set up your configuration by updating the variables in `config.py`. @@ -54,12 +71,12 @@ Note: Set up your configuration by updating the variables in `config.py`. The `local_run.py` script provides an example of how to run the Deep Video Discovery agent by providing a youtube url and question about it. ```bash - python local_run.py https://www.youtube.com/watch?v=ktbGziZlt3c "how many animals appear in this video" + python local_run.py https://www.youtube.com/watch?v=PQFQ-3d2J-8 "what did the main speaker talk about in the last part of video?" ``` ## TODO -- [ ] Support OpenAI API key configuration. +- [x] Support OpenAI API key configuration. - [ ] Implement MCP server. - [ ] Release evaluation trajectory data on long video benchmarks. diff --git a/app.py b/app.py new file mode 100644 index 0000000..d7fcf0a --- /dev/null +++ b/app.py @@ -0,0 +1,279 @@ +import json +import os, argparse, gradio as gr +from dvd import config +from dvd.dvd_core import DVDCoreAgent +from dvd.video_utils import load_video, decode_video_to_frames, download_srt_subtitle +from dvd.frame_caption import process_video, process_video_lite +from dvd.utils import extract_answer + +######################################################################## +# Helper functions +######################################################################## +def get_youtube_thumbnail(video_url: str): + """Extract YouTube video ID and return thumbnail URL.""" + if not video_url: + return None + + # Extract video ID from YouTube URL + video_id = None + if "youtube.com/watch?v=" in video_url: + video_id = video_url.split("v=")[1].split("&")[0] + elif "youtu.be/" in video_url: + video_id = video_url.split("youtu.be/")[1].split("?")[0] + + if video_id: + # YouTube provides several thumbnail qualities + # maxresdefault > hqdefault > mqdefault > default + return f"https://img.youtube.com/vi/{video_id}/hqdefault.jpg" + + return None + +def _prepare_video_assets(video_url: str): + """Download / decode / caption the video exactly as in local_run.py, + returning (video_id, caption_file, video_db_path).""" + # --- reuse logic from local_run.py (trimmed for brevity) ------------- + if "v=" in video_url: # YouTube URL + video_id = video_url.split("v=")[1] + else: # local file or misc. + video_id = os.path.splitext(os.path.basename(video_url))[0] + + video_path = os.path.join(config.VIDEO_DATABASE_FOLDER, "raw", f"{video_id}.mp4") + frames_dir = os.path.join(config.VIDEO_DATABASE_FOLDER, video_id, "frames") + captions_dir = os.path.join(config.VIDEO_DATABASE_FOLDER, video_id, "captions") + video_db_path= os.path.join(config.VIDEO_DATABASE_FOLDER, video_id, "database.json") + srt_path = os.path.join(config.VIDEO_DATABASE_FOLDER, video_id, "subtitles.srt") + os.makedirs(os.path.join(config.VIDEO_DATABASE_FOLDER, "raw"), exist_ok=True) + os.makedirs(frames_dir, exist_ok=True) + os.makedirs(captions_dir, exist_ok=True) + + if config.LITE_MODE: + if not os.path.exists(srt_path): + download_srt_subtitle(video_url, srt_path) + process_video_lite(captions_dir, srt_path) + caption_file = os.path.join(captions_dir, "captions.json") + else: + if not os.path.exists(video_path): + load_video(video_url, video_path) + if not os.path.exists(frames_dir) or not os.listdir(frames_dir): + decode_video_to_frames(video_path) + caption_file = os.path.join(captions_dir, "captions.json") + if not os.path.exists(caption_file): + process_video(frames_dir, captions_dir) + + return video_id, caption_file, video_db_path + +def solve(video_url: str, question: str): + """Streamed inference function used by Gradio.""" + if not video_url or not question: + yield "❗ Please provide both a video URL and a question." + return + + try: + yield "🔄 **Processing video...**" + _, caption_file, video_db_path = _prepare_video_assets(video_url) + + yield "🤖 **Initializing DVD agent...**" + agent = DVDCoreAgent(video_db_path, caption_file, config.MAX_ITERATIONS) + + accumulated_text = "### 🎯 Analysis Process:\n" + final_answer = None + + for msg in agent.stream_run(question): + # Only process messages with a role attribute + if not isinstance(msg, dict) or "role" not in msg: + continue + + # Show assistant's thinking process + if msg.get("role") == "assistant": + content = msg.get("content", "") + if content: + accumulated_text += f"\n\n**🤔 Assistant Thinking:**\n{content}" + yield accumulated_text + + # Check if assistant called the finish function + tool_calls = msg.get("tool_calls", []) + for tc in tool_calls: + if tc.get("function", {}).get("name") == "finish": + try: + args = json.loads(tc.get("function", {}).get("arguments", "{}")) + final_answer = args.get("answer", "") + except: + pass + + # Show when a tool is being called + elif msg.get("role") == "tool_call": + tool_name = msg.get("name", "unknown") + tool_args = msg.get("arguments", "{}") + try: + args_dict = json.loads(tool_args) + args_dict.pop("database", None) + # Format arguments nicely + args_str = json.dumps(args_dict, indent=2) + except: + args_str = tool_args + if tool_name != "finish": + accumulated_text += f"\n\n**🔄 Calling Tool:** `{tool_name}`\n```json\n{args_str}\n```" + yield accumulated_text + + # Show tool observations + elif msg.get("role") == "tool": + tool_name = msg.get("name", "unknown") + tool_result = msg.get("content", "") + + # Truncate long results for display + if len(tool_result) > 2000: + tool_result = tool_result[:2000] + "..." + + accumulated_text += f"\n\n**✅ Tool Result `{tool_name}`:**\n```\n{tool_result}\n```" + yield accumulated_text + + # Add final answer if found + if final_answer: + accumulated_text += f"\n### 📃✅ **Final Answer:**\n\n{final_answer}" + else: + accumulated_text += "\n\n---\n### ✅ **Analysis Complete!**" + + yield accumulated_text + + except Exception as e: + import traceback + yield f"### ⚠️ Error Occurred\n\n```\n{e}\n```\n\nDetails:\n```\n{traceback.format_exc()}\n```" + +######################################################################## +# Gradio UI +######################################################################## +def launch(args): + # Custom CSS for better styling + custom_css = """ + .gradio-container { + font-family: 'Inter', sans-serif; + } + .markdown-text { + font-size: 16px; + } + #answer-box { + border: 2px solid #e5e7eb; + border-radius: 8px; + padding: 20px; + background-color: #f9fafb; + min-height: 400px; + max-height: 600px; + overflow-y: auto; + } + .button-primary { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + color: white; + font-weight: bold; + font-size: 18px; + padding: 12px 24px; + } + #video-thumbnail { + border-radius: 8px; + box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); + } + """ + + with gr.Blocks(title="DVD Video Q&A Demo", css=custom_css, theme=gr.themes.Soft()) as demo: + gr.Markdown( + """ + # 🎬 Deep Video Discovery: Agentic Search with Tool Use for Long-form Video Understanding + +
+ Provide a YouTube URL, then ask any question about the video content. + The system will analyze the video and provide detailed answers. + Note that this online demo only provides lite mode of DVD where only subtitles are used. + To use full DVD capabilities, please deploy it locally. +
+ """ + ) + + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("### 📹 Video Input") + video_url = gr.Textbox( + label="Video URL / Path", + placeholder="e.g. https://www.youtube.com/watch?v=dQw4w9WgXcQ", + lines=1, + info="Support YouTube URLs or local video paths" + ) + + # Add video thumbnail + video_thumbnail = gr.Image( + label="Video Thumbnail", + elem_id="video-thumbnail", + height=200, + visible=False, + interactive=False + ) + + gr.Markdown("### ❓ Your Question") + question = gr.Textbox( + label="Question about the video", + placeholder="What happens in this video? Who are the main characters?", + lines=3, + info="Ask anything about the video content" + ) + + with gr.Row(): + run_btn = gr.Button("🔍 Analyze Video", variant="primary", elem_classes=["button-primary"]) + clear_btn = gr.ClearButton([video_url, question, video_thumbnail], value="🗑️ Clear") + + gr.Markdown("### 💡 Example Questions") + examples = gr.Examples( + examples=[ + ["https://www.youtube.com/watch?v=i2qSxMVeVLI", "What is the main topic discussed in this video?"], + ["https://www.youtube.com/watch?v=nOxKexn3iBo", "Who are the speakers and what are their key points?"], + ], + inputs=[video_url, question], + label="" + ) + + with gr.Column(scale=2): + gr.Markdown("### 📊 Analysis Results") + answer_box = gr.Markdown( + value="*Results will appear here after clicking 'Analyze Video'...*", + elem_id="answer-box", + label="" + ) + + gr.Markdown( + """ + --- ++ DVD: Powered by advanced video understanding and language models | + GitHub +
+ """ + ) + + # Event handlers + def update_thumbnail(url): + """Update thumbnail when URL changes.""" + thumbnail_url = get_youtube_thumbnail(url) + if thumbnail_url: + return gr.update(value=thumbnail_url, visible=True) + else: + return gr.update(value=None, visible=False) + + video_url.change( + fn=update_thumbnail, + inputs=[video_url], + outputs=[video_thumbnail] + ) + + import inspect + click_kwargs = dict(fn=solve, inputs=[video_url, question], outputs=answer_box) + if "stream" in inspect.signature(gr.Button.click).parameters: + click_kwargs["stream"] = True + run_btn.click(**click_kwargs) + + demo.launch(share=args.share) + +######################################################################## +# CLI entry-point (optional) +######################################################################## +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--share", action="store_true", help="Gradio share flag") + args = parser.parse_args() + launch(args) \ No newline at end of file diff --git a/dvd/build_database.py b/dvd/build_database.py index 5361f11..51a3db5 100644 --- a/dvd/build_database.py +++ b/dvd/build_database.py @@ -102,6 +102,7 @@ def frame_inspect_tool( messages=input_msgs, endpoints=config.AOAI_TOOL_VLM_ENDPOINT_LIST, model_name=config.AOAI_TOOL_VLM_MODEL_NAME, + api_key=config.OPENAI_API_KEY, image_paths=files, temperature=0, max_tokens=512, @@ -130,6 +131,7 @@ def clip_search_tool( endpoints=config.AOAI_EMBEDDING_RESOURCE_LIST, model_name=config.AOAI_EMBEDDING_LARGE_MODEL_NAME, input_text=[event_description], + api_key=config.OPENAI_API_KEY, )[0]['embedding'] results = database.query( query_emb, @@ -164,6 +166,7 @@ def global_browse_tool( endpoints=config.AOAI_EMBEDDING_RESOURCE_LIST, model_name=config.AOAI_EMBEDDING_LARGE_MODEL_NAME, input_text=[query], + api_key=config.OPENAI_API_KEY, )[0]['embedding'] results = database.query( query_emb, @@ -202,6 +205,7 @@ def global_browse_tool( messages=input_msgs, endpoints=config.AOAI_TOOL_VLM_ENDPOINT_LIST, model_name=config.AOAI_TOOL_VLM_MODEL_NAME, + api_key=config.OPENAI_API_KEY, temperature=0, max_tokens=512, ) @@ -318,6 +322,7 @@ def single_batch_embedding_task(data): endpoints=config.AOAI_EMBEDDING_RESOURCE_LIST, model_name=config.AOAI_EMBEDDING_LARGE_MODEL_NAME, input_text=captions, + api_key=config.OPENAI_API_KEY, ) max_tries = 3 while embs is None or len(embs) != len(captions): @@ -329,6 +334,7 @@ def single_batch_embedding_task(data): endpoints=config.AOAI_EMBEDDING_RESOURCE_LIST, model_name=config.AOAI_EMBEDDING_LARGE_MODEL_NAME, input_text=captions, + api_key=config.OPENAI_API_KEY, ) return list(zip(timestamps, cap_infos, [d['embedding'] for d in embs])) diff --git a/dvd/config.py b/dvd/config.py index e5e00fa..0659d71 100644 --- a/dvd/config.py +++ b/dvd/config.py @@ -1,3 +1,5 @@ +import os + # ------------------ video download and segmentation configuration ------------------ # VIDEO_DATABASE_FOLDER = "./video_database/" VIDEO_RESOLUTION = "360" # denotes the height of the video @@ -5,6 +7,8 @@ CLIP_SECS = 10 # seconds # ------------------ model configuration ------------------ # +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", None) # will overwrite Azure OpenAI setting + AOAI_CAPTION_VLM_ENDPOINT_LIST = [""] AOAI_CAPTION_VLM_MODEL_NAME = "gpt-4.1-mini" @@ -20,6 +24,7 @@ AOAI_EMBEDDING_LARGE_DIM = 3072 # ------------------ agent and tool setting ------------------ # +LITE_MODE = True # if True, only leverage srt subtitle, no pixel downloaded or pixel captioning GLOBAL_BROWSE_TOPK = 300 OVERWRITE_CLIP_SEARCH_TOPK = 0 # 0 means no overwrite and let agent decide diff --git a/dvd/dvd_core.py b/dvd/dvd_core.py index 656d6b8..db5f014 100644 --- a/dvd/dvd_core.py +++ b/dvd/dvd_core.py @@ -9,6 +9,7 @@ from dvd.func_call_shema import as_json_schema from dvd.func_call_shema import doc as D from dvd.utils import call_openai_model_with_tools +from concurrent.futures import ThreadPoolExecutor, as_completed TOPK = 16 @@ -26,6 +27,8 @@ def finish(answer: A[str, D("Answer to the user's question.")]) -> None: class DVDCoreAgent: def __init__(self, video_db_path, video_caption_path, max_iterations): self.tools = [frame_inspect_tool, clip_search_tool, global_browse_tool, finish] + if config.LITE_MODE: + self.tools.remove(frame_inspect_tool) self.name_to_function_map = {tool.__name__: tool for tool in self.tools} self.function_schemas = [ {"function": as_json_schema(func), "type": "function"} @@ -139,6 +142,7 @@ def run(self, question) -> list[dict]: model_name=config.AOAI_ORCHESTRATOR_LLM_MODEL_NAME, tools=self.function_schemas, temperature=0.0, + api_key=config.OPENAI_API_KEY, ) if response is None: return None @@ -155,6 +159,82 @@ def run(self, question) -> list[dict]: return msgs + def parallel_run(self, questions, max_workers=4) -> list[list[dict]]: + """ + Run multiple questions in parallel. + """ + results = [] + results = [None] * len(questions) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_index = { + executor.submit(self.run, q): idx + for idx, q in enumerate(questions) + } + for future in as_completed(future_to_index): + idx = future_to_index[future] + try: + results[idx] = future.result() + except Exception as e: + print(f"Error processing question: {e}") + results[idx] = None + return results + + # ------------------------------------------------------------------ # + # Streaming (generator) loop + # ------------------------------------------------------------------ # + def stream_run(self, question): + """ + A generator version of `run`. + Yields: + dict: every assistant / tool message produced during reasoning. + """ + msgs = copy.deepcopy(self.messages) + msgs[-1]["content"] = msgs[-1]["content"].replace("QUESTION_PLACEHOLDER", question) + + for i in range(self.max_iterations): + # Force a final `finish` on the last iteration + if i == self.max_iterations - 1: + final_usr_msg = { + "role": "user", + "content": "Please call the `finish` function to finish the task.", + } + msgs.append(final_usr_msg) + # Don't yield user messages to the UI + + response = call_openai_model_with_tools( + msgs, + endpoints=config.AOAI_ORCHESTRATOR_LLM_ENDPOINT_LIST, + model_name=config.AOAI_ORCHESTRATOR_LLM_MODEL_NAME, + tools=self.function_schemas, + temperature=0.0, + api_key=config.OPENAI_API_KEY, + ) + if response is None: + return + + response.setdefault("role", "assistant") + msgs.append(response) + yield response # ← stream assistant reply + + # Execute any requested tool calls + try: + for tool_call in response.get("tool_calls", []): + # Yield a formatted message about the tool being called + tool_name = tool_call.get("function", {}).get("name", "unknown") + tool_args = tool_call.get("function", {}).get("arguments", "{}") + yield { + "role": "tool_call", + "name": tool_name, + "arguments": tool_args + } + + self._exec_tool(tool_call, msgs) + # Only yield the tool result message + if msgs[-1].get("role") == "tool": + yield msgs[-1] # ← stream tool observation + except StopException: + return + def single_run_wrapper(info) -> dict: qid, video_db_path, video_caption_path, question = info diff --git a/dvd/frame_caption.py b/dvd/frame_caption.py index eea8684..9ed7e0a 100644 --- a/dvd/frame_caption.py +++ b/dvd/frame_caption.py @@ -233,12 +233,15 @@ def _caption_clip(task: Tuple[str, Dict], caption_ckpt_folder) -> Tuple[str, dic model_name=config.AOAI_CAPTION_VLM_MODEL_NAME, return_json=True, image_paths=files, + api_key=config.OPENAI_API_KEY, )["content"] if resp is None: continue try: assert isinstance(resp, str), f"Response must be a JSON string instead of {type(resp)}:{resp}." parsed = json.loads(resp) + parsed["clip_description"] += f"\n\nTranscript during this video clip: {transcript}." # add transcript to description + resp = json.dumps(parsed) with open(os.path.join(caption_ckpt_folder, f"{timestamp}.json"), "w") as f: f.write(resp) return timestamp, parsed @@ -269,6 +272,7 @@ def merge_subject_registries(registries: List[dict]) -> dict: endpoints=config.AOAI_CAPTION_VLM_ENDPOINT_LIST, model_name=config.AOAI_CAPTION_VLM_MODEL_NAME, return_json=True, + api_key=config.OPENAI_API_KEY, )["content"] if resp is None: continue @@ -327,6 +331,25 @@ def process_video( json.dump(frame_captions, f, indent=4) +def process_video_lite( + output_caption_folder: str, + subtitle_file_path: str, +): + """ + Process video in LITE_MODE using SRT subtitles. + """ + captions = parse_srt_to_dict(subtitle_file_path) + frame_captions = {} + for key, text in captions.items(): + frame_captions[key] = { + "caption": f"\n\nTranscript during this video clip: {text}.", + } + frame_captions["subject_registry"] = {} + with open( + os.path.join(output_caption_folder, "captions.json"), "w" + ) as f: + json.dump(frame_captions, f, indent=4) + # --------------------------------------------------------------------------- # # main # # --------------------------------------------------------------------------- # diff --git a/dvd/utils.py b/dvd/utils.py index 0233a99..663bbf4 100644 --- a/dvd/utils.py +++ b/dvd/utils.py @@ -76,6 +76,7 @@ def call_openai_model_with_tools( messages, endpoints, model_name, + api_key: str = None, tools: list = [], # List of tool definitions image_paths: list = [], max_tokens: int = 4096, @@ -83,22 +84,32 @@ def call_openai_model_with_tools( tool_choice: str = "auto", # Can be "auto", "none", or a specific tool return_json: bool = False, ) -> dict: - credential = AzureCliCredential() - token = credential.get_token('https://cognitiveservices.azure.com/') - headers = { - "Content-Type": "application/json", - 'Authorization': 'Bearer ' + token.token - } - if isinstance(endpoints, str): - endpoint = endpoints - elif isinstance(endpoints, list): - endpoint = random.choice(endpoints) + if api_key: + headers = { + "Content-Type": "application/json", + 'Authorization': 'Bearer ' + api_key + } + endpoint = "https://api.openai.com/v1" + url = f"{endpoint}/chat/completions" else: - raise ValueError("Endpoints must be a string or a list of strings.") + credential = AzureCliCredential() + token = credential.get_token('https://cognitiveservices.azure.com/') + headers = { + "Content-Type": "application/json", + 'Authorization': 'Bearer ' + token.token + } + if isinstance(endpoints, str): + endpoint = endpoints + elif isinstance(endpoints, list): + endpoint = random.choice(endpoints) + else: + raise ValueError("Endpoints must be a string or a list of strings.") + url = f"{endpoint}/openai/deployments/{model_name}/chat/completions?api-version=2025-03-01-preview" + model = model_name - url = f"{endpoint}/openai/deployments/{model}/chat/completions?api-version=2025-03-01-preview" payload = { + "model": model, "messages": copy.deepcopy(messages), # "reasoning_effort": reasoning_effort, } @@ -140,7 +151,7 @@ def call_openai_model_with_tools( class AzureOpenAIEmbeddingService: @staticmethod @retry_with_exponential_backoff - def get_embeddings(endpoints, model_name, input_text): + def get_embeddings(endpoints, model_name, input_text, api_key: str = None): """ Call Azure OpenAI Embedding service and get embeddings for the input text. @@ -150,27 +161,35 @@ def get_embeddings(endpoints, model_name, input_text): :param input_text: The text for which you want to generate embeddings. :return: The embeddings as a JSON response. """ - if isinstance(endpoints, str): - endpoint = endpoints - elif isinstance(endpoints, list): - endpoint = random.choice(endpoints) + if api_key: + headers = { + "Content-Type": "application/json", + 'Authorization': 'Bearer ' + api_key + } + endpoint = "https://api.openai.com/v1" + url = f"{endpoint}/embeddings" else: - raise ValueError("Endpoints must be a string or a list of strings.") + if isinstance(endpoints, str): + endpoint = endpoints + elif isinstance(endpoints, list): + endpoint = random.choice(endpoints) + else: + raise ValueError("Endpoints must be a string or a list of strings.") + # Define the URL for the embeddings endpoint + url = f"{endpoint}/openai/deployments/{model_name}/embeddings?api-version=2023-05-15" + + credential = AzureCliCredential() + token = credential.get_token('https://cognitiveservices.azure.com/') + headers = { + "Content-Type": "application/json", + 'Authorization': 'Bearer ' + token.token + } + model = model_name - if isinstance(endpoint, list): - endpoint = random.choice(endpoint) - # Define the URL for the embeddings endpoint - url = f"{endpoint}/openai/deployments/{model}/embeddings?api-version=2023-05-15" - - credential = AzureCliCredential() - token = credential.get_token('https://cognitiveservices.azure.com/') - headers = { - "Content-Type": "application/json", - 'Authorization': 'Bearer ' + token.token - } # Set up the payload for the request payload = { - "input": input_text + "input": input_text, + "model": model } # Make the request to the Azure OpenAI service @@ -203,28 +222,49 @@ def extract_answer(message: dict) -> str | None: str | None The extracted answer, or ``None`` if no answer could be found. """ - # Direct text response - if (content := message.get("content")): - return content.strip() - # Tool-based response for call in message.get("tool_calls", []): args_json = call["function"]["arguments"] args = json.loads(args_json) if (answer := args.get("answer")): return answer + + # Direct text response + if (content := message.get("content")): + return content.strip() + return None if __name__ == "__main__": - call_openai_model_with_tools( - messages=[{"role": "user", "content": "Hello, how are you?"}], - endpoints=["https://msra-im-openai-eus2.openai.azure.com"], - model_name="o3", - tools=[], - image_paths=[], - max_tokens=4096, - temperature=0.0, - tool_choice="auto", - return_json=False, - ) \ No newline at end of file + # Example for Azure + # call_openai_model_with_tools( + # messages=[{"role": "user", "content": "Hello, how are you?"}], + # endpoints=["https://msra-im-openai-eus2.openai.azure.com"], + # model_name="o3", + # tools=[], + # image_paths=[], + # max_tokens=4096, + # temperature=0.0, + # tool_choice="auto", + # return_json=False, + # ) + + # Example for OpenAI + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + response = call_openai_model_with_tools( + messages=[{"role": "user", "content": "Hello, how are you?"}], + endpoints=None, # Not used for OpenAI + model_name="gpt-4o", + api_key=api_key, + tools=[], + image_paths=[], + max_tokens=4096, + temperature=0.0, + tool_choice="auto", + return_json=False, + ) + print(response) + else: + print("OPENAI_API_KEY environment variable not set.") \ No newline at end of file diff --git a/dvd/video_utils.py b/dvd/video_utils.py index 2b3840f..7f6c891 100644 --- a/dvd/video_utils.py +++ b/dvd/video_utils.py @@ -99,10 +99,39 @@ def load_video( ) shutil.copy2(subtitle_source, subtitle_destination) - return os.path.abspath(destination_path) +def download_srt_subtitle(video_url: str, output_path: str): + """Downloads an SRT subtitle from a YouTube URL.""" + if not _is_youtube_url(video_url): + raise ValueError("Provided URL is not a valid YouTube link.") + + output_dir = os.path.dirname(output_path) + os.makedirs(output_dir, exist_ok=True) + + ydl_opts = { + 'writesubtitles': True, + 'subtitlesformat': 'srt', + 'skip_download': True, + 'writeautomaticsub': True, + 'outtmpl': os.path.join(output_dir, '%(id)s.%(ext)s'), + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=False) + video_id = info['id'] + ydl.download([video_url]) + + # Locate the downloaded subtitle file (yt-dlp names them as