JAX 的本质,不是“另一个深度学习框架”,而是一个让你用 NumPy 风格写出可编译、可求导、可加速计算的高性能工具。


如果你用过 PyTorch 或 TensorFlow,可以把它们理解成“深度学习全家桶”:
常用功能都帮你打包好了
上手快
做工程开发很顺手
而 JAX 不太一样。
它更像一个“高性能计算底座”:
不是大而全的框架
更像给科研和数值计算准备的基础工具
写法像 NumPy,学习门槛低
但背后可以通过编译器和硬件加速跑得非常快
一句话概括:
PyTorch / TensorFlow 更像成品厨房,JAX 更像高性能实验室。
先说结论:JAX 值不值得学?
如果你只想记住一个判断标准,那就是:
当你既想写得像 Python,又想跑得像编译语言,JAX 就很适合。
它把三件事合在了一起:
像 NumPy 一样好写
像编译语言一样快
像自动化机器一样能求导、能批量、能并行
这也是为什么 JAX 在机器学习研究、优化算法、科学计算里越来越常见。
1)安装:CPU、GPU、TPU 都能跑
JAX 的一个很大优势是:同一套代码可以在不同硬件上运行,不需要你重写逻辑。
CPU 版本:最简单,适合先入门
pip install -U jax
这个版本最省心,基本哪里都能装。
GPU 版本:适合训练和大规模计算
NVIDIA CUDA
pip install -U "jax[cuda12]"
AMD ROCm
pip install -U "jax[rocm]"
CUDA 要注意版本匹配
如果你是 NVIDIA 显卡,先看一下当前 CUDA 版本:
nvidia-smi
如果你看到类似:
CUDA Version: 12.9
那就说明你的环境属于 CUDA 12.x 系列,通常可以使用 cuda12 对应的 JAX 版本。
这一步非常关键。
JAX 装不上、GPU 用不了,很多时候就是版本没对齐。
TPU 版本
pip install -U "jax[tpu]"
TPU 更多出现在云端或研究环境,一般普通电脑不会直接用到。
怎么确认 JAX 现在到底在用什么硬件?
import jax
print(jax.devices())
如果输出类似:
[CudaDevice(id=0)]
说明 JAX 已经识别到 GPU,并且会把计算交给它。
2)JAX 像 NumPy,但不是 NumPy
JAX 的设计目标之一,就是让学过 NumPy 的人能快速上手。
import jax.numpy as jnp
你会发现很多写法几乎一样:
# NumPy
y_np = np.sin(x) + x**2
# JAX
y_jax = jnp.sin(x) + x**2
函数名、运算符、广播规则、索引方式,都非常像。
但注意:
语法像,不代表行为完全一样。
一个最重要的区别:JAX 数组是不可变的
NumPy 里可以直接改
x[0] = 10
JAX 里不能这么写,要这样:
x = x.at[0].set(10)
这不是“故意添麻烦”,而是 JAX 的设计选择。
你可以把它理解成:
NumPy:直接修改原数组
JAX:返回一个“修改后的新结果”
这样做的好处是:
更适合编译优化
更适合并行计算
更适合自动求导
更适合可复现随机数
JAX 更偏“函数式”风格
普通 Python 往往是“写一行,执行一行”。
JAX 在 jit、grad、vmap 等变换下,会先把你的计算“看懂”,再交给编译器优化。
可以把它理解为:
普通 Python:边走边算
JAX:先画地图,再高速通行
3)JAX 的三大核心:JIT、自动微分、XLA
很多人会把 JAX 记成三个关键词:
JIT
Automatic Differentiation
XLA
这三个东西,基本就是 JAX 的灵魂。
3.1 JIT:把 Python 代码翻译成更快的机器代码
Python 好写,但做大规模数值计算时,解释执行开销比较大。
比如这句:
c = a + b
看起来只是加法,但 Python 背后要做很多事:
找到
a找到
b判断类型
决定调用哪个加法方法
创建新对象
管理引用计数
返回结果
如果只是一次还好。
但如果你有一亿次循环:
for i in range(100_000_000):
x = x * 1.000001
那 Python 大部分时间都花在“解释和调度”,而不是“真正计算”。
NumPy 为什么快?
因为 NumPy 的思路是:
Python 负责发号施令
真正的计算交给底层 C / 向量化内核
一次处理整批数据
所以 NumPy 本质上是:
Python 做指挥,底层编译代码做体力活。
JAX 更进一步:把整个函数编译掉
JAX 的思路是:
不只是调用一个个底层函数,而是把整个计算过程编译成一个优化后的程序。
比如:
from jax import jit
@jit
def f(x):
return x * x + 2 * x + 1
第一次调用这个函数时,JAX 会做三件事:
1)追踪函数
把真实数值换成“抽象符号”,记录计算流程。
2)构建计算图
把这段计算整理成编译器能理解的结构。
3)交给 XLA 编译
XLA 会做优化,比如:
算子融合
减少中间数组
生成更适合硬件执行的代码
JIT 的关键现象:第一次慢,后面快
这点非常重要。
第一次调用:要追踪 + 编译,所以偏慢
后续调用:直接跑编译后的机器代码,非常快
所以评价 JAX 的性能时,不能只看第一次。
要看“编译完成后”的稳定运行速度。
3.2 自动微分:JAX 会替你求导
训练神经网络时,最常做的一件事就是:求梯度。
梯度可以理解成:
参数往哪个方向调整,损失下降得最快。
JAX 的自动微分,就是把链式法则自动应用到你的计算上。
一个最简单的例子
import jax.numpy as jnp
from jax import grad
def f(x):
return jnp.sin(x) * jnp.exp(-x) + x**3
f1 = grad(f) # 一阶导
f2 = grad(f1) # 二阶导
f3 = grad(f2) # 三阶导
这说明:
grad(f)给你一阶导数再套一层
grad,就能得到二阶导继续套下去,还能得到更高阶导数
多输入函数也可以分别求导
from jax import grad
import jax.numpy as jnp
def g(x, y):
return x**2 * y + jnp.sin(y) + jnp.exp(x)
dg_dx = grad(g, argnums=0)
dg_dy = grad(g, argnums=1)
JAX 会帮你分别算出:
对
x的偏导对
y的偏导
为什么这很重要?
因为机器学习、优化、科学计算里,大量核心问题本质上都是:
算函数值
算导数
再更新参数
JAX 把“求导”这件原本很数学、很繁琐的事,变成了程序自动处理。
3.3 XLA:背后真正干活的优化器
XLA 可以理解成 JAX 背后的“执行引擎”。
它会把计算图变成更适合硬件执行的形式。
在这个过程中,最重要的一件事就是:
把多个小操作合并成更少的大操作。
比如:
y = jnp.sin(x) * jnp.exp(-x) + x * x + 3 * x
如果按传统方式执行,可能会产生很多中间数组:
算
sin(x)算
exp(-x)做乘法
算
x*x算
3*x最后再加起来
而 XLA 会尽量把这些步骤融合成更少的 kernel 执行。
这样可以减少:
内存读写
中间结果开销
调度成本
所以你看到的“JAX JIT 很快”,本质上不是魔法,而是编译器优化。
4)vmap:自动把“单个样本计算”变成“批量计算”
JAX 还有一个特别实用的功能:vmap。
它的作用是:
你写一个处理单个样本的函数,JAX 自动把它扩展成处理一批样本的版本。
这叫 自动向量化。
为什么它有用?
比如你写了一个函数,只会算一张图片、一个样本、一个向量。
如果有 10 万个样本,你可以自己写循环;
也可以交给 vmap。
这样好处是:
代码更简洁
少写循环
更适合 GPU / TPU
更利于编译优化
谁负责什么?
这里要分清楚:
vmap:负责“改写程序结构”XLA:负责“真正优化并执行”
也就是说:
函数 → vmap 改写 → jit 编译 → XLA 优化 → 硬件执行
5)JAX 的随机数:不用全局状态,而是显式 key
NumPy 的随机数,常见写法是全局状态:
import numpy as np
np.random.seed(0)
print(np.random.rand())
print(np.random.rand())
这很方便,但有个问题:
它依赖全局状态,不太适合并行、复现和函数式编程。
JAX 的做法更严格
JAX 不靠“隐藏的全局随机数池”,而是让你显式传入随机 key:
import jax.random as jr
key = jr.PRNGKey(0)
x = jr.normal(key)
如果你重复使用同一个 key,结果是一样的:
print(jr.normal(key))
print(jr.normal(key))
想要新的随机数?必须 split
key = jr.PRNGKey(0)
key, subkey = jr.split(key)
x1 = jr.normal(subkey)
每次 split,都会得到新的子 key。
这样做的好处是:
可复现
并行安全
适合
jit/vmap不容易出现“共享随机状态”的混乱问题
6)性能到底快在哪?不要只看语法,要看执行方式
很多人以为 JAX 快,是因为它“写法更高级”。
其实不是。
JAX 快的核心原因,是它把计算尽可能交给:
编译器
向量化
GPU / TPU
算子融合
一组很有代表性的对比
在作者给出的 CUDA GPU 环境里,测试结果大致是这样的:
纯 Python 循环:
0.021sNumPy 向量化:
0.000274sJAX eager:
0.000110sJAX JIT 第一次调用:
0.044s左右JAX JIT 稳定运行:
13–30 微秒
怎么理解这组数据?
1)纯 Python 最慢
原因很直接:循环和解释器开销太大。
2)NumPy 快很多
因为它把计算交给了底层编译代码。
3)JAX eager 已经能跑 GPU kernel
所以在某些环境下会很快。
4)JAX JIT 第一次慢
因为要先编译。
5)JAX JIT 后面最快
因为它跑的是“编译后的、融合过的、适合硬件的机器代码”。
但有一个重要提醒
这类性能对比,一定要看运行环境。
NumPy 通常跑在 CPU 上
JAX 可能跑在 GPU 上
JAX eager / JIT 的性能也会受数据是否已经在设备上、数组大小、编译是否完成等影响
所以更准确的说法是:
JAX 的优势不只是“快”,而是“同样的代码,可以被编译成更适合硬件的执行方式”。
7)JAX 不只是做计算,也能直接写神经网络
JAX 当然也可以训练模型。
一个最典型的例子,就是用它从零写一个小型多层感知机(MLP)。
一个典型训练流程
加载数据
做标准化和 one-hot 编码
初始化参数
前向传播
计算损失
自动求导
参数更新
评估准确率
核心代码思路
参数初始化
def initialize_params(input_dim, hidden_dim1, hidden_dim2, output_dim, key):
keys = jax.random.split(key, 3)
W1 = jax.random.normal(keys[0], (input_dim, hidden_dim1))
b1 = jnp.zeros((hidden_dim1,))
W2 = jax.random.normal(keys[1], (hidden_dim1, hidden_dim2))
b2 = jnp.zeros((hidden_dim2,))
W3 = jax.random.normal(keys[2], (hidden_dim2, output_dim))
b3 = jnp.zeros((output_dim,))
return W1, b1, W2, b2, W3, b3
前向传播
def forward(params, X):
W1, b1, W2, b2, W3, b3 = params
Z1 = jnp.dot(X, W1) + b1
A1 = jax.nn.relu(Z1)
Z2 = jnp.dot(A1, W2) + b2
A2 = jax.nn.relu(Z2)
logits = jnp.dot(A2, W3) + b3
return logits
损失函数
工程里更推荐用 log_softmax,数值更稳:
def loss_fn(params, x, y, l2_reg=1e-4):
logits = forward(params, x)
log_probs = jax.nn.log_softmax(logits)
l2_loss = l2_reg * sum(jnp.sum(w ** 2) for w in params[::2])
ce = -jnp.mean(jnp.sum(y * log_probs, axis=1))
return ce + l2_loss
一步训练
@jax.jit
def train_step(params, x, y, lr=0.01):
grads = jax.grad(loss_fn)(params, x, y)
return tuple(p - lr * g for p, g in zip(params, grads))
这里很关键:
grad负责自动求导jit负责把训练步骤编译掉
这也是 JAX 训练任务里最经典的组合。
准确率
def accuracy(params, x, y):
logits = forward(params, x)
preds = jnp.argmax(logits, axis=1)
targets = jnp.argmax(y, axis=1)
return jnp.mean(preds == targets)
训练结果说明了什么?
以 Iris 这种三分类数据集为例,随机猜的准确率大约是 33%。
如果训练后测试准确率能到 96%+,说明这个小模型已经学到了有效规律。
这说明两件事:
JAX 完全可以写完整的训练闭环
jit + grad + vmap这套组合,对训练任务非常有价值
8)这篇文章最该记住的几个结论
如果你只想记住最关键的点,可以记这几条:
结论 1:JAX 不是“大而全框架”,而是“高性能计算底座”
它更适合研究、实验和性能敏感的任务。
结论 2:JAX 像 NumPy,但更严格
数组不可变、随机数显式 key、执行更偏函数式。
结论 3:JIT 是 JAX 快的核心原因之一
第一次慢是编译成本,后面快是因为跑的是机器代码。
结论 4:自动微分让“求导”变成现成能力
这对机器学习和优化问题特别重要。
结论 5:XLA 是幕后真正的优化器
算子融合、减少中间结果、提升硬件利用率,都是它在做。
结论 6:vmap 和 GPU / TPU 是 JAX 的加速利器
它们让批量计算更自然,也更容易跑满硬件。
9)什么时候适合用 JAX?
JAX 特别适合这些场景:
做机器学习研究
写需要大量求导的算法
追求高性能数值计算
需要在 GPU / TPU 上跑同一套逻辑
想把代码写得既像数学公式,又能高效执行
如果你做的是传统业务开发、快速工程落地,PyTorch 往往更省心。
如果你重视研究灵活性、编译优化和性能,JAX 非常值得学。
评论区