Merge pull request #11 from soderstromkr/copilot/update-whisper-device-parameter
Pass explicit device parameter to whisper.load_model() for MPS acceleration
This commit is contained in:
25
.gitignore
vendored
Normal file
25
.gitignore
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Python cache
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
*.egg-info/
|
||||||
@@ -40,16 +40,19 @@ def transcribe(path, glob_file, model=None, language=None, verbose=False):
|
|||||||
within the specified path.
|
within the specified path.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Check for GPU acceleration
|
# Check for GPU acceleration and set device
|
||||||
if backends.mps.is_available():
|
if backends.mps.is_available():
|
||||||
|
device = 'mps'
|
||||||
Generator('mps').manual_seed(42)
|
Generator('mps').manual_seed(42)
|
||||||
elif cuda.is_available():
|
elif cuda.is_available():
|
||||||
|
device = 'cuda'
|
||||||
Generator('cuda').manual_seed(42)
|
Generator('cuda').manual_seed(42)
|
||||||
else:
|
else:
|
||||||
|
device = 'cpu'
|
||||||
Generator().manual_seed(42)
|
Generator().manual_seed(42)
|
||||||
|
|
||||||
# Load model
|
# Load model on the correct device
|
||||||
model = whisper.load_model(model)
|
model = whisper.load_model(model, device=device)
|
||||||
# Start main loop
|
# Start main loop
|
||||||
files_transcripted=[]
|
files_transcripted=[]
|
||||||
for file in glob_file:
|
for file in glob_file:
|
||||||
|
|||||||
Binary file not shown.
Reference in New Issue
Block a user