diff --git a/app.py b/app.py index ec23a5c..f54aae6 100644 --- a/app.py +++ b/app.py @@ -4,7 +4,7 @@ import tkinter as tk from tkinter import ttk from tkinter import filedialog from tkinter import messagebox -from src._LocalTranscribe import transcribe, get_path +from src._LocalTranscribe import transcribe, get_path, detect_backend import customtkinter import threading @@ -46,11 +46,93 @@ HF_MODEL_MAP = { 'KB Swedish (large)': 'KBLab/kb-whisper-large', } +# Per-model info shown in the UI description label +# (speed, size, quality stars, suggested use) +MODEL_INFO = { + 'tiny': ('Very fast', '~75 MB', '★★☆☆☆', 'Quick drafts & testing'), + 'tiny.en': ('Very fast', '~75 MB', '★★☆☆☆', 'Quick drafts & testing (English only)'), + 'base': ('Fast', '~145 MB', '★★★☆☆', 'Notes & short podcasts'), + 'base.en': ('Fast', '~145 MB', '★★★☆☆', 'Notes & short podcasts (English only)'), + 'small': ('Balanced', '~485 MB', '★★★★☆', 'Everyday use'), + 'small.en': ('Balanced', '~485 MB', '★★★★☆', 'Everyday use (English only)'), + 'medium': ('Accurate', '~1.5 GB', '★★★★☆', 'Professional content'), + 'medium.en': ('Accurate', '~1.5 GB', '★★★★☆', 'Professional content (English only)'), + 'large-v2': ('Slow', '~3 GB', '★★★★★', 'Maximum accuracy'), + 'large-v3': ('Slow', '~3 GB', '★★★★★', 'Maximum accuracy (recommended)'), + 'KB Swedish (tiny)': ('Very fast', '~75 MB', '★★★☆☆', 'Swedish — optimised by KBLab'), + 'KB Swedish (base)': ('Fast', '~145 MB', '★★★☆☆', 'Swedish — optimised by KBLab'), + 'KB Swedish (small)': ('Balanced', '~485 MB', '★★★★☆', 'Swedish — optimised by KBLab'), + 'KB Swedish (medium)': ('Accurate', '~1.5 GB', '★★★★☆', 'Swedish — optimised by KBLab'), + 'KB Swedish (large)': ('Slow', '~3 GB', '★★★★★', 'Swedish — KBLab, best accuracy'), +} + customtkinter.set_appearance_mode("System") customtkinter.set_default_color_theme("blue") # Themes: blue (default), dark-blue, green -firstclick = True + +# All languages supported by Whisper (display label → ISO code; None = auto-detect) +WHISPER_LANGUAGES = { + 'Auto-detect': None, + 'Afrikaans (af)': 'af', 'Albanian (sq)': 'sq', + 'Amharic (am)': 'am', 'Arabic (ar)': 'ar', + 'Armenian (hy)': 'hy', 'Assamese (as)': 'as', + 'Azerbaijani (az)': 'az', 'Bashkir (ba)': 'ba', + 'Basque (eu)': 'eu', 'Belarusian (be)': 'be', + 'Bengali (bn)': 'bn', 'Bosnian (bs)': 'bs', + 'Breton (br)': 'br', 'Bulgarian (bg)': 'bg', + 'Catalan (ca)': 'ca', 'Chinese (zh)': 'zh', + 'Croatian (hr)': 'hr', 'Czech (cs)': 'cs', + 'Danish (da)': 'da', 'Dutch (nl)': 'nl', + 'English (en)': 'en', 'Estonian (et)': 'et', + 'Faroese (fo)': 'fo', 'Finnish (fi)': 'fi', + 'French (fr)': 'fr', 'Galician (gl)': 'gl', + 'Georgian (ka)': 'ka', 'German (de)': 'de', + 'Greek (el)': 'el', 'Gujarati (gu)': 'gu', + 'Haitian Creole (ht)': 'ht', 'Hausa (ha)': 'ha', + 'Hawaiian (haw)': 'haw', 'Hebrew (he)': 'he', + 'Hindi (hi)': 'hi', 'Hungarian (hu)': 'hu', + 'Icelandic (is)': 'is', 'Indonesian (id)': 'id', + 'Italian (it)': 'it', 'Japanese (ja)': 'ja', + 'Javanese (jw)': 'jw', 'Kannada (kn)': 'kn', + 'Kazakh (kk)': 'kk', 'Khmer (km)': 'km', + 'Korean (ko)': 'ko', 'Lao (lo)': 'lo', + 'Latin (la)': 'la', 'Latvian (lv)': 'lv', + 'Lingala (ln)': 'ln', 'Lithuanian (lt)': 'lt', + 'Luxembourgish (lb)': 'lb', 'Macedonian (mk)': 'mk', + 'Malagasy (mg)': 'mg', 'Malay (ms)': 'ms', + 'Malayalam (ml)': 'ml', 'Maltese (mt)': 'mt', + 'Maori (mi)': 'mi', 'Marathi (mr)': 'mr', + 'Mongolian (mn)': 'mn', 'Myanmar (my)': 'my', + 'Nepali (ne)': 'ne', 'Norwegian (no)': 'no', + 'Occitan (oc)': 'oc', 'Pashto (ps)': 'ps', + 'Persian (fa)': 'fa', 'Polish (pl)': 'pl', + 'Portuguese (pt)': 'pt', 'Punjabi (pa)': 'pa', + 'Romanian (ro)': 'ro', 'Russian (ru)': 'ru', + 'Sanskrit (sa)': 'sa', 'Serbian (sr)': 'sr', + 'Shona (sn)': 'sn', 'Sindhi (sd)': 'sd', + 'Sinhala (si)': 'si', 'Slovak (sk)': 'sk', + 'Slovenian (sl)': 'sl', 'Somali (so)': 'so', + 'Spanish (es)': 'es', 'Sundanese (su)': 'su', + 'Swahili (sw)': 'sw', 'Swedish (sv)': 'sv', + 'Tagalog (tl)': 'tl', 'Tajik (tg)': 'tg', + 'Tamil (ta)': 'ta', 'Tatar (tt)': 'tt', + 'Telugu (te)': 'te', 'Thai (th)': 'th', + 'Tibetan (bo)': 'bo', 'Turkish (tr)': 'tr', + 'Turkmen (tk)': 'tk', 'Ukrainian (uk)': 'uk', + 'Urdu (ur)': 'ur', 'Uzbek (uz)': 'uz', + 'Vietnamese (vi)': 'vi', 'Welsh (cy)': 'cy', + 'Yiddish (yi)': 'yi', 'Yoruba (yo)': 'yo', +} + + +def _language_options_for_model(model_name): + """Return (values, default, state) for the language combobox given a model name.""" + if model_name.endswith('.en'): + return ['English (en)'], 'English (en)', 'disabled' + if model_name.startswith('KB Swedish'): + return ['Swedish (sv)'], 'Swedish (sv)', 'disabled' + return list(WHISPER_LANGUAGES.keys()), 'Auto-detect', 'readonly' def _set_app_icon(root): @@ -94,22 +176,16 @@ class App: self.path_entry.insert(0, os.path.join(os.getcwd(), 'sample_audio')) self.path_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) customtkinter.CTkButton(path_frame, text="Browse", command=self.browse, font=font).pack(side=tk.LEFT, padx=5) - # Language frame - #thanks to pommicket from Stackoverflow for this fix - def on_entry_click(event): - """function that gets called whenever entry is clicked""" - global firstclick - if firstclick: # if this is the first time they clicked it - firstclick = False - self.language_entry.delete(0, "end") # delete all the text in the entry + # Language frame language_frame = customtkinter.CTkFrame(master) language_frame.pack(fill=tk.BOTH, padx=10, pady=10) customtkinter.CTkLabel(language_frame, text="Language:", font=font).pack(side=tk.LEFT, padx=5) - self.language_entry = customtkinter.CTkEntry(language_frame, width=50, font=('Roboto', 12, 'italic')) - self.default_language_text = "Enter language (or ignore to auto-detect)" - self.language_entry.insert(0, self.default_language_text) - self.language_entry.bind('', on_entry_click) - self.language_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + _lang_values, _lang_default, _lang_state = _language_options_for_model('medium') + self.language_combobox = customtkinter.CTkComboBox( + language_frame, width=50, state=_lang_state, + values=_lang_values, font=font_b) + self.language_combobox.set(_lang_default) + self.language_combobox.pack(side=tk.LEFT, fill=tk.X, expand=True) # Model frame models = ['tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', @@ -124,9 +200,16 @@ class App: # ComboBox frame self.model_combobox = customtkinter.CTkComboBox( model_frame, width=50, state="readonly", - values=models, font=font_b) + values=models, font=font_b, + command=self._on_model_change) self.model_combobox.set('medium') # Set the default value self.model_combobox.pack(side=tk.LEFT, fill=tk.X, expand=True) + # Model description label + self.model_desc_label = customtkinter.CTkLabel( + master, text=self._model_desc_text('medium'), + font=('Roboto', 11), text_color=('#555555', '#aaaaaa'), + anchor='w') + self.model_desc_label.pack(fill=tk.X, padx=14, pady=(0, 4)) # Timestamps toggle ts_frame = customtkinter.CTkFrame(master) ts_frame.pack(fill=tk.BOTH, padx=10, pady=10) @@ -137,11 +220,17 @@ class App: self.timestamps_switch.pack(side=tk.LEFT, padx=5) # Progress Bar self.progress_bar = ttk.Progressbar(master, length=200, mode='indeterminate') + # Stop event for cancellation + self._stop_event = threading.Event() # Button actions frame button_frame = customtkinter.CTkFrame(master) button_frame.pack(fill=tk.BOTH, padx=10, pady=10) self.transcribe_button = customtkinter.CTkButton(button_frame, text="Transcribe", command=self.start_transcription, font=font) self.transcribe_button.pack(side=tk.LEFT, padx=5, pady=10, fill=tk.X, expand=True) + self.stop_button = customtkinter.CTkButton( + button_frame, text="Stop", command=self._stop_transcription, font=font, + fg_color="#c0392b", hover_color="#922b21", state=tk.DISABLED) + self.stop_button.pack(side=tk.LEFT, padx=5, pady=10, fill=tk.X, expand=True) customtkinter.CTkButton(button_frame, text="Quit", command=master.quit, font=font).pack(side=tk.RIGHT, padx=5, pady=10, fill=tk.X, expand=True) # ── Embedded console / log panel ────────────────────────────────── @@ -156,11 +245,40 @@ class App: sys.stdout = _ConsoleRedirector(self.log_box) sys.stderr = _ConsoleRedirector(self.log_box) + # Backend indicator + _bi = detect_backend() + backend_label = customtkinter.CTkLabel( + master, + text=f"Backend: {_bi['label']}", + font=('Roboto', 11), + text_color=("#555555", "#aaaaaa"), + anchor='e', + ) + backend_label.pack(fill=tk.X, padx=12, pady=(0, 2)) + # Welcome message (shown after redirect so it appears in the panel) print("Welcome to Local Transcribe with Whisper! \U0001f600") print("Transcriptions will be saved automatically.") print("─" * 46) # Helper functions + def _stop_transcription(self): + self._stop_event.set() + self.stop_button.configure(state=tk.DISABLED) + print("⛔ Stop requested — finishing current file…") + + def _model_desc_text(self, model_name): + info = MODEL_INFO.get(model_name) + if not info: + return '' + speed, size, stars, use = info + return f'{stars} {speed} · {size} · {use}' + + def _on_model_change(self, selected): + self.model_desc_label.configure(text=self._model_desc_text(selected)) + values, default, state = _language_options_for_model(selected) + self.language_combobox.configure(values=values, state=state) + self.language_combobox.set(default) + # Browsing def browse(self): initial_dir = os.getcwd() @@ -169,10 +287,10 @@ class App: self.path_entry.insert(0, folder_path) # Start transcription def start_transcription(self): - # Disable transcribe button + self._stop_event.clear() self.transcribe_button.configure(state=tk.DISABLED) - # Start a new thread for the transcription process - threading.Thread(target=self.transcribe_thread).start() + self.stop_button.configure(state=tk.NORMAL) + threading.Thread(target=self.transcribe_thread, daemon=True).start() # Threading def transcribe_thread(self): path = self.path_entry.get() @@ -183,14 +301,8 @@ class App: self.transcribe_button.configure(state=tk.NORMAL) return model = HF_MODEL_MAP.get(model_display, model_display) - language = self.language_entry.get() - # Auto-set Swedish for KB models - is_kb_model = model_display.startswith('KB Swedish') - # Check if the language field has the default text or is empty - if is_kb_model: - language = 'sv' - elif language == self.default_language_text or not language.strip(): - language = None # This is the same as passing nothing + lang_label = self.language_combobox.get() + language = WHISPER_LANGUAGES.get(lang_label, lang_label) if lang_label else None verbose = True # always show transcription progress in the console panel timestamps = self.timestamps_var.get() # Show progress bar @@ -201,16 +313,17 @@ class App: #messagebox.showinfo("Message", "Starting transcription!") # Start transcription try: - output_text = transcribe(path, glob_file, model, language, verbose, timestamps) + output_text = transcribe(path, glob_file, model, language, verbose, timestamps, stop_event=self._stop_event) except UnboundLocalError: messagebox.showinfo("Files not found error!", 'Nothing found, choose another folder.') pass - except ValueError: - messagebox.showinfo("Invalid language name, you might have to clear the default text to continue!") + except ValueError as e: + messagebox.showinfo("Error", str(e)) # Hide progress bar self.progress_bar.stop() self.progress_bar.pack_forget() - # Enable transcribe button + # Restore buttons + self.stop_button.configure(state=tk.DISABLED) self.transcribe_button.configure(state=tk.NORMAL) # Recover output text try: diff --git a/requirements.txt b/requirements.txt index b4144dc..5884552 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ faster-whisper +mlx-whisper customtkinter diff --git a/src/_LocalTranscribe.py b/src/_LocalTranscribe.py index 116ad89..6daf35d 100644 --- a/src/_LocalTranscribe.py +++ b/src/_LocalTranscribe.py @@ -1,5 +1,6 @@ import os import sys +import platform import datetime import time import site @@ -66,16 +67,124 @@ SUPPORTED_EXTENSIONS = { } -def _detect_device(): - """Return (device, compute_type) for the best available backend.""" +# --------------------------------------------------------------------------- +# MLX model map (Apple Silicon only) +# --------------------------------------------------------------------------- +_MLX_MODEL_MAP = { + "tiny": "mlx-community/whisper-tiny-mlx", + "base": "mlx-community/whisper-base-mlx", + "small": "mlx-community/whisper-small-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", +} + + +def detect_backend(): + """Return the best available inference backend. + + Returns a dict with keys: + backend : "mlx" | "cuda" | "cpu" + device : device string for WhisperModel (cuda / cpu) + compute_type : compute type string for WhisperModel + label : human-readable label for UI display + """ + # Apple Silicon → try MLX (GPU + Neural Engine via Apple MLX) + if sys.platform == "darwin" and platform.machine() == "arm64": + try: + import mlx_whisper # noqa: F401 + return { + "backend": "mlx", + "device": "cpu", + "compute_type": "int8", + "label": "MLX · Apple GPU/NPU", + } + except ImportError: + pass + + # NVIDIA CUDA try: import ctranslate2 cuda_types = ctranslate2.get_supported_compute_types("cuda") if "float16" in cuda_types: - return "cuda", "float16" + return { + "backend": "cuda", + "device": "cuda", + "compute_type": "float16", + "label": "CUDA · GPU", + } except Exception: pass - return "cpu", "int8" + + return { + "backend": "cpu", + "device": "cpu", + "compute_type": "int8", + "label": "CPU · int8", + } + + +def _decode_audio_pyav(file_path): + """Decode any audio/video file to a float32 mono 16 kHz numpy array. + + Uses PyAV (bundled FFmpeg) — no external ffmpeg binary required. + Returns (audio_array, duration_seconds). + """ + import av + import numpy as np + + with av.open(file_path) as container: + duration = float(container.duration) / 1_000_000 # microseconds → seconds + stream = container.streams.audio[0] + resampler = av.AudioResampler(format="fltp", layout="mono", rate=16000) + chunks = [] + for frame in container.decode(stream): + for out in resampler.resample(frame): + if out: + chunks.append(out.to_ndarray()[0]) + # Flush resampler + for out in resampler.resample(None): + if out: + chunks.append(out.to_ndarray()[0]) + + if not chunks: + return np.zeros(0, dtype=np.float32), duration + return np.concatenate(chunks, axis=0), duration + + +def _transcribe_mlx_file(file, mlx_model_id, language, timestamps, verbose): + """Transcribe a single file with mlx-whisper (Apple GPU/NPU). + + Decodes audio via PyAV (no system ffmpeg needed), then runs MLX inference. + Returns (segments_as_dicts, audio_duration_seconds). + Segments have dict keys: 'start', 'end', 'text'. + """ + import mlx_whisper + + audio_array, duration = _decode_audio_pyav(file) + + decode_opts = {} + if language: + decode_opts["language"] = language + + result = mlx_whisper.transcribe( + audio_array, + path_or_hf_repo=mlx_model_id, + verbose=(True if verbose else None), + **decode_opts, + ) + segments = result["segments"] + audio_duration = segments[-1]["end"] if segments else duration + return segments, audio_duration + + +def _srt_timestamp(seconds): + """Convert seconds (float) to SRT timestamp format HH:MM:SS,mmm.""" + ms = round(seconds * 1000) + h, ms = divmod(ms, 3_600_000) + m, ms = divmod(ms, 60_000) + s, ms = divmod(ms, 1000) + return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" # Get the path @@ -91,7 +200,7 @@ def get_path(path): return sorted(media_files) # Main function -def transcribe(path, glob_file, model=None, language=None, verbose=False, timestamps=True): +def transcribe(path, glob_file, model=None, language=None, verbose=False, timestamps=True, stop_event=None): """ Transcribes audio files in a specified folder using faster-whisper (CTranslate2). @@ -122,10 +231,98 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest SEP = "─" * 46 # ── Step 1: Detect hardware ────────────────────────────────────── - device, compute_type = _detect_device() - print(f"⚙ Device: {device} | Compute: {compute_type}") + backend_info = detect_backend() + backend = backend_info["backend"] + device = backend_info["device"] + compute_type = backend_info["compute_type"] + print(f"⚙ Backend: {backend_info['label']}") - # ── Step 2: Load model ─────────────────────────────────────────── + # ── Step 1b: MLX path (Apple GPU/NPU) ─────────────────────────── + if backend == "mlx": + mlx_model_id = _MLX_MODEL_MAP.get(model) + if mlx_model_id is None: + print(f"⚠ Model '{model}' is not available in MLX format.") + print(" Falling back to faster-whisper on CPU (int8).") + backend = "cpu" + device, compute_type = "cpu", "int8" + else: + # ── Step 2 (MLX): load + transcribe ───────────────────── + print(f"⏳ Loading MLX model '{model}' — downloading if needed...") + print("✅ Model ready!") + print(SEP) + + total_files = len(glob_file) + print(f"📂 Found {total_files} supported media file(s) in folder") + print(SEP) + + if total_files == 0: + output_text = '⚠ No supported media files found — try another folder.' + print(output_text) + print(SEP) + return output_text + + files_transcripted = [] + file_num = 0 + for file in glob_file: + if stop_event and stop_event.is_set(): + print("⛔ Transcription stopped by user.") + break + title = os.path.basename(file).split('.')[0] + file_num += 1 + print(f"\n{'─' * 46}") + print(f"📄 File {file_num}/{total_files}: {title}") + try: + t_start = time.time() + segments, audio_duration = _transcribe_mlx_file( + file, mlx_model_id, language, timestamps, verbose + ) + os.makedirs('{}/transcriptions'.format(path), exist_ok=True) + segment_list = [] + txt_path = "{}/transcriptions/{}.txt".format(path, title) + srt_path = "{}/transcriptions/{}.srt".format(path, title) + with open(txt_path, 'w', encoding='utf-8') as f, \ + open(srt_path, 'w', encoding='utf-8') as srt_f: + f.write(title) + f.write('\n' + '─' * 40 + '\n') + for idx, seg in enumerate(segments, start=1): + if stop_event and stop_event.is_set(): + break + text = seg["text"].strip() + if timestamps: + start_ts = str(datetime.timedelta(seconds=seg["start"])) + end_ts = str(datetime.timedelta(seconds=seg["end"])) + f.write('\n[{} --> {}] {}'.format(start_ts, end_ts, text)) + else: + f.write('\n{}'.format(text)) + srt_f.write(f'{idx}\n{_srt_timestamp(seg["start"])} --> {_srt_timestamp(seg["end"])}\n{text}\n\n') + f.flush() + srt_f.flush() + if verbose: + print(" [%.2fs → %.2fs] %s" % (seg["start"], seg["end"], seg["text"])) + else: + print(" Transcribed up to %.0fs..." % seg["end"], end='\r') + segment_list.append(seg) + elapsed = time.time() - t_start + elapsed_min = elapsed / 60.0 + audio_min = audio_duration / 60.0 + ratio = audio_duration / elapsed if elapsed > 0 else float('inf') + print(f"✅ Done — saved to transcriptions/{title}.txt") + print(f"⏱ Transcribed {audio_min:.1f} min of audio in {elapsed_min:.1f} min ({ratio:.1f}x realtime)") + files_transcripted.append(segment_list) + except Exception as exc: + print(f"⚠ Could not decode '{os.path.basename(file)}', skipping.") + print(f" Reason: {exc}") + + print(f"\n{SEP}") + if files_transcripted: + output_text = f"✅ Finished! {len(files_transcripted)} file(s) transcribed.\n Saved in: {path}/transcriptions" + else: + output_text = '⚠ No files eligible for transcription — try another folder.' + print(output_text) + print(SEP) + return output_text + + # ── Step 2: Load model (faster-whisper / CTranslate2) ─────────── print(f"⏳ Loading model '{model}' — downloading if needed...") try: whisper_model = WhisperModel(model, device=device, compute_type=compute_type) @@ -164,6 +361,9 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest files_transcripted = [] file_num = 0 for file in glob_file: + if stop_event and stop_event.is_set(): + print("⛔ Transcription stopped by user.") + break title = os.path.basename(file).split('.')[0] file_num += 1 print(f"\n{'─' * 46}") @@ -180,10 +380,15 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest os.makedirs('{}/transcriptions'.format(path), exist_ok=True) # Stream segments as they are decoded segment_list = [] - with open("{}/transcriptions/{}.txt".format(path, title), 'w', encoding='utf-8') as f: + txt_path = "{}/transcriptions/{}.txt".format(path, title) + srt_path = "{}/transcriptions/{}.srt".format(path, title) + with open(txt_path, 'w', encoding='utf-8') as f, \ + open(srt_path, 'w', encoding='utf-8') as srt_f: f.write(title) f.write('\n' + '─' * 40 + '\n') - for seg in segments: + for idx, seg in enumerate(segments, start=1): + if stop_event and stop_event.is_set(): + break text = seg.text.strip() if timestamps: start_ts = str(datetime.timedelta(seconds=seg.start)) @@ -191,7 +396,9 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False, timest f.write('\n[{} --> {}] {}'.format(start_ts, end_ts, text)) else: f.write('\n{}'.format(text)) + srt_f.write(f'{idx}\n{_srt_timestamp(seg.start)} --> {_srt_timestamp(seg.end)}\n{text}\n\n') f.flush() + srt_f.flush() if verbose: print(" [%.2fs → %.2fs] %s" % (seg.start, seg.end, seg.text)) else: