侧边栏壁纸
博主头像
MobotStone AI

行动起来,活在当下

  • 累计撰写 62 篇文章
  • 累计创建 9 个标签
  • 累计收到 0 条评论

目 录CONTENT

文章目录

JAX 到底是什么?为什么它能像 NumPy 一样写,却能跑出 GPU 级速度?

Administrator
2026-04-24 / 0 评论 / 0 点赞 / 0 阅读 / 0 字

JAX 的本质,不是“另一个深度学习框架”,而是一个让你用 NumPy 风格写出可编译、可求导、可加速计算的高性能工具。

69eacdc33423b.webp

如果你用过 PyTorch 或 TensorFlow,可以把它们理解成“深度学习全家桶”:

  • 常用功能都帮你打包好了

  • 上手快

  • 做工程开发很顺手

JAX 不太一样。

它更像一个“高性能计算底座”:

  • 不是大而全的框架

  • 更像给科研和数值计算准备的基础工具

  • 写法像 NumPy,学习门槛低

  • 但背后可以通过编译器和硬件加速跑得非常快

一句话概括:

PyTorch / TensorFlow 更像成品厨房,JAX 更像高性能实验室。

先说结论:JAX 值不值得学?

如果你只想记住一个判断标准,那就是:

当你既想写得像 Python,又想跑得像编译语言,JAX 就很适合。

它把三件事合在了一起:

  1. 像 NumPy 一样好写

  2. 像编译语言一样快

  3. 像自动化机器一样能求导、能批量、能并行

这也是为什么 JAX 在机器学习研究、优化算法、科学计算里越来越常见。

1)安装:CPU、GPU、TPU 都能跑

JAX 的一个很大优势是:同一套代码可以在不同硬件上运行,不需要你重写逻辑。

CPU 版本:最简单,适合先入门

pip install -U jax

这个版本最省心,基本哪里都能装。


GPU 版本:适合训练和大规模计算

NVIDIA CUDA

pip install -U "jax[cuda12]"

AMD ROCm

pip install -U "jax[rocm]"

CUDA 要注意版本匹配

如果你是 NVIDIA 显卡,先看一下当前 CUDA 版本:

nvidia-smi

如果你看到类似:

CUDA Version: 12.9

那就说明你的环境属于 CUDA 12.x 系列,通常可以使用 cuda12 对应的 JAX 版本。

这一步非常关键。
JAX 装不上、GPU 用不了,很多时候就是版本没对齐。


TPU 版本

pip install -U "jax[tpu]"

TPU 更多出现在云端或研究环境,一般普通电脑不会直接用到。


怎么确认 JAX 现在到底在用什么硬件?

import jax
print(jax.devices())

如果输出类似:

[CudaDevice(id=0)]

说明 JAX 已经识别到 GPU,并且会把计算交给它。


2)JAX 像 NumPy,但不是 NumPy

JAX 的设计目标之一,就是让学过 NumPy 的人能快速上手。

import jax.numpy as jnp

你会发现很多写法几乎一样:

# NumPy
y_np = np.sin(x) + x**2

# JAX
y_jax = jnp.sin(x) + x**2

函数名、运算符、广播规则、索引方式,都非常像。

但注意:

语法像,不代表行为完全一样。


一个最重要的区别:JAX 数组是不可变的

NumPy 里可以直接改

x[0] = 10

JAX 里不能这么写,要这样:

x = x.at[0].set(10)

这不是“故意添麻烦”,而是 JAX 的设计选择。

你可以把它理解成:

  • NumPy:直接修改原数组

  • JAX:返回一个“修改后的新结果”

这样做的好处是:

  • 更适合编译优化

  • 更适合并行计算

  • 更适合自动求导

  • 更适合可复现随机数


JAX 更偏“函数式”风格

普通 Python 往往是“写一行,执行一行”。

JAX 在 jitgradvmap 等变换下,会先把你的计算“看懂”,再交给编译器优化。

可以把它理解为:

  • 普通 Python:边走边算

  • JAX:先画地图,再高速通行


3)JAX 的三大核心:JIT、自动微分、XLA

很多人会把 JAX 记成三个关键词:

  • JIT

  • Automatic Differentiation

  • XLA

这三个东西,基本就是 JAX 的灵魂。


3.1 JIT:把 Python 代码翻译成更快的机器代码

Python 好写,但做大规模数值计算时,解释执行开销比较大。

比如这句:

c = a + b

看起来只是加法,但 Python 背后要做很多事:

  1. 找到 a

  2. 找到 b

  3. 判断类型

  4. 决定调用哪个加法方法

  5. 创建新对象

  6. 管理引用计数

  7. 返回结果

如果只是一次还好。

但如果你有一亿次循环:

for i in range(100_000_000):
    x = x * 1.000001

那 Python 大部分时间都花在“解释和调度”,而不是“真正计算”。


NumPy 为什么快?

因为 NumPy 的思路是:

  • Python 负责发号施令

  • 真正的计算交给底层 C / 向量化内核

  • 一次处理整批数据

所以 NumPy 本质上是:

Python 做指挥,底层编译代码做体力活。


