Testing Whisper-JAX on $100 GPU
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
Your Name afdb496e36 formatting 1 week ago
.gitignore results 1 week ago
Pipfile results 1 week ago
Pipfile.lock results 1 week ago
ReadMe.md formatting 1 week ago
main.py results 1 week ago

ReadMe.md

whisper-jax optimized GPU accelerated speech to text

  • Make sure you have the nvidia drivers for linux installed.
  • make sure pipenv is installed
    • sudo apt install pipenv
  • make sure ffmpeg is installed
    • sudo apt install ffmpeg
  • Look at your GPU using nvidia-smi
    • nvidia-smi
      +-----------------------------------------------------------------------------------------+
      | NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
      |-----------------------------------------+------------------------+----------------------+
      | GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
      | Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
      |                                         |                        |               MIG M. |
      |=========================================+========================+======================|
      |   0  NVIDIA GeForce RTX 2060        Off |   00000000:01:00.0  On |                  N/A |
      |  0%   48C    P8             20W /  183W |     523MiB /   6144MiB |      4%      Default |
      |                                         |                        |                  N/A |
      +-----------------------------------------+------------------------+----------------------+
      
  • install cudnn
  • create a new project with pipenv (new virtual environment and Pipfile)
    • pipenv --python 3.12
  • install jaxlib for CUDA (nvidia GPU)
    • this may take a while.
    • pipenv install "jax[cuda12-local]"
  • install whisper-jax python library from github (forked to fix issue 199):
  • lets hope it works!!
    • python3 main.py

huggingface transformers library info:

Default cache directories (overridable with env vars):

  • Linux/macOS: ~/.cache/huggingface/
  • Windows: C:\Users\<username>\.cache\huggingface\

Results:

  • whisper-large-v2: OOM crash (host only has 8GB of RAM, GPU has 6GB of VRAM)
  • whisper-medium.en: Noticeably better transcription, 7x speed (1 minute transcribed in ~9 seconds, 5 seconds transcribed in ~0.7 seconds)
  • whisper-small.en: Noticeably worse transcription, not better than futo keyboard. 20x speed (1 minute transcribed in ~3 seconds)