From 60ed1567ff098391cc158966e5fa8af004cedb63 Mon Sep 17 00:00:00 2001 From: anonymous626 <131758638+anonymous626@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:41:09 +0800 Subject: [PATCH 1/4] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 20eae2d..58aaebf 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ This repository contains the official implementation of the paper [Deep Video Di ## Update +- **2025/09/19**: Accepted by NeurIPS 2025 🎉 - **2025/08/04**: Upload captions on benchmarks for reproduction: [LVBench](https://1drv.ms/u/c/f029f6f5a52c17c4/ETR7ogx7YCtBgtDu66a4R14B7RKLZJoz20D4Z5I1KD6HTg?e=404kKg), [LVBench w/ transcripts](https://1drv.ms/u/c/f029f6f5a52c17c4/EcqO2lC_hRxGn-0t0IBNKZcBts3HDCEg8mZo4ltN6kXFUQ?e=XmabUn), [Video-MME](https://1drv.ms/u/c/f029f6f5a52c17c4/EVKjXQnPjeZGi-onOxEMb8UBxqI9NexKzccHuYEe8-0Lig?e=a4SxCU), [LongVideoBench](https://1drv.ms/u/c/f029f6f5a52c17c4/EQp_PABeb3ZIiysjIn-_5gEBbkhtfcBwCM1pel9xl3JHPg?e=TLpQXQ) and [EgoSchema](https://1drv.ms/u/c/f029f6f5a52c17c4/Ec0oEX3tO5pIknRdEqT9LDQB0hbS9vR9fUJaVbRfCQPJKg?e=bszgh6). - **2025/08/02**: Support auto subtitle in the demo. - **2025/07/17**: Add gradio demo. From 68ad71428545b690224094bcbd6e66faa0ce18f1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 Oct 2025 14:14:16 +0000 Subject: [PATCH 2/4] update reproduce guide --- README.md | 1 + dvd/dvd_core.py | 21 ++++ dvd/utils.py | 9 +- reproduce/REPRODUCE.md | 29 +++++ reproduce/decode_frames.py | 221 ++++++++++++++++++++++++++++++++++ reproduce/download_lvbench.sh | 5 + reproduce/prepare_database.py | 52 ++++++++ reproduce/run_benchmark.py | 57 +++++++++ 8 files changed, 391 insertions(+), 4 deletions(-) create mode 100644 reproduce/REPRODUCE.md create mode 100644 reproduce/decode_frames.py create mode 100644 reproduce/download_lvbench.sh create mode 100644 reproduce/prepare_database.py create mode 100644 reproduce/run_benchmark.py diff --git a/README.md b/README.md index 58aaebf..805df16 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ This repository contains the official implementation of the paper [Deep Video Di ## Update +- **2025/10/15**: Add a markdown to help to reproduce the LVBench results: [REPRODUCE.md](reproduce/REPRODUCE.md). Email me (xiaoyizhang@microsoft) if your issue get no response in 24 hours. - **2025/09/19**: Accepted by NeurIPS 2025 🎉 - **2025/08/04**: Upload captions on benchmarks for reproduction: [LVBench](https://1drv.ms/u/c/f029f6f5a52c17c4/ETR7ogx7YCtBgtDu66a4R14B7RKLZJoz20D4Z5I1KD6HTg?e=404kKg), [LVBench w/ transcripts](https://1drv.ms/u/c/f029f6f5a52c17c4/EcqO2lC_hRxGn-0t0IBNKZcBts3HDCEg8mZo4ltN6kXFUQ?e=XmabUn), [Video-MME](https://1drv.ms/u/c/f029f6f5a52c17c4/EVKjXQnPjeZGi-onOxEMb8UBxqI9NexKzccHuYEe8-0Lig?e=a4SxCU), [LongVideoBench](https://1drv.ms/u/c/f029f6f5a52c17c4/EQp_PABeb3ZIiysjIn-_5gEBbkhtfcBwCM1pel9xl3JHPg?e=TLpQXQ) and [EgoSchema](https://1drv.ms/u/c/f029f6f5a52c17c4/Ec0oEX3tO5pIknRdEqT9LDQB0hbS9vR9fUJaVbRfCQPJKg?e=bszgh6). - **2025/08/02**: Support auto subtitle in the demo. diff --git a/dvd/dvd_core.py b/dvd/dvd_core.py index 630d181..db5f014 100644 --- a/dvd/dvd_core.py +++ b/dvd/dvd_core.py @@ -9,6 +9,7 @@ from dvd.func_call_shema import as_json_schema from dvd.func_call_shema import doc as D from dvd.utils import call_openai_model_with_tools +from concurrent.futures import ThreadPoolExecutor, as_completed TOPK = 16 @@ -158,6 +159,26 @@ def run(self, question) -> list[dict]: return msgs + def parallel_run(self, questions, max_workers=4) -> list[list[dict]]: + """ + Run multiple questions in parallel. + """ + results = [] + results = [None] * len(questions) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_index = { + executor.submit(self.run, q): idx + for idx, q in enumerate(questions) + } + for future in as_completed(future_to_index): + idx = future_to_index[future] + try: + results[idx] = future.result() + except Exception as e: + print(f"Error processing question: {e}") + results[idx] = None + return results + # ------------------------------------------------------------------ # # Streaming (generator) loop # ------------------------------------------------------------------ # diff --git a/dvd/utils.py b/dvd/utils.py index 077cb28..663bbf4 100644 --- a/dvd/utils.py +++ b/dvd/utils.py @@ -222,16 +222,17 @@ def extract_answer(message: dict) -> str | None: str | None The extracted answer, or ``None`` if no answer could be found. """ - # Direct text response - if (content := message.get("content")): - return content.strip() - # Tool-based response for call in message.get("tool_calls", []): args_json = call["function"]["arguments"] args = json.loads(args_json) if (answer := args.get("answer")): return answer + + # Direct text response + if (content := message.get("content")): + return content.strip() + return None diff --git a/reproduce/REPRODUCE.md b/reproduce/REPRODUCE.md new file mode 100644 index 0000000..53d547d --- /dev/null +++ b/reproduce/REPRODUCE.md @@ -0,0 +1,29 @@ +# Reproduce + +0. Setup database root path with `export DATABASE_DIR=/path/to/your/database/folder`. + +1. Download the pre-built database from [here](https://huggingface.co/datasets/xyzhang626/LongVideoBenchmarkCaptions/tree/main). Or you can use the script `wget https://huggingface.co/datasets/xyzhang626/LongVideoBenchmarkCaptions/resolve/main/LVBench_4.1.zip` to download the database. + +2. Prepare the database json files. You can use the script in `prepare_lvbench_db.py` to prepare the database json files. Please modify the path to your downloaded LVBench database. It will generate the database json files into `$DATABASE_DIR/LVBench_4.1`. + +```bash +python -m reproduce.prepare_database /path/to/your/zipfile $DATABASE_DIR +``` + +3. Download LVBench dataset, you could find this 3rd party assets in [here](https://huggingface.co/datasets/AIWinter/LVBench/tree/main). Or you can use the script to download the dataset. + +```bash +export TARGET_DIR=$DATABASE_DIR/LVBench_4.1 +bash reproduce/download_lvbench.sh +``` + +4. Decode the videos into raw frames, you could use the script in `decode_frames.py`, please modify the path to your downloaded LVBench dataset. + +```bash +python -m reproduce.decode_frames --part $DATABASE_DIR/LVBench_4.1/all_videos_split.zip.001 --out $TARGET_DIR --fps 2 +``` + +5. Run the benchmark. You can use the script in `run_benchmark.py` to run the benchmark. Please modify the path to your prepared database json files. +```bash +python -m reproduce.run_benchmark $TARGET_DIR $TARGET_DIR/video_info.meta.jsonl +``` diff --git a/reproduce/decode_frames.py b/reproduce/decode_frames.py new file mode 100644 index 0000000..6788b49 --- /dev/null +++ b/reproduce/decode_frames.py @@ -0,0 +1,221 @@ +import argparse +import os +import re +import sys +import tempfile +import zipfile +import shutil +import logging +from pathlib import Path +from typing import List, Iterable +import cv2 +from tqdm import tqdm +import multiprocessing as mp + +VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.flv', '.mpg', '.mpeg', '.m4v'} + + +def find_all_parts(part_path: Path) -> List[Path]: + """ + Given any split archive file, find all parts in the same directory and return them in order. + Supports *.zip.001 / *.zip.002 or *.z01 / *.z02 + .zip formats (compatible with common naming). + """ + name = part_path.name + parent = part_path.parent + + # Match something like something.zip.001 or something.zip.002 + m = re.match(r'(.+\.zip)\.(\d{3})$', name) + if m: + base = m.group(1) + parts = sorted(parent.glob(base + ".???")) + return parts + + # Match WinZip style: something.z01, something.z02 ... + something.zip + m2 = re.match(r'(.+)\.z(\d{2})$', name, re.IGNORECASE) + if m2: + base_prefix = m2.group(1) + zparts = sorted(parent.glob(base_prefix + ".z??"), key=lambda p: p.suffix.lower()) + main_zip = parent / (base_prefix + ".zip") + if main_zip.exists(): + return zparts + [main_zip] + + # If it's a complete zip file + if name.endswith(".zip"): + return [part_path] + + raise ValueError(f"Unrecognized split archive naming: {name}") + + +def assemble_zip(parts: List[Path]) -> Path: + """ + Concatenate all parts in order into a temporary zip file and return its path. + If there is only one .zip file, return its path directly (no copy). + """ + if len(parts) == 1 and parts[0].suffix == ".zip": + return parts[0] + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip", prefix="merged_zip_") + os.close(tmp_fd) + with open(tmp_path, "wb") as w: + for p in parts: + logging.info(f"Merging part: {p.name}") + with open(p, "rb") as r: + shutil.copyfileobj(r, w, length=1024 * 1024) + return Path(tmp_path) + + +def iter_video_members(zf: zipfile.ZipFile) -> Iterable[zipfile.ZipInfo]: + for info in zf.infolist(): + if info.is_dir(): + continue + ext = Path(info.filename).suffix.lower() + if ext in VIDEO_EXTS: + yield info + + +def ensure_dir(path: Path): + path.mkdir(parents=True, exist_ok=True) + + +def decode_video(temp_video_path: Path, out_root: Path, fps: float, overwrite: bool = False, video_stem: str | None = None): + """ + Extract frames at the given fps and save them. + video_stem: Pass the original video filename (without extension) to avoid using the temporary filename. + """ + cap = cv2.VideoCapture(str(temp_video_path)) + if not cap.isOpened(): + logging.error(f"Cannot open video: {temp_video_path}") + return + + orig_fps = cap.get(cv2.CAP_PROP_FPS) or 0 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) + if orig_fps <= 0: + logging.warning(f"Original FPS abnormal ({orig_fps}), will estimate by frame time.") + orig_fps = fps # fallback + + interval = 1.0 / fps + next_t = 0.0 + frame_index = 0 + saved_index = 0 + + # Use original filename (if provided) + video_stem = video_stem or temp_video_path.stem + frames_dir = out_root / video_stem / "frames" + ensure_dir(frames_dir) + + if not overwrite: + # Count existing frames, auto-continue + existing = sorted(frames_dir.glob("frames_n*.jpg")) + if existing: + last = existing[-1].stem + m = re.search(r'frames_n(\d+)', last) + if m: + saved_index = int(m.group(1)) + logging.info(f"{video_stem}: Append mode, {saved_index} frames already exist.") + + pbar = tqdm(total=total_frames if total_frames > 0 else None, + desc=f"Decoding {video_stem}", + unit="f", + dynamic_ncols=True) + + while True: + ret, frame = cap.read() + if not ret: + break + current_t = frame_index / orig_fps + if current_t + 1e-6 >= next_t: + saved_index += 1 + out_path = frames_dir / f"frames_n{saved_index:06d}.jpg" + if overwrite or not out_path.exists(): + cv2.imwrite(str(out_path), frame, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) + next_t += interval + frame_index += 1 + pbar.update(1) + pbar.close() + cap.release() + logging.info(f"{video_stem}: Frame extraction complete, total (including existing) {saved_index} frames.") + + +def process_archive(part_path: Path, out_root: Path, fps: float, overwrite: bool): + parts = find_all_parts(part_path) + logging.info("Found parts in order: " + ", ".join(p.name for p in parts)) + merged_zip = assemble_zip(parts) + cleanup_needed = merged_zip not in parts # If we generated a temporary file + try: + with zipfile.ZipFile(merged_zip, 'r') as zf: + video_members = list(iter_video_members(zf)) + logging.info(f"Number of video files in archive: {len(video_members)}") + temp_files = [] # (process, tmp_path) + max_workers = min(len(video_members), mp.cpu_count() or 1) + logging.info(f"Parallel decoding: using {max_workers} processes") + + def wait_one(): + # Wait for the earliest started process to finish and clean up temp file + proc, tpath = temp_files.pop(0) + proc.join() + try: + tpath.unlink(missing_ok=True) + except Exception: + pass + + for info in video_members: + # Write this video to a temporary file + suffix = Path(info.filename).suffix + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp.write(zf.read(info)) + tmp_path = Path(tmp.name) + + original_stem = Path(info.filename).stem + # Start a separate process for decoding, passing the original filename + p = mp.Process(target=decode_video, args=(tmp_path, out_root, fps), kwargs={'overwrite': overwrite, 'video_stem': original_stem}) + p.start() + temp_files.append((p, tmp_path)) + + # If reached parallel limit, wait for the earliest one to finish + if len(temp_files) >= max_workers: + wait_one() + + # Wait for all remaining to finish + while temp_files: + wait_one() + finally: + if cleanup_needed: + try: + merged_zip.unlink(missing_ok=True) + except Exception: + pass + + +def parse_args(): + ap = argparse.ArgumentParser(description="Extract video frames from split zip archives at a given fps") + ap.add_argument("--part", required=True, help="Path to any split archive file (e.g. all_videos_split.zip.002)") + ap.add_argument("--out", required=True, help="Output root directory") + ap.add_argument("--fps", type=float, required=True, help="Target frame extraction rate (e.g. 5)") + ap.add_argument("--overwrite", action="store_true", help="Overwrite existing frames") + ap.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]) + return ap.parse_args() + +def main(): + args = parse_args() + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s [%(levelname)s] %(message)s", + ) + + part_path = Path(args.part).expanduser().resolve() + out_root = Path(args.out).expanduser().resolve() + ensure_dir(out_root) + + if args.fps <= 0: + logging.error("fps must be > 0") + sys.exit(1) + + if not part_path.exists(): + logging.error(f"File does not exist: {part_path}") + sys.exit(1) + + process_archive(part_path, out_root, args.fps, overwrite=args.overwrite) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/reproduce/download_lvbench.sh b/reproduce/download_lvbench.sh new file mode 100644 index 0000000..999d999 --- /dev/null +++ b/reproduce/download_lvbench.sh @@ -0,0 +1,5 @@ +mkdir -p "$TARGET_DIR" +printf '%03d\n' {3..14} | \ +wget --continue -P "$TARGET_DIR" "https://huggingface.co/datasets/zai-org/LVBench/resolve/main/video_info.meta.jsonl" +xargs -P 8 -I{} wget --continue -P "$TARGET_DIR" \ + "https://huggingface.co/datasets/AIWinter/LVBench/resolve/main/all_videos_split.zip.{}" diff --git a/reproduce/prepare_database.py b/reproduce/prepare_database.py new file mode 100644 index 0000000..d37230e --- /dev/null +++ b/reproduce/prepare_database.py @@ -0,0 +1,52 @@ +import os +import json +import zipfile +from pathlib import Path +import argparse + +def replace_root_path(zip_file_path, database_dir): + """ + Read a zip file, replace 'video_file_root' in JSON files, and save to the specified directory. + + Args: + zip_file_path: Path to the zip file. + database_dir: Directory for the database. + """ + zip_file_name = Path(zip_file_path).stem + + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + # Iterate through all files in the zip + for file_name in zip_ref.namelist(): + if file_name.endswith('.json'): + # Read the JSON file + with zip_ref.open(file_name) as json_file: + data = json.load(json_file) + + # Replace video_file_root + new_root = os.path.join(database_dir, zip_file_name) + data['video_file_root'] = new_root + + # Create output directory + json_name = Path(file_name).stem + output_dir = os.path.join(database_dir, zip_file_name, json_name) + os.makedirs(output_dir, exist_ok=True) + + # Save the JSON file + output_path = os.path.join(output_dir, 'database.json') + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + print(f"Processed: {file_name} -> {output_path}") + +if __name__ == "__main__": + # Example usage + parser = argparse.ArgumentParser(description='Replace video_file_root path in JSON files inside a zip archive') + parser.add_argument('zip_file', type=str, help='Path to the zip file') + parser.add_argument('database_dir', type=str, help='Path to the database directory') + + args = parser.parse_args() + + zip_file = args.zip_file + database_dir = args.database_dir + + replace_root_path(zip_file, database_dir) \ No newline at end of file diff --git a/reproduce/run_benchmark.py b/reproduce/run_benchmark.py new file mode 100644 index 0000000..36d630d --- /dev/null +++ b/reproduce/run_benchmark.py @@ -0,0 +1,57 @@ +import dvd.config as config +import os +import argparse +import json +from dvd.dvd_core import DVDCoreAgent +from dvd.video_utils import load_video, decode_video_to_frames, download_srt_subtitle +from dvd.frame_caption import process_video, process_video_lite +from dvd.utils import extract_answer + +def main(): + parser = argparse.ArgumentParser(description="Run DVDCoreAgent on a video.") + parser.add_argument("benchmark_database_folder", help="The path to the benchmark database folder.") + parser.add_argument("benchmark_metadata", help="The path to the benchmark metadata file.") + args = parser.parse_args() + + benchmark_database_folder = args.benchmark_database_folder + + with open(args.benchmark_metadata, "r") as f: + lines = f.readlines() + + total_data = [] + results = {} + for line in lines: + # one line for one video instance containing multiple questions + video_info = json.loads(line) + video_id = video_info["key"] + qa_list = video_info["qa"] + + qids = [qa["uid"] for qa in qa_list] + questions = [qa["question"] for qa in qa_list] + + frames_dir = os.path.join(benchmark_database_folder, video_id, "frames") + if not os.path.exists(frames_dir) or len(os.listdir(frames_dir)) == 0: + print(f"Frames for video {frames_dir} not found, skipping...") + continue + video_db_path = os.path.join(benchmark_database_folder, video_id, "database.json") + + print(f"Initializing DVDCoreAgent from database {video_db_path}...") + agent = DVDCoreAgent(video_db_path, video_caption_path=None, max_iterations=15) + agent.messages[-1]['content'] += "\nSelect the best option that accurately addresses the question.\nAnswer with the option\'s letter from the given choices directly and only give the best option." + print("Agent initialized.") + # Run with questions + msgs = agent.parallel_run(questions, max_workers=4) + for qid, question, msg in zip(qids, questions, msgs): + answer = extract_answer(msg[-1]) + results[qid] = { + "question": question, + "answer": answer, + "reasoning": msg + } + + with open("benchmark_results.json", "w") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + +if __name__ == "__main__": + main() + From 6e0247a4e30baf143e5dba5dd7135cf7c3318058 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 Oct 2025 14:22:34 +0000 Subject: [PATCH 3/4] update reproduce README --- README.md | 2 +- reproduce/{REPRODUCE.md => README.md} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename reproduce/{REPRODUCE.md => README.md} (100%) diff --git a/README.md b/README.md index 805df16..332dd9c 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ This repository contains the official implementation of the paper [Deep Video Di ## Update -- **2025/10/15**: Add a markdown to help to reproduce the LVBench results: [REPRODUCE.md](reproduce/REPRODUCE.md). Email me (xiaoyizhang@microsoft) if your issue get no response in 24 hours. +- **2025/10/15**: Add a markdown to help to reproduce the LVBench results: [REPRODUCE.md](reproduce/README.md). Email me (xiaoyizhang [/at/] microsoft.com) if your issue gets no response in 24 hours. - **2025/09/19**: Accepted by NeurIPS 2025 🎉 - **2025/08/04**: Upload captions on benchmarks for reproduction: [LVBench](https://1drv.ms/u/c/f029f6f5a52c17c4/ETR7ogx7YCtBgtDu66a4R14B7RKLZJoz20D4Z5I1KD6HTg?e=404kKg), [LVBench w/ transcripts](https://1drv.ms/u/c/f029f6f5a52c17c4/EcqO2lC_hRxGn-0t0IBNKZcBts3HDCEg8mZo4ltN6kXFUQ?e=XmabUn), [Video-MME](https://1drv.ms/u/c/f029f6f5a52c17c4/EVKjXQnPjeZGi-onOxEMb8UBxqI9NexKzccHuYEe8-0Lig?e=a4SxCU), [LongVideoBench](https://1drv.ms/u/c/f029f6f5a52c17c4/EQp_PABeb3ZIiysjIn-_5gEBbkhtfcBwCM1pel9xl3JHPg?e=TLpQXQ) and [EgoSchema](https://1drv.ms/u/c/f029f6f5a52c17c4/Ec0oEX3tO5pIknRdEqT9LDQB0hbS9vR9fUJaVbRfCQPJKg?e=bszgh6). - **2025/08/02**: Support auto subtitle in the demo. diff --git a/reproduce/REPRODUCE.md b/reproduce/README.md similarity index 100% rename from reproduce/REPRODUCE.md rename to reproduce/README.md From 000ac61ad38fde9d4f4e3ef2399282f037abe228 Mon Sep 17 00:00:00 2001 From: anonymous626 <131758638+anonymous626@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:40:40 +0800 Subject: [PATCH 4/4] Add audio transcription scrip for reproduce This script transcribes audio files from a specified directory using the WhisperX model, aligns the output, and saves the results in JSON format. --- reproduce/transcribe.py | 53 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 reproduce/transcribe.py diff --git a/reproduce/transcribe.py b/reproduce/transcribe.py new file mode 100644 index 0000000..790f72d --- /dev/null +++ b/reproduce/transcribe.py @@ -0,0 +1,53 @@ +import json +import os +import whisperx +import gc +from whisperx.alignment import DEFAULT_ALIGN_MODELS_TORCH, DEFAULT_ALIGN_MODELS_HF + +device = "cuda" +batch_size = 64 # reduce if low on GPU mem +compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy) + +# 1. Transcribe with original whisper (batched) +model = whisperx.load_model("large-v3", device, compute_type=compute_type) +# 3. Assign speaker labels +HF_TOKEN = "YOUR_HF_TOKEN" +# diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device) + +root = "./lvbench_vdb" +for file in os.listdir(root): + # save model to local path (optional) + # model_dir = "/path/" + # model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir) + + if not file.endswith(".mp3"): + continue + + audio_file = os.path.join(root, file) + + if os.path.exists(audio_file.replace(".mp3", ".json")): + print(f"File {audio_file.replace('.mp3', '.json')} already exists, skipping...") + with open(audio_file.replace(".mp3", ".json"), "r") as f: + legacy_result = json.load(f) + else: + legacy_result = None + + audio = whisperx.load_audio(audio_file) + result = model.transcribe(audio, batch_size=batch_size) + + if result["language"] in DEFAULT_ALIGN_MODELS_TORCH or \ + result["language"] in DEFAULT_ALIGN_MODELS_HF: + lang = result["language"] + else: + lang = 'en' + print(f"Language {result['language']} not supported, using English instead for {audio_file}.") + + # 2. Align whisper output + model_a, metadata = whisperx.load_align_model(language_code=lang, device=device) + result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) + + with open(audio_file.replace(".mp3", ".json"), "w") as f: + json.dump(result, f, indent=4) + print(f"saved as {audio_file.replace('.mp3', '.json')}") + +