Home
Login

JAX est une bibliothèque Python pour le calcul numérique haute performance et l'apprentissage automatique. Elle combine la facilité d'utilisation de NumPy avec la puissance de la différenciation automatique, et peut fonctionner sur CPU, GPU et TPU.

Apache-2.0Python 32.5kjax-ml Last Updated: 2025-06-14

JAX : Transformations fonctionnelles composables pour le calcul numérique haute performance

Aperçu du projet

JAX est une bibliothèque Python développée par Google Research, conçue pour fournir des capacités de calcul numérique haute performance. Elle combine la facilité d'utilisation de NumPy avec des fonctionnalités avancées telles que la différentiation automatique et la compilation à la volée (JIT), permettant aux chercheurs et aux ingénieurs de construire et d'entraîner plus facilement des modèles d'apprentissage automatique complexes et d'effectuer des calculs scientifiques. Le principe fondamental de JAX est la transformation fonctionnelle, qui permet aux utilisateurs d'effectuer des opérations telles que la dérivation, la vectorisation et la parallélisation de fonctions numériques de manière concise.

Contexte

Dans les domaines de l'apprentissage automatique et du calcul scientifique, le calcul numérique haute performance est essentiel. Bien que NumPy soit facile à utiliser, il présente des limitations en termes de différentiation automatique et d'accélération GPU/TPU. Les frameworks d'apprentissage profond tels que TensorFlow et PyTorch offrent ces fonctionnalités, mais leur courbe d'apprentissage est plus abrupte et ils manquent de flexibilité dans certains scénarios de calcul scientifique.

L'objectif de JAX est de combler ces lacunes. Il fournit une plateforme unifiée qui répond à la fois aux besoins de l'entraînement des modèles d'apprentissage automatique et prend en charge diverses tâches de calcul scientifique. Grâce à la transformation fonctionnelle, JAX simplifie les opérations telles que la différentiation automatique et la parallélisation, permettant aux utilisateurs de se concentrer davantage sur la conception et la mise en œuvre des algorithmes.

Caractéristiques principales

  • Différentiation automatique (Autodiff) : JAX offre une puissante fonctionnalité de différentiation automatique, capable de calculer des dérivées d'ordre arbitraire. Il prend en charge les modes de différentiation avant et arrière, permettant aux utilisateurs de choisir le mode approprié en fonction de leurs besoins spécifiques.
  • Compilation à la volée (JIT Compilation) : JAX peut compiler le code Python en code XLA (Accelerated Linear Algebra) optimisé, permettant ainsi un calcul haute performance sur CPU, GPU et TPU.
  • Vectorisation (Vectorization) : JAX fournit la fonction vmap, qui peut automatiquement vectoriser les fonctions scalaires, permettant ainsi un calcul parallèle efficace sur les tableaux.
  • Parallélisation (Parallelization) : JAX fournit la fonction pmap, qui peut exécuter des fonctions en parallèle sur plusieurs appareils, accélérant ainsi les tâches de calcul à grande échelle.
  • Transformations fonctionnelles composables : Le principe fondamental de JAX est la transformation fonctionnelle. Les utilisateurs peuvent combiner différentes fonctions de transformation pour implémenter diverses fonctionnalités avancées, telles que la différentiation automatique, la vectorisation et la parallélisation.
  • Compatibilité NumPy : JAX fournit une API similaire à NumPy, permettant aux utilisateurs de migrer facilement leur code NumPy existant vers JAX.
  • Contrôle explicite du PRNG : JAX oblige les utilisateurs à gérer explicitement les générateurs de nombres aléatoires, évitant ainsi les problèmes liés à l'état global et rendant le code plus facile à déboguer et à reproduire.

Scénarios d'application

JAX est largement utilisé dans les domaines suivants :

  • Apprentissage automatique : JAX peut être utilisé pour construire et entraîner divers modèles d'apprentissage automatique, notamment les réseaux neuronaux profonds, les modèles probabilistes, etc.
  • Calcul scientifique : JAX peut être utilisé pour résoudre divers problèmes de calcul scientifique, tels que la simulation numérique, l'optimisation, l'analyse statistique, etc.
  • Apprentissage par renforcement : JAX peut être utilisé pour implémenter divers algorithmes d'apprentissage par renforcement, tels que la politique de gradient, le Q-learning, etc.
  • Calcul haute performance : JAX peut utiliser des accélérateurs matériels tels que les GPU et les TPU pour réaliser un calcul haute performance.
  • Développement de prototypes de recherche : La flexibilité et la facilité d'utilisation de JAX en font un choix idéal pour le développement de prototypes de recherche.

Pour tous les détails, veuillez vous référer au site officiel (https://github.com/jax-ml/jax)