@@ -73,6 +73,7 @@ struct jit_ctx {
const struct bpf_prog *prog;
int idx;
bool write;
+ int max_insns;
int epilogue_offset;
int *offset;
int exentry_idx;
@@ -90,11 +91,15 @@ struct bpf_plt {
#define PLT_TARGET_SIZE sizeof_field(struct bpf_plt, target)
#define PLT_TARGET_OFFSET offsetof(struct bpf_plt, target)
-static inline void emit(const u32 insn, struct jit_ctx *ctx)
+static inline void __emit(u32 insn, struct jit_ctx *ctx, int idx)
{
- if (ctx->image != NULL && ctx->write)
- ctx->image[ctx->idx] = cpu_to_le32(insn);
+ if (ctx->image != NULL && ctx->write && idx < ctx->max_insns)
+ ctx->image[idx] = cpu_to_le32(insn);
+}
+static inline void emit(u32 insn, struct jit_ctx *ctx)
+{
+ __emit(insn, ctx, ctx->idx);
ctx->idx++;
}
@@ -1544,6 +1549,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
build_epilogue(&ctx);
build_plt(&ctx);
+ ctx.max_insns = ctx.idx;
extable_align = __alignof__(struct exception_table_entry);
extable_size = prog->aux->num_exentries *
sizeof(struct exception_table_entry);
@@ -1603,7 +1609,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
build_plt(&ctx);
/* 3. Extra pass to validate JITed code. */
- if (validate_ctx(&ctx)) {
+ if (WARN_ON_ONCE(ctx.idx > ctx.max_insns) || validate_ctx(&ctx)) {
bpf_jit_binary_free(header);
prog = orig_prog;
goto out_off;
@@ -1687,7 +1693,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
int args_off, int retval_off, int run_ctx_off,
bool save_ret)
{
- __le32 *branch;
+ int bridx;
u64 enter_prog;
u64 exit_prog;
struct bpf_prog *p = l->link.prog;
@@ -1725,7 +1731,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
/* if (__bpf_prog_enter(prog) == 0)
* goto skip_exec_of_prog;
*/
- branch = ctx->image + ctx->idx;
+ bridx = ctx->idx;
emit(A64_NOP, ctx);
/* save return value to callee saved register x20 */
@@ -1740,10 +1746,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
if (save_ret)
emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
- if (ctx->image) {
- int offset = &ctx->image[ctx->idx] - branch;
- *branch = cpu_to_le32(A64_CBZ(1, A64_R(0), offset));
- }
+ __emit(A64_CBZ(1, A64_R(0), ctx->idx - bridx), ctx, bridx);
/* arg1: prog */
emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
@@ -1757,7 +1760,7 @@ static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
int args_off, int retval_off, int run_ctx_off,
- __le32 **branches)
+ int *bridx)
{
int i;
@@ -1775,7 +1778,7 @@ static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
/* Save the location of branch, and generate a nop.
* This nop will be replaced with a cbnz later.
*/
- branches[i] = ctx->image + ctx->idx;
+ bridx[i] = ctx->idx;
emit(A64_NOP, ctx);
}
}
@@ -1828,7 +1831,7 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
bool save_ret;
- __le32 **branches = NULL;
+ int *bridx = NULL;
/* trampoline stack layout:
* [ parent ip ]
@@ -1936,13 +1939,12 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
flags & BPF_TRAMP_F_RET_FENTRY_RET);
if (fmod_ret->nr_links) {
- branches = kcalloc(fmod_ret->nr_links, sizeof(__le32 *),
- GFP_KERNEL);
- if (!branches)
+ bridx = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
+ if (!bridx)
return -ENOMEM;
invoke_bpf_mod_ret(ctx, fmod_ret, args_off, retval_off,
- run_ctx_off, branches);
+ run_ctx_off, bridx);
}
if (flags & BPF_TRAMP_F_CALL_ORIG) {
@@ -1957,11 +1959,10 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
emit(A64_NOP, ctx);
}
- /* update the branches saved in invoke_bpf_mod_ret with cbnz */
- for (i = 0; i < fmod_ret->nr_links && ctx->image != NULL; i++) {
- int offset = &ctx->image[ctx->idx] - branches[i];
- *branches[i] = cpu_to_le32(A64_CBNZ(1, A64_R(10), offset));
- }
+ /* update the bridx saved in invoke_bpf_mod_ret with cbnz */
+ for (i = 0; i < fmod_ret->nr_links; i++)
+ __emit(A64_CBNZ(1, A64_R(10), ctx->idx - bridx[i]), ctx,
+ bridx[i]);
for (i = 0; i < fexit->nr_links; i++)
invoke_bpf_prog(ctx, fexit->links[i], args_off, retval_off,
@@ -2004,7 +2005,7 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
if (ctx->image)
bpf_flush_icache(ctx->image, ctx->image + ctx->idx);
- kfree(branches);
+ kfree(bridx);
return ctx->idx;
}
@@ -2018,35 +2019,27 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
int nargs = m->nr_args;
int max_insns = ((long)image_end - (long)image) / AARCH64_INSN_SIZE;
struct jit_ctx ctx = {
- .image = NULL,
+ .image = image,
.idx = 0,
.write = true,
+ .max_insns = max_insns,
};
/* the first 8 arguments are passed by registers */
if (nargs > 8)
return -ENOTSUPP;
+ jit_fill_hole(image, (unsigned int)(image_end - image));
+
ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nargs, flags);
+
if (ret < 0)
return ret;
if (ret > max_insns)
return -EFBIG;
- ctx.image = image;
- ctx.idx = 0;
-
- jit_fill_hole(image, (unsigned int)(image_end - image));
- ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nargs, flags);
-
- if (ret > 0 && validate_code(&ctx) < 0)
- ret = -EINVAL;
-
- if (ret > 0)
- ret *= AARCH64_INSN_SIZE;
-
- return ret;
+ return validate_code(&ctx) < 0 ? -EINVAL : ret * AARCH64_INSN_SIZE;
}
static bool is_long_jump(void *ip, void *target)