feat: enhance transcription capabilities with MLX support and backend detection

This commit is contained in:
2026-04-04 00:32:36 +02:00
parent f7d621e510
commit e29572420e
3 changed files with 362 additions and 41 deletions

View File

@@ -1,5 +1,6 @@
import os
import sys
import platform
import datetime
import time
import site
@@ -66,16 +67,124 @@ SUPPORTED_EXTENSIONS = {
}
def _detect_device():
"""Return (device, compute_type) for the best available backend."""
# ---------------------------------------------------------------------------
# MLX model map (Apple Silicon only)
# ---------------------------------------------------------------------------
_MLX_MODEL_MAP = {
"tiny": "mlx-community/whisper-tiny-mlx",
"base": "mlx-community/whisper-base-mlx",
"small": "mlx-community/whisper-small-mlx",
"medium": "mlx-community/whisper-medium-mlx",
"large-v2": "mlx-community/whisper-large-v2-mlx",
"large-v3": "mlx-community/whisper-large-v3-mlx",
}
def detect_backend():
"""Return the best available inference backend.
Returns a dict with keys:
backend : "mlx" | "cuda" | "cpu"
device : device string for WhisperModel (cuda / cpu)
compute_type : compute type string for WhisperModel
label : human-readable label for UI display
"""
# Apple Silicon → try MLX (GPU + Neural Engine via Apple MLX)
if sys.platform == "darwin" and platform.machine() == "arm64":
try:
import mlx_whisper # noqa: F401
return {
"backend": "mlx",
"device": "cpu",
"compute_type": "int8",
"label": "MLX · Apple GPU/NPU",
}
except ImportError:
pass
# NVIDIA CUDA
try:
import ctranslate2
cuda_types = ctranslate2.get_supported_compute_types("cuda")
if "float16" in cuda_types:
return "cuda", "float16"
return {
"backend": "cuda",
"device": "cuda",
"compute_type": "float16",
"label": "CUDA · GPU",
}
except Exception:
pass
return "cpu", "int8"
return {
"backend": "cpu",
"device": "cpu",
"compute_type": "int8",
"label": "CPU · int8",
}
def _decode_audio_pyav(file_path):
"""Decode any audio/video file to a float32 mono 16 kHz numpy array.
Uses PyAV (bundled FFmpeg) — no external ffmpeg binary required.
Returns (audio_array, duration_seconds).
"""
import av
import numpy as np
with av.open(file_path) as container:
duration = float(container.duration) / 1_000_000 # microseconds → seconds
stream = container.streams.audio[0]
resampler = av.AudioResampler(format="fltp", layout="mono", rate=16000)
chunks = []
for frame in container.decode(stream):
for out in resampler.resample(frame):
if out:
chunks.append(out.to_ndarray()[0])
# Flush resampler
for out in resampler.resample(None):
if out:
chunks.append(out.to_ndarray()[0])
if not chunks:
return np.zeros(0, dtype=np.float32), duration
return np.concatenate(chunks, axis=0), duration
def _transcribe_mlx_file(file, mlx_model_id, language, timestamps, verbose):
"""Transcribe a single file with mlx-whisper (Apple GPU/NPU).
Decodes audio via PyAV (no system ffmpeg needed), then runs MLX inference.
Returns (segments_as_dicts, audio_duration_seconds).
Segments have dict keys: 'start', 'end', 'text'.
"""
import mlx_whisper
audio_array, duration = _decode_audio_pyav(file)
decode_opts = {}
if language:
decode_opts["language"] = language
result = mlx_whisper.transcribe(
audio_array,
path_or_hf_repo=mlx_model_id,
verbose=(True if verbose else None),
**decode_opts,
)
segments = result["segments"]
audio_duration = segments[-1]["end"] if segments else duration
return segments, audio_duration
def _srt_timestamp(seconds):
"""Convert seconds (float) to SRT timestamp format HH:MM:SS,mmm."""
ms = round(seconds * 1000)
h, ms = divmod(ms, 3_600_000)
m, ms = divmod(ms, 60_000)
s, ms = divmod(ms, 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
# Get the path
@@ -91,7 +200,7 @@ def get_path(path):
return sorted(media_files)
# Main function
def transcribe(path, glob_file, model=None, language=None, verbose=False, timestamps=True):
def transcribe(path, glob_file, model=None, language=None, verbose=False, timestamps=True, stop_event=None):
"""
Transcribes audio files in a specified folder using faster-whisper (CTranslate2).
@@ -122,10 +231,98 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest
SEP = "" * 46
# ── Step 1: Detect hardware ──────────────────────────────────────
device, compute_type = _detect_device()
print(f"⚙ Device: {device} | Compute: {compute_type}")
backend_info = detect_backend()
backend = backend_info["backend"]
device = backend_info["device"]
compute_type = backend_info["compute_type"]
print(f"⚙ Backend: {backend_info['label']}")
# ── Step 2: Load model ───────────────────────────────────────────
# ── Step 1b: MLX path (Apple GPU/NPU) ───────────────────────────
if backend == "mlx":
mlx_model_id = _MLX_MODEL_MAP.get(model)
if mlx_model_id is None:
print(f"⚠ Model '{model}' is not available in MLX format.")
print(" Falling back to faster-whisper on CPU (int8).")
backend = "cpu"
device, compute_type = "cpu", "int8"
else:
# ── Step 2 (MLX): load + transcribe ─────────────────────
print(f"⏳ Loading MLX model '{model}' — downloading if needed...")
print("✅ Model ready!")
print(SEP)
total_files = len(glob_file)
print(f"📂 Found {total_files} supported media file(s) in folder")
print(SEP)
if total_files == 0:
output_text = '⚠ No supported media files found — try another folder.'
print(output_text)
print(SEP)
return output_text
files_transcripted = []
file_num = 0
for file in glob_file:
if stop_event and stop_event.is_set():
print("⛔ Transcription stopped by user.")
break
title = os.path.basename(file).split('.')[0]
file_num += 1
print(f"\n{'' * 46}")
print(f"📄 File {file_num}/{total_files}: {title}")
try:
t_start = time.time()
segments, audio_duration = _transcribe_mlx_file(
file, mlx_model_id, language, timestamps, verbose
)
os.makedirs('{}/transcriptions'.format(path), exist_ok=True)
segment_list = []
txt_path = "{}/transcriptions/{}.txt".format(path, title)
srt_path = "{}/transcriptions/{}.srt".format(path, title)
with open(txt_path, 'w', encoding='utf-8') as f, \
open(srt_path, 'w', encoding='utf-8') as srt_f:
f.write(title)
f.write('\n' + '' * 40 + '\n')
for idx, seg in enumerate(segments, start=1):
if stop_event and stop_event.is_set():
break
text = seg["text"].strip()
if timestamps:
start_ts = str(datetime.timedelta(seconds=seg["start"]))
end_ts = str(datetime.timedelta(seconds=seg["end"]))
f.write('\n[{} --> {}] {}'.format(start_ts, end_ts, text))
else:
f.write('\n{}'.format(text))
srt_f.write(f'{idx}\n{_srt_timestamp(seg["start"])} --> {_srt_timestamp(seg["end"])}\n{text}\n\n')
f.flush()
srt_f.flush()
if verbose:
print(" [%.2fs → %.2fs] %s" % (seg["start"], seg["end"], seg["text"]))
else:
print(" Transcribed up to %.0fs..." % seg["end"], end='\r')
segment_list.append(seg)
elapsed = time.time() - t_start
elapsed_min = elapsed / 60.0
audio_min = audio_duration / 60.0
ratio = audio_duration / elapsed if elapsed > 0 else float('inf')
print(f"✅ Done — saved to transcriptions/{title}.txt")
print(f"⏱ Transcribed {audio_min:.1f} min of audio in {elapsed_min:.1f} min ({ratio:.1f}x realtime)")
files_transcripted.append(segment_list)
except Exception as exc:
print(f"⚠ Could not decode '{os.path.basename(file)}', skipping.")
print(f" Reason: {exc}")
print(f"\n{SEP}")
if files_transcripted:
output_text = f"✅ Finished! {len(files_transcripted)} file(s) transcribed.\n Saved in: {path}/transcriptions"
else:
output_text = '⚠ No files eligible for transcription — try another folder.'
print(output_text)
print(SEP)
return output_text
# ── Step 2: Load model (faster-whisper / CTranslate2) ───────────
print(f"⏳ Loading model '{model}' — downloading if needed...")
try:
whisper_model = WhisperModel(model, device=device, compute_type=compute_type)
@@ -164,6 +361,9 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest
files_transcripted = []
file_num = 0
for file in glob_file:
if stop_event and stop_event.is_set():
print("⛔ Transcription stopped by user.")
break
title = os.path.basename(file).split('.')[0]
file_num += 1
print(f"\n{'' * 46}")
@@ -180,10 +380,15 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest
os.makedirs('{}/transcriptions'.format(path), exist_ok=True)
# Stream segments as they are decoded
segment_list = []
with open("{}/transcriptions/{}.txt".format(path, title), 'w', encoding='utf-8') as f:
txt_path = "{}/transcriptions/{}.txt".format(path, title)
srt_path = "{}/transcriptions/{}.srt".format(path, title)
with open(txt_path, 'w', encoding='utf-8') as f, \
open(srt_path, 'w', encoding='utf-8') as srt_f:
f.write(title)
f.write('\n' + '' * 40 + '\n')
for seg in segments:
for idx, seg in enumerate(segments, start=1):
if stop_event and stop_event.is_set():
break
text = seg.text.strip()
if timestamps:
start_ts = str(datetime.timedelta(seconds=seg.start))
@@ -191,7 +396,9 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest
f.write('\n[{} --> {}] {}'.format(start_ts, end_ts, text))
else:
f.write('\n{}'.format(text))
srt_f.write(f'{idx}\n{_srt_timestamp(seg.start)} --> {_srt_timestamp(seg.end)}\n{text}\n\n')
f.flush()
srt_f.flush()
if verbose:
print(" [%.2fs → %.2fs] %s" % (seg.start, seg.end, seg.text))
else: