Flax 是一個由 Google 開發的開源神經網路庫,專為 JAX 設計。它旨在提供一個靈活、高性能且易於使用的平台,用於研究和開發深度學習模型。Flax 專注於函數式編程範式,鼓勵模組化和可重用的代碼。
GitHub 倉庫: https://github.com/google/flax
flax.linen
: 核心模組,提供用於定義神經網路層的 API。它包含各種預定義的層,如卷積層、全連接層、循環層等。flax.training
: 提供用於訓練模型的實用工具,包括優化器、學習率調度器、檢查點管理等。flax.optim
: 包含各種優化器,如 Adam、SGD 等,以及學習率調度器。flax.struct
: 定義用於表示模型狀態的不可變數據結構。flax.core
: 提供底層函數式編程工具。import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
# 定義模型
class LinearRegression(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(features=1)(x)
# 定義訓練步驟
def train_step(state, batch):
def loss_fn(params):
y_pred = state.apply_fn({'params': params}, batch['x'])
loss = jnp.mean((y_pred - batch['y'])**2)
return loss
grad_fn = jax.grad(loss_fn)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
# 初始化模型和優化器
key = jax.random.PRNGKey(0)
model = LinearRegression()
params = model.init(key, jnp.ones((1, 1)))['params']
tx = optax.sgd(learning_rate=0.01)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
# 生成數據
x = jnp.linspace(0, 1, 100)[:, None]
y = 2 * x + 1 + jax.random.normal(key, (100, 1)) * 0.1
batch = {'x': x, 'y': y}
# 訓練模型
for i in range(100):
state = train_step(state, batch)
# 打印結果
print(f"Trained parameters: {state.params}")
Flax 是一個強大的神經網路庫,專為 JAX 設計。它提供了靈活性、高性能和易用性,使其成為研究和開發深度學習模型的理想選擇。如果您正在尋找一個基於 JAX 的框架,Flax 絕對值得考慮。