diff --git a/src/_LocalTranscribe.py b/src/_LocalTranscribe.py index 94c8f8f..5d78562 100644 --- a/src/_LocalTranscribe.py +++ b/src/_LocalTranscribe.py @@ -2,7 +2,7 @@ import os import datetime from glob import glob import whisper -from torch import cuda, Generator +from torch import backends, cuda, Generator import colorama from colorama import Back,Fore colorama.init(autoreset=True) @@ -39,12 +39,15 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False): - The transcribed text files will be saved in a "transcriptions" folder within the specified path. - """ + """ # Check for GPU acceleration - if cuda.is_available(): + if backends.mps.is_available(): + Generator('mps').manual_seed(42) + elif cuda.is_available(): Generator('cuda').manual_seed(42) else: Generator().manual_seed(42) + # Load model model = whisper.load_model(model) # Start main loop