JAX هي مكتبة Python تم تطويرها بواسطة Google Research، تهدف إلى توفير قدرات حسابية عددية عالية الأداء. تجمع بين سهولة استخدام NumPy والميزات المتقدمة مثل التفاضل التلقائي والترجمة الفورية (JIT)، مما يمكّن الباحثين والمهندسين من بناء وتدريب نماذج تعلم آلي معقدة وإجراء حسابات علمية بسهولة أكبر. الفكرة الأساسية لـ JAX هي التحويلات الوظيفية، والتي تسمح للمستخدمين بإجراء عمليات مثل الاشتقاق والتحويل إلى متجه والتوازي على الدوال العددية بطريقة موجزة.
في مجالات التعلم الآلي والحساب العلمي، يعتبر الحساب العددي عالي الأداء أمرًا بالغ الأهمية. على الرغم من سهولة استخدام NumPy التقليدية، إلا أنها تعاني من قيود في التفاضل التلقائي وتسريع GPU/TPU. في حين أن أطر التعلم العميق مثل TensorFlow و PyTorch توفر هذه الوظائف، إلا أن منحنى التعلم الخاص بها حاد، وهي ليست مرنة بما يكفي في بعض سيناريوهات الحساب العلمي.
يهدف ظهور JAX إلى سد هذه الثغرات. فهو يوفر منصة موحدة يمكنها تلبية احتياجات تدريب نماذج التعلم الآلي ودعم مهام الحساب العلمي المختلفة. من خلال التحويلات الوظيفية، يبسط JAX عمليات مثل التفاضل التلقائي والتوازي، مما يسمح للمستخدمين بالتركيز بشكل أكبر على تصميم وتنفيذ الخوارزميات.
vmap
، والتي يمكنها تحويل الدوال العددية إلى متجه تلقائيًا، وبالتالي إجراء حسابات متوازية فعالة على المصفوفات.pmap
، والتي يمكنها تنفيذ الدوال بالتوازي على أجهزة متعددة، وبالتالي تسريع مهام الحساب واسعة النطاق.يستخدم JAX على نطاق واسع في المجالات التالية: