JAX Framework를 사용하다 보면 가벼운 프로그램 하나를 수행하는데도 GPU 메모리를 대부분 사용하는 상황을 확인할 수 있다. 이런 경우 한 GPU에서 가벼운 하나의 프로그램이 전체 GPU 메모리를 점유하여 병렬적으로 프로그램을 실행할 수 없는 문제가 발생한다.
이러한 문제가 발생하는 이유는 JAX Framework가 효율성을 위하여 기본적으로 GPU 메모리의 90%를 미리 할당해놓기 (Preallocate) 때문이다.
export XLA_PYTHON_CLIENT_PREALLOCATE=false
Preallocate 옵션을 false로 바꿔줌으로써 preallocation을 수행하지 않게 할 수 있다.
export XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
혹은, preallocation 하는 메모리의 양(.XX)을 조절할 수 있다.
참고: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html