JAX 更进一步:把整个函数编译掉

JAX 的思路是:

不只是调用一个个底层函数,而是把整个计算过程编译成一个优化后的程序。

比如:

from jax import jit

@jit
def f(x):
    return x * x + 2 * x + 1

第一次调用这个函数时,JAX 会做三件事:

1)追踪函数

把真实数值换成“抽象符号”,记录计算流程。

2)构建计算图

把这段计算整理成编译器能理解的结构。

3)交给 XLA 编译

XLA 会做优化,比如:

  • 算子融合

  • 减少中间数组

  • 生成更适合硬件执行的代码


JIT 的关键现象:第一次慢,后面快

这点非常重要。

  • 第一次调用:要追踪 + 编译,所以偏慢

  • 后续调用:直接跑编译后的机器代码,非常快

所以评价 JAX 的性能时,不能只看第一次。

要看“编译完成后”的稳定运行速度。


3.2 自动微分:JAX 会替你求导

训练神经网络时,最常做的一件事就是:求梯度

梯度可以理解成:

参数往哪个方向调整,损失下降得最快。

JAX 的自动微分,就是把链式法则自动应用到你的计算上。


一个最简单的例子

import jax.numpy as jnp
from jax import grad

def f(x):
    return jnp.sin(x) * jnp.exp(-x) + x**3

f1 = grad(f)   # 一阶导
f2 = grad(f1)  # 二阶导
f3 = grad(f2)  # 三阶导

这说明:

  • grad(f) 给你一阶导数

  • 再套一层 grad,就能得到二阶导

  • 继续套下去,还能得到更高阶导数


多输入函数也可以分别求导

from jax import grad
import jax.numpy as jnp

def g(x, y):
    return x**2 * y + jnp.sin(y) + jnp.exp(x)

dg_dx = grad(g, argnums=0)
dg_dy = grad(g, argnums=1)

JAX 会帮你分别算出:

  • x 的偏导

  • y 的偏导


为什么这很重要?

因为机器学习、优化、科学计算里,大量核心问题本质上都是:

  • 算函数值

  • 算导数

  • 再更新参数

JAX 把“求导”这件原本很数学、很繁琐的事,变成了程序自动处理。


3.3 XLA:背后真正干活的优化器

XLA 可以理解成 JAX 背后的“执行引擎”。

它会把计算图变成更适合硬件执行的形式。

在这个过程中,最重要的一件事就是:

把多个小操作合并成更少的大操作。

比如:

y = jnp.sin(x) * jnp.exp(-x) + x * x + 3 * x

如果按传统方式执行,可能会产生很多中间数组:

  • sin(x)

  • exp(-x)

  • 做乘法

  • x*x

  • 3*x

  • 最后再加起来

而 XLA 会尽量把这些步骤融合成更少的 kernel 执行。

这样可以减少:

  • 内存读写

  • 中间结果开销

  • 调度成本

所以你看到的“JAX JIT 很快”,本质上不是魔法,而是编译器优化


4)vmap:自动把“单个样本计算”变成“批量计算”

JAX 还有一个特别实用的功能:vmap

它的作用是:

你写一个处理单个样本的函数,JAX 自动把它扩展成处理一批样本的版本。

这叫 自动向量化

为什么它有用?

比如你写了一个函数,只会算一张图片、一个样本、一个向量。

如果有 10 万个样本,你可以自己写循环;

也可以交给 vmap

这样好处是:

  • 代码更简洁

  • 少写循环

  • 更适合 GPU / TPU

  • 更利于编译优化

谁负责什么?

这里要分清楚:

  • vmap:负责“改写程序结构”

  • XLA:负责“真正优化并执行”

也就是说:

函数 → vmap 改写 → jit 编译 → XLA 优化 → 硬件执行

5)JAX 的随机数:不用全局状态,而是显式 key

NumPy 的随机数,常见写法是全局状态:

import numpy as np

np.random.seed(0)
print(np.random.rand())
print(np.random.rand())

这很方便,但有个问题:

它依赖全局状态,不太适合并行、复现和函数式编程。


JAX 的做法更严格

JAX 不靠“隐藏的全局随机数池”,而是让你显式传入随机 key:

import jax.random as jr

key = jr.PRNGKey(0)
x = jr.normal(key)

如果你重复使用同一个 key,结果是一样的:

print(jr.normal(key))
print(jr.normal(key))

想要新的随机数?必须 split

key = jr.PRNGKey(0)
key, subkey = jr.split(key)

x1 = jr.normal(subkey)

每次 split,都会得到新的子 key。

这样做的好处是:

  • 可复现

  • 并行安全

  • 适合 jit / vmap

  • 不容易出现“共享随机状态”的混乱问题


6)性能到底快在哪?不要只看语法,要看执行方式

很多人以为 JAX 快,是因为它“写法更高级”。

其实不是。

JAX 快的核心原因,是它把计算尽可能交给:

  • 编译器

  • 向量化

  • GPU / TPU

  • 算子融合


一组很有代表性的对比

