Add explicit device parameter to whisper.load_model()
Co-authored-by: soderstromkr <23003509+soderstromkr@users.noreply.github.com>
This commit is contained in:
@@ -40,16 +40,19 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False):
|
|||||||
within the specified path.
|
within the specified path.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Check for GPU acceleration
|
# Check for GPU acceleration and set device
|
||||||
if backends.mps.is_available():
|
if backends.mps.is_available():
|
||||||
|
device = 'mps'
|
||||||
Generator('mps').manual_seed(42)
|
Generator('mps').manual_seed(42)
|
||||||
elif cuda.is_available():
|
elif cuda.is_available():
|
||||||
|
device = 'cuda'
|
||||||
Generator('cuda').manual_seed(42)
|
Generator('cuda').manual_seed(42)
|
||||||
else:
|
else:
|
||||||
|
device = 'cpu'
|
||||||
Generator().manual_seed(42)
|
Generator().manual_seed(42)
|
||||||
|
|
||||||
# Load model
|
# Load model on the correct device
|
||||||
model = whisper.load_model(model)
|
model = whisper.load_model(model, device=device)
|
||||||
# Start main loop
|
# Start main loop
|
||||||
files_transcripted=[]
|
files_transcripted=[]
|
||||||
for file in glob_file:
|
for file in glob_file:
|
||||||
|
|||||||
BIN
src/__pycache__/_LocalTranscribe.cpython-312.pyc
Normal file
BIN
src/__pycache__/_LocalTranscribe.cpython-312.pyc
Normal file
Binary file not shown.
Reference in New Issue
Block a user