Flax は、Google によって開発されたオープンソースのニューラルネットワークライブラリで、JAX 専用に設計されています。柔軟性、高性能、使いやすさを提供し、深層学習モデルの研究開発のためのプラットフォームとなることを目指しています。Flax は関数型プログラミングパラダイムに焦点を当て、モジュール化された再利用可能なコードを推奨しています。
GitHub リポジトリ: https://github.com/google/flax
flax.linen
: コアモジュールで、ニューラルネットワーク層を定義するための API を提供します。畳み込み層、全結合層、再帰層など、さまざまな定義済みの層が含まれています。flax.training
: モデルのトレーニングに使用するユーティリティを提供します。オプティマイザ、学習率スケジューラ、チェックポイント管理などが含まれます。flax.optim
: Adam、SGD などのさまざまなオプティマイザと、学習率スケジューラが含まれています。flax.struct
: モデルの状態を表すために使用される不変のデータ構造を定義します。flax.core
: 低レベルの関数型プログラミングツールを提供します。import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
# モデルの定義
class LinearRegression(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(features=1)(x)
# トレーニングステップの定義
def train_step(state, batch):
def loss_fn(params):
y_pred = state.apply_fn({'params': params}, batch['x'])
loss = jnp.mean((y_pred - batch['y'])**2)
return loss
grad_fn = jax.grad(loss_fn)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
# モデルとオプティマイザの初期化
key = jax.random.PRNGKey(0)
model = LinearRegression()
params = model.init(key, jnp.ones((1, 1)))['params']
tx = optax.sgd(learning_rate=0.01)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
# データの生成
x = jnp.linspace(0, 1, 100)[:, None]
y = 2 * x + 1 + jax.random.normal(key, (100, 1)) * 0.1
batch = {'x': x, 'y': y}
# モデルのトレーニング
for i in range(100):
state = train_step(state, batch)
# 結果の出力
print(f"Trained parameters: {state.params}")
Flax は、JAX 専用に設計された強力なニューラルネットワークライブラリです。柔軟性、高性能、使いやすさを提供し、深層学習モデルの研究開発に最適な選択肢となっています。JAX ベースのフレームワークを探しているなら、Flax は間違いなく検討する価値があります。