在作者给出的 CUDA GPU 环境里,测试结果大致是这样的:

  • 纯 Python 循环0.021s

  • NumPy 向量化0.000274s

  • JAX eager0.000110s

  • JAX JIT 第一次调用0.044s 左右

  • JAX JIT 稳定运行13–30 微秒


怎么理解这组数据?

1)纯 Python 最慢

原因很直接:循环和解释器开销太大。

2)NumPy 快很多

因为它把计算交给了底层编译代码。

3)JAX eager 已经能跑 GPU kernel

所以在某些环境下会很快。

4)JAX JIT 第一次慢

因为要先编译。

5)JAX JIT 后面最快

因为它跑的是“编译后的、融合过的、适合硬件的机器代码”。


但有一个重要提醒

这类性能对比,一定要看运行环境

  • NumPy 通常跑在 CPU 上

  • JAX 可能跑在 GPU 上

  • JAX eager / JIT 的性能也会受数据是否已经在设备上、数组大小、编译是否完成等影响

所以更准确的说法是:

JAX 的优势不只是“快”,而是“同样的代码,可以被编译成更适合硬件的执行方式”。


7)JAX 不只是做计算,也能直接写神经网络

JAX 当然也可以训练模型。

一个最典型的例子,就是用它从零写一个小型多层感知机(MLP)。


一个典型训练流程

  1. 加载数据

  2. 做标准化和 one-hot 编码

  3. 初始化参数

  4. 前向传播

  5. 计算损失

  6. 自动求导

  7. 参数更新

  8. 评估准确率


核心代码思路

参数初始化

def initialize_params(input_dim, hidden_dim1, hidden_dim2, output_dim, key):
    keys = jax.random.split(key, 3)
    W1 = jax.random.normal(keys[0], (input_dim, hidden_dim1))
    b1 = jnp.zeros((hidden_dim1,))
    W2 = jax.random.normal(keys[1], (hidden_dim1, hidden_dim2))
    b2 = jnp.zeros((hidden_dim2,))
    W3 = jax.random.normal(keys[2], (hidden_dim2, output_dim))
    b3 = jnp.zeros((output_dim,))
    return W1, b1, W2, b2, W3, b3

前向传播

def forward(params, X):
    W1, b1, W2, b2, W3, b3 = params
    Z1 = jnp.dot(X, W1) + b1
    A1 = jax.nn.relu(Z1)
    Z2 = jnp.dot(A1, W2) + b2
    A2 = jax.nn.relu(Z2)
    logits = jnp.dot(A2, W3) + b3
    return logits

损失函数

工程里更推荐用 log_softmax,数值更稳:

def loss_fn(params, x, y, l2_reg=1e-4):
    logits = forward(params, x)
    log_probs = jax.nn.log_softmax(logits)
    l2_loss = l2_reg * sum(jnp.sum(w ** 2) for w in params[::2])
    ce = -jnp.mean(jnp.sum(y * log_probs, axis=1))
    return ce + l2_loss

一步训练

@jax.jit
def train_step(params, x, y, lr=0.01):
    grads = jax.grad(loss_fn)(params, x, y)
    return tuple(p - lr * g for p, g in zip(params, grads))

这里很关键:

  • grad 负责自动求导

  • jit 负责把训练步骤编译掉

这也是 JAX 训练任务里最经典的组合。


准确率

def accuracy(params, x, y):
    logits = forward(params, x)
    preds = jnp.argmax(logits, axis=1)
    targets = jnp.argmax(y, axis=1)
    return jnp.mean(preds == targets)

训练结果说明了什么?

以 Iris 这种三分类数据集为例,随机猜的准确率大约是 33%

如果训练后测试准确率能到 96%+,说明这个小模型已经学到了有效规律。

这说明两件事:

  1. JAX 完全可以写完整的训练闭环

  2. jit + grad + vmap 这套组合,对训练任务非常有价值


8)这篇文章最该记住的几个结论

如果你只想记住最关键的点,可以记这几条:

结论 1:JAX 不是“大而全框架”,而是“高性能计算底座”

它更适合研究、实验和性能敏感的任务。

结论 2:JAX 像 NumPy,但更严格

数组不可变、随机数显式 key、执行更偏函数式。

结论 3:JIT 是 JAX 快的核心原因之一

第一次慢是编译成本,后面快是因为跑的是机器代码。

结论 4:自动微分让“求导”变成现成能力

这对机器学习和优化问题特别重要。

结论 5:XLA 是幕后真正的优化器

算子融合、减少中间结果、提升硬件利用率,都是它在做。

结论 6:vmap 和 GPU / TPU 是 JAX 的加速利器

它们让批量计算更自然,也更容易跑满硬件。


9)什么时候适合用 JAX?

JAX 特别适合这些场景:

  • 做机器学习研究

  • 写需要大量求导的算法

  • 追求高性能数值计算

  • 需要在 GPU / TPU 上跑同一套逻辑

  • 想把代码写得既像数学公式,又能高效执行

如果你做的是传统业务开发、快速工程落地,PyTorch 往往更省心。

如果你重视研究灵活性、编译优化和性能,JAX 非常值得学。

0

评论区