在JVM中,JIT (Just-in-Time) 即时编译指的是在Java程序运行过程中JVM优化部分指令为本地指令,从而大幅提升性能。在上一篇文章写一个玩具Java虚拟机中实现了一个基本可以运行Java字节码的JVM。本篇文章描述我是如何在这个玩具JVM中实现JIT的。
推荐文章“How to JIT - an introduction”,介绍了JIT的基本实现原理。作者把JIT分为两个阶段:
- 运行期生成机器代码(本地指令)
- 执行机器代码
生成机器代码很好理解,就是一个JVM指令到机器指令的翻译;而执行机器代码,原理上是利用了OS提供了API可以分配可以执行的内存,然后往这块内存中写入机器码,从而实现运行期可以执行动态生成的机器码功能。
我们可以利用这个原理来实现JIT,但是未免太底层了点,需要做很多工作来完成这件事情。我们可以利用libjit来简化实现。这个作者博客里还有些libjit的教程,其中part 1值得阅读。 简单来说,libjit对机器指令做了抽象,利用它的API来描述一个函数包含了哪些指令,实现了什么功能。然后具体的指令生成以及指令执行则交给libjit完成。
例如以下使用libjit的代码:
1
2
3
4
5
6
7
8
// t = u
jit_insn_store(F, t, u); // 类似 mov 指令
// u = v
jit_insn_store(F, u, v);
// v = t % v
jit_value_t rem = jit_insn_rem(F, t, v); // 求余指令
jit_insn_store(F, v, rem);
所以,我们需要做的,就是将JVM的字节码,翻译成一堆libjit的API调用。但是我希望能够稍微做点抽象,我们写个翻译器,能够将JVM这种基于栈的指令,翻译成基于寄存器的指令,才方便后面无论是使用libjit还是直接翻译成机器码。
指令翻译
要将基于栈的指令翻译成基于寄存器的指令(类似),仔细想想主要解决两个问题:
- 去除操作数栈
- 跳转指令所需要的标签
去除操作数栈,我使用了一个简单办法,因为JVM中执行字节码时,我们是可以知道每条指令执行时栈的具体情况的,也就是每条指令执行时,它依赖的操作数在栈的哪个位置是清楚的。例如,假设某个函数开头有以下指令:
1
2
3
4
5
opcode [04] - 0000: iconst_1 # [1]
opcode [3C] - 0001: istore_1 # []
opcode [1B] - 0002: iload_1 # [1]
opcode [1A] - 0003: iload_0 # [1, N]
opcode [68] - 0004: imul # [1 * N]
当执行imul指令时,就可以知道该指令使用栈s[0]、s[1]的值,做完计算后写回s[0]。所以,类似JVM中局部变量用数字来编号,我也为栈元素编号,这些编号的元素全部被视为局部变量,所以这些指令全部可以转换为基于局部变量的指令。为了和JVM中本身的局部变量统一,我们将栈元素编号从局部变量后面开始。假设以上函数有2个局部变量,那么栈元素从编号2开始,局部变量编号从0开始。以上指令可以翻译为:
1
2
3
4
5
mov 1, $2 # 常量1写入变量2
lod $2, $1 # 变量2写入变量1
lod $1, $2 # 变量1写回变量2
lod $0, $3 # 变量0写入变量3
mul $3, $2 # 变量3与变量2相乘,写回变量2
这里,我们定义了自己的中间指令集(IR),这个中间指令集存在的意义在于,在将来翻译为某个平台的机器码时,它比JVM的指令集更容易理解。中间指令集是一种抽象,方便基于它们使用libjit或其他手段翻译成机器码。
不过,我们看到上面的指令非常冗余。要优化掉这种冗余相对比较复杂,所以本文暂时不讨论这个问题。
这个中间指令基于局部变量的方式,是利于JIT下游做各种具体实现的,例如是否直接转换为通用寄存器,即一定范围的局部变量数是可以直接使用寄存器实现,超出该范围的局部变量则放在栈上,用栈模拟;或者全部用栈模拟。注意在机器指令中栈元素是可以直接偏移访问的,不同于“基于栈的虚拟机”中的栈。
以上指令,我们可以简单地为每条指令设定如何翻译为libjit的调用,例如mov指令:
1
2
3
4
static void build_mov(BuildContext* context, const Instruction* inst) {
jit_value_t c = jit_value_create_nint_constant(context->F, jit_type_int, inst->op1);
jit_insn_store(context->F, context->vars[inst->op2], c);
}
例如mul指令:
1
2
3
4
5
static void build_mul(BuildContext* context, const Instruction* inst) {
// context->vars就是前面说的局部变量表,包含了JVM中的局部变量及操作数栈
jit_value_t tmp = jit_insn_mul(context->F, context->vars[inst->op1], context->vars[inst->op2]);
jit_insn_store(context->F, context->vars[inst->op1], tmp);
}
接下来说另一个问题:跳转指令的标签。在机器指令中,跳转指令跳转的目标位置是一个绝对地址,或者像JVM中一样,是一个相对地址。但是在我们的中间指令集中,是没有地址的概念的,在翻译为机器指令时,也无法获取地址。所以,我们一般是增加了一个特殊指令label
,用于打上一个标签,设置一个标签编号,相当于是一个地址。在后面的跳转指令中,则跳转的是这个标签编号。
所以,我们需要在翻译JVM指令到我们的中间指令时,识别出哪些地方需要打标签;并且在翻译跳转类指令时,翻译为跳转到某个编号的标签。
例如以下指令:
1
2
3
4
5
6
7
opcode [04] - 0000: iconst_1
opcode [3C] - 0001: istore_1
opcode [1B] - 0002: iload_1 # 会被调整,需要在此打标签
opcode [1A] - 0003: iload_0
...
opcode [1A] - 0010: iload_0
opcode [9D] - 0011: ifgt -9 # pc-9,也就是跳转到0002位置
为了打上标签,我们的翻译需要遍历两遍指令,第一遍用来找出所有标签,第二遍才做真正的翻译。
1
2
3
4
5
6
7
8
9
10
11
12
// 该函数遍历所有指令,找出所有需要打标签的指令位置
private List<Integer> createLabels(List<InstParser.Instruction> jbytecode) {
List<Integer> labels = new LinkedList<>();
for (InstParser.Instruction i : jbytecode) {
LabelParser labelParser = labelParsers.get(i.opcode);
if (labelParser != null) { // 不为空时表示是跳转指令
int pc = labelParser.parse(i); // 不同的跳转指令地址解析不同,解析得到跳转的目标地址
labels.add(pc); // 保存起来返回
}
}
return labels;
}
然后在翻译指令的过程中,发现当前翻译的指令地址是跳转的目标位置时,则生成标签指令:
1
2
3
4
5
6
7
8
9
List<Integer> labels = createLabels(jbytecode);
...
Iterator<InstParser.Instruction> it = jbytecode.iterator();
while (it.hasNext()) {
InstParser.Instruction inst = it.next();
int label = labels.indexOf(inst.pc);
if (label >= 0) {
state.addIR(new Inst(op_label, label)); // 生成标签指令,label就是标签编号
}
在处理跳转指令时,则填入标签编号:
1
2
3
4
5
6
7
translators.put(Opcode.op_ifgt, (state, inst, iterator) -> {
short offset = (short)((inst.op1 << 8) + inst.op2);
int pc = inst.pc + offset;
int label = state.findLabel(pc); // 找到标签编号
int var = state.popStack();
state.addIR(new Inst(op_jmp_gt, var, label));
});
我们的中间指令集中,跳转指令和标签指令就为:
1
2
label #N // 打上标签N
jmp_gt $var, #N // 如果$var>0,跳转到标签#N
看下使用libjit如何翻译以上两条指令:
1
2
3
4
5
6
7
8
9
10
11
12
static void build_label(BuildContext* context, const Instruction* inst) {
// 打上标签,inst->op1为标签编号N,对应写到context->labels[N]中
jit_insn_label(context->F, &context->labels[inst->op1]);
}
static void build_jmp_gt(BuildContext* context, const Instruction* inst) {
jit_value_t const0 = jit_value_create_nint_constant(context->F, jit_type_int, 0);
// 是否>0
jit_value_t cmp_v_0 = jit_insn_gt(context->F, context->vars[inst->op1], const0);
// 大于0则跳转到标签inst->op2
jit_insn_branch_if(context->F, cmp_v_0, &context->labels[inst->op2]);
}
代码贴得有点多,大概懂原理就行了。
在JIT中还有个很重要的过程,就是判定哪些代码需要被JIT。这里只是简单地尝试对每一个函数进行JIT,发现所有指令都能够被JIT时就JIT。
指令执行
在上一篇文章中,执行每个JVM函数时,都会有一个Frame与之关联。所以,在这里只要函数被JIT了,对应的帧就会包含被编译的代码,也就是libjit中的jit_function_t
。在该Frame被执行时,就调用libjit执行该函数:
1
2
3
4
5
6
7
8
9
10
11
12
13
private void runNative() {
int arg_cnt = getArgsCount();
int[] args = new int[arg_cnt];
for (int i = 0; i < arg_cnt; ++i) {
if (mLocals[i].type != Slot.Type.NUM) throw new RuntimeException("only supported number arg in jit");
args[i] = mLocals[i].i;
}
int ret = mJIT.invoke(args); // mJIT后面会看到,主要就是将参数以数组形式传递到libjit中,并做JIT函数调用
mThread.popFrame();
if (hasReturnType() && mThread.topFrame() != null) {
mThread.topFrame().pushInt(ret); // 目前只支持int返回类型
}
}
实现
以上就是整个JIT的过程,主要工作集中于JVM指令到中间指令,中间指令到libjit API调用。整个实现包含以下模块:
1
2
3
4
5
6
7
8
9
10
11
+-----------+ +----------+
| ASM | | libjit |
| | <-----+ API call |
+-----------+ +----+-----+
^
|
+-----------+ +----+-----+
| JVM | | IR code |
| bytecode +-----> | |
+-----------+ +----------+
JVM byte code及IR code的处理是在Java中完成的;处理完后将IR code输出为byte[],通过JNI调用包装好的C API。这个C API则是基于libjit,将IR code翻译为libjit的API调用。指令翻译完后调用libjit的API得到最终的ASM机器指令。
同样,要执行指令时,也是通过JNI调用这个C API。JNI交互全部包装在以下类中:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public class ToyJIT {
private long jitPtr = 0;
public void initialize(byte[] bytes, int maxLocals, int maxLabels, int argCnt, int retType) {
jitPtr = compile(bytes, maxLocals, maxLabels, argCnt, retType);
}
public int invoke(int... args) {
return invoke(jitPtr, args);
}
static {
System.loadLibrary("toyjit");
}
private static native long compile(byte[] bytes, int maxLocals, int maxLabels, int argCnt, int retType);
private static native int invoke(long jitPtr, int[] args);
即,libtoyjit.so
主要提供翻译接口 compile
及执行接口 invoke
。
性能对比
简单测试了下一个阶乘计算函数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public static int fac2(int n) {
int r = 1;
do {
r = r * n;
n = n - 1;
} while (n > 0);
return r;
}
...
int i = 0;
for (; i < 10000; ++i) {
fac2(100);
}
...
fac2
函数会被JIT,测试发现不开启JIT时需要16秒,开启后1秒,差距还是很明显的。
最后奉上代码,toy_jit,就是前面说的C API部分,翻译IR到libjit API call,包装接口用于JNI调用。redhat 7.2下编译,需要先编译出libjit,我是直接clone的libjit master编译的。Java部分还是在toy_jvm中。