JAX는 Google Research에서 개발한 Python 라이브러리로, 고성능 수치 계산 능력을 제공하는 것을 목표로 합니다. NumPy의 사용 편의성과 자동 미분, 즉시 컴파일 등 고급 기능을 결합하여 연구원과 엔지니어가 복잡한 머신러닝 모델을 더 쉽게 구축하고 훈련하며 과학 계산을 수행할 수 있도록 합니다. JAX의 핵심 이념은 함수형 변환으로, 사용자가 간결한 방식으로 수치 함수에 대한 미분, 벡터화, 병렬화 등의 작업을 수행할 수 있도록 합니다.
머신러닝 및 과학 계산 분야에서 고성능 수치 계산은 매우 중요합니다. 기존의 NumPy는 사용하기 쉽지만 자동 미분 및 GPU/TPU 가속 측면에서 한계가 있습니다. TensorFlow 및 PyTorch와 같은 딥러닝 프레임워크는 이러한 기능을 제공하지만 학습 곡선이 가파르고 특정 과학 계산 시나리오에서는 유연성이 부족합니다.
JAX의 출현은 이러한 부족함을 메우는 것을 목표로 합니다. 머신러닝 모델 훈련 요구 사항을 충족하고 다양한 과학 계산 작업을 지원할 수 있는 통합 플랫폼을 제공합니다. 함수형 변환을 통해 JAX는 자동 미분, 병렬화 등의 작업을 단순화하여 사용자가 알고리즘 설계 및 구현에 더 집중할 수 있도록 합니다.
vmap
함수를 제공하여 스칼라 함수를 자동으로 벡터화하여 배열에서 효율적인 병렬 계산을 수행할 수 있습니다.pmap
함수를 제공하여 여러 장치에서 함수를 병렬로 실행하여 대규모 계산 작업을 가속화할 수 있습니다.JAX는 다음과 같은 분야에서 널리 사용됩니다.