Testing Whisper-JAX on $100 GPU
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
Your Name afdb496e36 formatting 10 months ago
.gitignore results 10 months ago
Pipfile results 10 months ago
Pipfile.lock results 10 months ago
ReadMe.md formatting 10 months ago
main.py results 10 months ago

ReadMe.md

whisper-jax optimized GPU accelerated speech to text

  • Make sure you have the nvidia drivers for linux installed.
  • make sure pipenv is installed
    • sudo apt install pipenv
  • make sure ffmpeg is installed
    • sudo apt install ffmpeg
  • Look at your GPU using nvidia-smi
    • nvidia-smi
      +-----------------------------------------------------------------------------------------+
      | NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
      |-----------------------------------------+------------------------+----------------------+
      | GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
      | Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
      |                                         |                        |               MIG M. |
      |=========================================+========================+======================|
      |   0  NVIDIA GeForce RTX 2060        Off |   00000000:01:00.0  On |                  N/A |
      |  0%   48C    P8             20W /  183W |     523MiB /   6144MiB |      4%      Default |
      |                                         |                        |                  N/A |
      +-----------------------------------------+------------------------+----------------------+
      
  • install cudnn
  • create a new project with pipenv (new virtual environment and Pipfile)
    • pipenv --python 3.12
  • install jaxlib for CUDA (nvidia GPU)
    • this may take a while.
    • pipenv install "jax[cuda12-local]"
  • install whisper-jax python library from github (forked to fix issue 199):
  • lets hope it works!!
    • python3 main.py

huggingface transformers library info:

Default cache directories (overridable with env vars):

  • Linux/macOS: ~/.cache/huggingface/
  • Windows: C:\Users\<username>\.cache\huggingface\

Results:

  • whisper-large-v2: OOM crash (host only has 8GB of RAM, GPU has 6GB of VRAM)
  • whisper-medium.en: Noticeably better transcription, 7x speed (1 minute transcribed in ~9 seconds, 5 seconds transcribed in ~0.7 seconds)
  • whisper-small.en: Noticeably worse transcription, not better than futo keyboard. 20x speed (1 minute transcribed in ~3 seconds)