Merge pull request #11 from soderstromkr/copilot/update-whisper-device-parameter
Pass explicit device parameter to whisper.load_model() for MPS acceleration
This commit is contained in:
25
.gitignore
vendored
Normal file
25
.gitignore
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
# Python cache
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Build artifacts
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
@@ -40,16 +40,19 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False):
|
||||
within the specified path.
|
||||
|
||||
"""
|
||||
# Check for GPU acceleration
|
||||
# Check for GPU acceleration and set device
|
||||
if backends.mps.is_available():
|
||||
device = 'mps'
|
||||
Generator('mps').manual_seed(42)
|
||||
elif cuda.is_available():
|
||||
device = 'cuda'
|
||||
Generator('cuda').manual_seed(42)
|
||||
else:
|
||||
device = 'cpu'
|
||||
Generator().manual_seed(42)
|
||||
|
||||
# Load model
|
||||
model = whisper.load_model(model)
|
||||
# Load model on the correct device
|
||||
model = whisper.load_model(model, device=device)
|
||||
# Start main loop
|
||||
files_transcripted=[]
|
||||
for file in glob_file:
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user