You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
1 week ago | |
---|---|---|
.gitignore | 1 week ago | |
Pipfile | 1 week ago | |
Pipfile.lock | 1 week ago | |
ReadMe.md | 1 week ago | |
main.py | 1 week ago |
ReadMe.md
whisper-jax optimized GPU accelerated speech to text
- Make sure you have the nvidia drivers for linux installed.
- make sure pipenv is installed
sudo apt install pipenv
- make sure ffmpeg is installed
sudo apt install ffmpeg
- Look at your GPU using nvidia-smi
nvidia-smi
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 2060 Off | 00000000:01:00.0 On | N/A | | 0% 48C P8 20W / 183W | 523MiB / 6144MiB | 4% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+
- install
cudnn
- https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Debian&target_version=12&target_type=deb_local
wget https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-debian12-9.8.0_1.0-1_amd64.deb sudo dpkg -i cudnn-local-repo-debian12-9.8.0_1.0-1_amd64.deb sudo cp /var/cuda-repo-debian12-9-8-local/cudnn-*-keyring.gpg /usr/share/keyrings/ sudo apt-get update sudo apt-get -y install cudnn
- https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Debian&target_version=12&target_type=deb_local
- create a new project with pipenv (new virtual environment and Pipfile)
pipenv --python 3.12
- install jaxlib for CUDA (nvidia GPU)
- this may take a while.
pipenv install "jax[cuda12-local]"
- install whisper-jax python library from github (forked to fix issue 199):
- https://github.com/sanchit-gandhi/whisper-jax/issues/199
pipenv install git+https://git.sequentialread.com/forest/whisper-jax
- lets hope it works!!
python3 main.py
huggingface transformers library info:
Default cache directories (overridable with env vars):
- Linux/macOS:
~/.cache/huggingface/
- Windows:
C:\Users\<username>\.cache\huggingface\
Results:
whisper-large-v2
: OOM crash (host only has 8GB of RAM, GPU has 6GB of VRAM)whisper-medium.en
: Noticeably better transcription,7x
speed (1 minute transcribed in ~9 seconds, 5 seconds transcribed in ~0.7 seconds)whisper-small.en
: Noticeably worse transcription, not better than futo keyboard.20x
speed (1 minute transcribed in ~3 seconds)