Message ID | 20230713060724.389084-1-yhs@fb.com (mailing list archive) |
---|---|
State | Superseded |
Delegated to: | BPF |
Headers | show |
Series | bpf: Support new insns from cpu v4 | expand |
On Wed, Jul 12, 2023 at 11:07:24PM -0700, Yonghong Song wrote: > > @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) > LDST(DW, u64) > #undef LDST > > +#define LDS(SIZEOP, SIZE) \ LDSX ? > + LDX_MEMSX_##SIZEOP: \ > + DST = *(SIZE *)(unsigned long) (SRC + insn->off); \ > + CONT; > + > + LDS(B, s8) > + LDS(H, s16) > + LDS(W, s32) > +#undef LDS ... > @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) || > insn->code == (BPF_LDX | BPF_MEM | BPF_H) || > insn->code == (BPF_LDX | BPF_MEM | BPF_W) || > - insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) { > + insn->code == (BPF_LDX | BPF_MEM | BPF_DW) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) { > type = BPF_READ; > } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) || > insn->code == (BPF_STX | BPF_MEM | BPF_H) || > @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > */ > case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED: > if (type == BPF_READ) { > + /* it is hard to differentiate that the > + * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX, > + * let us use insn->imm to remember it. > + */ > + insn->imm = BPF_MODE(insn->code); That's a fragile approach. And the evidence is in this patch. This part of interpreter: LDX_PROBE_MEM_##SIZEOP: \ bpf_probe_read_kernel(&DST, sizeof(SIZE), \ (const void *)(long) (SRC + insn->off)); \ DST = *((SIZE *)&DST); \ wasn't updated to handle sign extension. How about #define BPF_PROBE_MEMSX 0x40 /* same as BPF_IND */ and handle it in JITs and interpreter. We need a selftest for BTF style access to signed fields to make sure both interpreter and JIT handling of BPF_PROBE_MEMSX is tested.
On 7/14/23 11:13 AM, Alexei Starovoitov wrote: > On Wed, Jul 12, 2023 at 11:07:24PM -0700, Yonghong Song wrote: >> >> @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) >> LDST(DW, u64) >> #undef LDST >> >> +#define LDS(SIZEOP, SIZE) \ > > LDSX ? Ack. > >> + LDX_MEMSX_##SIZEOP: \ >> + DST = *(SIZE *)(unsigned long) (SRC + insn->off); \ >> + CONT; >> + >> + LDS(B, s8) >> + LDS(H, s16) >> + LDS(W, s32) >> +#undef LDS > > ... > >> @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) >> if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) || >> insn->code == (BPF_LDX | BPF_MEM | BPF_H) || >> insn->code == (BPF_LDX | BPF_MEM | BPF_W) || >> - insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) { >> + insn->code == (BPF_LDX | BPF_MEM | BPF_DW) || >> + insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) || >> + insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) || >> + insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) { >> type = BPF_READ; >> } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) || >> insn->code == (BPF_STX | BPF_MEM | BPF_H) || >> @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) >> */ >> case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED: >> if (type == BPF_READ) { >> + /* it is hard to differentiate that the >> + * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX, >> + * let us use insn->imm to remember it. >> + */ >> + insn->imm = BPF_MODE(insn->code); > > That's a fragile approach. > And the evidence is in this patch. > This part of interpreter: > LDX_PROBE_MEM_##SIZEOP: \ > bpf_probe_read_kernel(&DST, sizeof(SIZE), \ > (const void *)(long) (SRC + insn->off)); \ > DST = *((SIZE *)&DST); \ > > wasn't updated to handle sign extension. Thanks for catching this! > > How about > #define BPF_PROBE_MEMSX 0x40 /* same as BPF_IND */ > > and handle it in JITs and interpreter. Good idea. Will do. > We need a selftest for BTF style access to signed fields to make sure both > interpreter and JIT handling of BPF_PROBE_MEMSX is tested. Will do.
On Wed, 2023-07-12 at 23:07 -0700, Yonghong Song wrote: > > Add interpreter/jit support for new sign-extension load insns > > which adds a new mode (BPF_MEMSX). > > Also add verifier support to recognize these insns and to > > do proper verification with new insns. In verifier, besides > > to deduce proper bounds for the dst_reg, probed memory access > > is handled by remembering insn mode in insn->imm field so later > > on proper jit insns can be emitted. > > > > Signed-off-by: Yonghong Song <yhs@fb.com> > > --- > > arch/x86/net/bpf_jit_comp.c | 32 ++++++++- > > include/uapi/linux/bpf.h | 1 + > > kernel/bpf/core.c | 13 ++++ > > kernel/bpf/verifier.c | 125 +++++++++++++++++++++++++++------ > > tools/include/uapi/linux/bpf.h | 1 + > > 5 files changed, 151 insertions(+), 21 deletions(-) > > > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > > index 438adb695daa..addeea95f397 100644 > > --- a/arch/x86/net/bpf_jit_comp.c > > +++ b/arch/x86/net/bpf_jit_comp.c > > @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > > *pprog = prog; > > } > > > > +/* LDX: dst_reg = *(s8*)(src_reg + off) */ > > +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > > +{ > > + u8 *prog = *pprog; > > + > > + switch (size) { > > + case BPF_B: > > + /* Emit 'movsx rax, byte ptr [rax + off]' */ > > + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE); > > + break; > > + case BPF_H: > > + /* Emit 'movsx rax, word ptr [rax + off]' */ > > + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF); > > + break; > > + case BPF_W: > > + /* Emit 'movsx rax, dword ptr [rax+0x14]' */ > > + EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63); > > + break; > > + } > > + emit_insn_suffix(&prog, src_reg, dst_reg, off); > > + *pprog = prog; > > +} > > + > > /* STX: *(u8*)(dst_reg + off) = src_reg */ > > static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > > { > > @@ -1370,6 +1393,9 @@ st: if (is_imm8(insn->off)) > > case BPF_LDX | BPF_PROBE_MEM | BPF_W: > > case BPF_LDX | BPF_MEM | BPF_DW: > > case BPF_LDX | BPF_PROBE_MEM | BPF_DW: > > + case BPF_LDX | BPF_MEMSX | BPF_B: > > + case BPF_LDX | BPF_MEMSX | BPF_H: > > + case BPF_LDX | BPF_MEMSX | BPF_W: > > insn_off = insn->off; > > > > if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { > > @@ -1415,7 +1441,11 @@ st: if (is_imm8(insn->off)) > > start_of_ldx = prog; > > end_of_jmp[-1] = start_of_ldx - end_of_jmp; > > } > > - emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > > + if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) || > > + BPF_MODE(insn->code) == BPF_MEMSX) > > + emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > > + else > > + emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > > if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { > > struct exception_table_entry *ex; > > u8 *_insn = image + proglen + (start_of_ldx - temp); > > diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h > > index 600d0caebbd8..c7196302d1eb 100644 > > --- a/include/uapi/linux/bpf.h > > +++ b/include/uapi/linux/bpf.h > > @@ -19,6 +19,7 @@ > > > > /* ld/ldx fields */ > > #define BPF_DW 0x18 /* double word (64-bit) */ > > +#define BPF_MEMSX 0x80 /* load with sign extension */ > > #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ > > #define BPF_XADD 0xc0 /* exclusive add - legacy name */ > > > > diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c > > index dc85240a0134..8a1cc658789e 100644 > > --- a/kernel/bpf/core.c > > +++ b/kernel/bpf/core.c > > @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base); > > INSN_3(LDX, MEM, H), \ > > INSN_3(LDX, MEM, W), \ > > INSN_3(LDX, MEM, DW), \ > > + INSN_3(LDX, MEMSX, B), \ > > + INSN_3(LDX, MEMSX, H), \ > > + INSN_3(LDX, MEMSX, W), \ > > /* Immediate based. */ \ > > INSN_3(LD, IMM, DW) > > > > @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) > > LDST(DW, u64) > > #undef LDST > > > > +#define LDS(SIZEOP, SIZE) \ > > + LDX_MEMSX_##SIZEOP: \ > > + DST = *(SIZE *)(unsigned long) (SRC + insn->off); \ > > + CONT; > > + > > + LDS(B, s8) > > + LDS(H, s16) > > + LDS(W, s32) > > +#undef LDS > > + > > #define ATOMIC_ALU_OP(BOP, KOP) \ > > case BOP: \ > > if (BPF_SIZE(insn->code) == BPF_W) \ > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > > index 81a93eeac7a0..fbe4ca72d4c1 100644 > > --- a/kernel/bpf/verifier.c > > +++ b/kernel/bpf/verifier.c > > @@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size) > > __reg_combine_64_into_32(reg); > > } > > > > +static void set_sext64_default_val(struct bpf_reg_state *reg, int size) > > +{ > > + if (size == 1) { > > + reg->smin_value = reg->s32_min_value = S8_MIN; > > + reg->smax_value = reg->s32_max_value = S8_MAX; > > + } else if (size == 2) { > > + reg->smin_value = reg->s32_min_value = S16_MIN; > > + reg->smax_value = reg->s32_max_value = S16_MAX; > > + } else { > > + /* size == 4 */ > > + reg->smin_value = reg->s32_min_value = S32_MIN; > > + reg->smax_value = reg->s32_max_value = S32_MAX; > > + } > > + reg->umin_value = reg->u32_min_value = 0; > > + reg->umax_value = U64_MAX; > > + reg->u32_max_value = U32_MAX; > > + reg->var_off = tnum_unknown; > > +} > > + > > +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) > > +{ > > + u64 top_smax_value, top_smin_value; > > + s64 init_s64_max, init_s64_min, s64_max, s64_min; > > + u64 num_bits = size * 8; > > + > > + top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; > > + top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; > > + > > + if (top_smax_value != top_smin_value) > > + goto out; > > + > > + /* find the s64_min and s64_min after sign extension */ > > + if (size == 1) { > > + init_s64_max = (s8)reg->smax_value; > > + init_s64_min = (s8)reg->smin_value; > > + } else if (size == 2) { > > + init_s64_max = (s16)reg->smax_value; > > + init_s64_min = (s16)reg->smin_value; > > + } else { > > + /* size == 4 */ > > + init_s64_max = (s32)reg->smax_value; > > + init_s64_min = (s32)reg->smin_value; > > + } > > + > > + s64_max = max(init_s64_max, init_s64_min); > > + s64_min = min(init_s64_max, init_s64_min); > > + > > + if (s64_max >= 0 && s64_min >= 0) { > > + reg->smin_value = reg->s32_min_value = s64_min; > > + reg->smax_value = reg->s32_max_value = s64_max; > > + reg->umin_value = reg->u32_min_value = s64_min; > > + reg->umax_value = reg->u32_max_value = s64_max; > > + reg->var_off = tnum_range(s64_min, s64_max); > > + return; > > + } > > + > > + if (s64_min < 0 && s64_max < 0) { > > + reg->smin_value = reg->s32_min_value = s64_min; > > + reg->smax_value = reg->s32_max_value = s64_max; > > + reg->umin_value = (u64)s64_max; > > + reg->umax_value = (u64)s64_min; > > + reg->u32_min_value = (u32)s64_max; > > + reg->u32_max_value = (u32)s64_min; > > + reg->var_off = tnum_range((u64)s64_max, (u64)s64_min); > > + return; > > + } > > + > > +out: > > + set_sext64_default_val(reg, size); > > +} > > + > > static bool bpf_map_is_rdonly(const struct bpf_map *map) > > { > > /* A map is considered read-only if the following condition are true: > > @@ -5815,7 +5886,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map) > > !bpf_map_write_active(map); > > } > > > > -static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) > > +static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val, > > + bool is_ldsx) > > { > > void *ptr; > > u64 addr; > > @@ -5828,13 +5900,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) > > > > switch (size) { > > case sizeof(u8): > > - *val = (u64)*(u8 *)ptr; > > + *val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr; > > break; > > case sizeof(u16): > > - *val = (u64)*(u16 *)ptr; > > + *val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr; > > break; > > case sizeof(u32): > > - *val = (u64)*(u32 *)ptr; > > + *val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr; > > break; > > case sizeof(u64): > > *val = *(u64 *)ptr; > > @@ -6248,7 +6320,7 @@ static int check_stack_access_within_bounds( > > */ > > static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno, > > int off, int bpf_size, enum bpf_access_type t, > > - int value_regno, bool strict_alignment_once) > > + int value_regno, bool strict_alignment_once, bool is_ldsx) > > { > > struct bpf_reg_state *regs = cur_regs(env); > > struct bpf_reg_state *reg = regs + regno; > > @@ -6309,7 +6381,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn > > u64 val = 0; > > > > err = bpf_map_direct_read(map, map_off, size, > > - &val); > > + &val, is_ldsx); > > if (err) > > return err; > > > > @@ -6479,8 +6551,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn > > > > if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ && > > regs[value_regno].type == SCALAR_VALUE) { > > - /* b/h/w load zero-extends, mark upper bits as known 0 */ > > - coerce_reg_to_size(®s[value_regno], size); > > + if (!is_ldsx) > > + /* b/h/w load zero-extends, mark upper bits as known 0 */ > > + coerce_reg_to_size(®s[value_regno], size); > > + else > > + coerce_reg_to_size_sx(®s[value_regno], size); > > } > > return err; > > } > > @@ -6572,17 +6647,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i > > * case to simulate the register fill. > > */ > > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > > - BPF_SIZE(insn->code), BPF_READ, -1, true); > > + BPF_SIZE(insn->code), BPF_READ, -1, true, false); > > if (!err && load_reg >= 0) > > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > > BPF_SIZE(insn->code), BPF_READ, load_reg, > > - true); > > + true, false); > > if (err) > > return err; > > > > /* Check whether we can write into the same memory. */ > > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > > - BPF_SIZE(insn->code), BPF_WRITE, -1, true); > > + BPF_SIZE(insn->code), BPF_WRITE, -1, true, false); > > if (err) > > return err; > > > > @@ -6828,7 +6903,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno, > > return zero_size_allowed ? 0 : -EACCES; > > > > return check_mem_access(env, env->insn_idx, regno, offset, BPF_B, > > - atype, -1, false); > > + atype, -1, false, false); > > } > > > > fallthrough; > > @@ -7200,7 +7275,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn > > /* we write BPF_DW bits (8 bytes) at a time */ > > for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) { > > err = check_mem_access(env, insn_idx, regno, > > - i, BPF_DW, BPF_WRITE, -1, false); > > + i, BPF_DW, BPF_WRITE, -1, false, false); > > if (err) > > return err; > > } > > @@ -7293,7 +7368,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id > > > > for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) { > > err = check_mem_access(env, insn_idx, regno, > > - i, BPF_DW, BPF_WRITE, -1, false); > > + i, BPF_DW, BPF_WRITE, -1, false, false); > > if (err) > > return err; > > } > > @@ -9437,7 +9512,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn > > */ > > for (i = 0; i < meta.access_size; i++) { > > err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B, > > - BPF_WRITE, -1, false); > > + BPF_WRITE, -1, false, false); > > if (err) > > return err; > > } > > @@ -16315,7 +16390,8 @@ static int do_check(struct bpf_verifier_env *env) > > */ > > err = check_mem_access(env, env->insn_idx, insn->src_reg, > > insn->off, BPF_SIZE(insn->code), > > - BPF_READ, insn->dst_reg, false); > > + BPF_READ, insn->dst_reg, false, > > + BPF_MODE(insn->code) == BPF_MEMSX); > > if (err) > > return err; > > > > @@ -16352,7 +16428,7 @@ static int do_check(struct bpf_verifier_env *env) > > /* check that memory (dst_reg + off) is writeable */ > > err = check_mem_access(env, env->insn_idx, insn->dst_reg, > > insn->off, BPF_SIZE(insn->code), > > - BPF_WRITE, insn->src_reg, false); > > + BPF_WRITE, insn->src_reg, false, false); > > if (err) > > return err; > > > > @@ -16377,7 +16453,7 @@ static int do_check(struct bpf_verifier_env *env) > > /* check that memory (dst_reg + off) is writeable */ > > err = check_mem_access(env, env->insn_idx, insn->dst_reg, > > insn->off, BPF_SIZE(insn->code), > > - BPF_WRITE, -1, false); > > + BPF_WRITE, -1, false, false); > > if (err) > > return err; > > > > @@ -16805,7 +16881,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env) > > > > for (i = 0; i < insn_cnt; i++, insn++) { > > if (BPF_CLASS(insn->code) == BPF_LDX && > > - (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) { > > + ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) || > > + insn->imm != 0)) { > > verbose(env, "BPF_LDX uses reserved fields\n"); > > return -EINVAL; > > } > > @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > > if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) || > > insn->code == (BPF_LDX | BPF_MEM | BPF_H) || > > insn->code == (BPF_LDX | BPF_MEM | BPF_W) || > > - insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) { > > + insn->code == (BPF_LDX | BPF_MEM | BPF_DW) || > > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) || > > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) || > > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) { Later in this function there is a code that deals with `is_narrower_load` condition (line 17785 in my case). This code handles the case when e.g. 1 byte is read from a 4 byte field. It does so by first converting such load to 4 byte load and than adding BPF_RSH and BPF_AND instructions. It appears to me that this code should handle sign extension as well. > > type = BPF_READ; > > } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) || > > insn->code == (BPF_STX | BPF_MEM | BPF_H) || > > @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > > */ > > case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED: > > if (type == BPF_READ) { > > + /* it is hard to differentiate that the > > + * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX, > > + * let us use insn->imm to remember it. > > + */ > > + insn->imm = BPF_MODE(insn->code); > > insn->code = BPF_LDX | BPF_PROBE_MEM | > > BPF_SIZE((insn)->code); > > env->prog->aux->num_exentries++; > > diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h > > index 600d0caebbd8..c7196302d1eb 100644 > > --- a/tools/include/uapi/linux/bpf.h > > +++ b/tools/include/uapi/linux/bpf.h > > @@ -19,6 +19,7 @@ > > > > /* ld/ldx fields */ > > #define BPF_DW 0x18 /* double word (64-bit) */ > > +#define BPF_MEMSX 0x80 /* load with sign extension */ > > #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ > > #define BPF_XADD 0xc0 /* exclusive add - legacy name */ > >
On Wed, 2023-07-12 at 23:07 -0700, Yonghong Song wrote: > Add interpreter/jit support for new sign-extension load insns > which adds a new mode (BPF_MEMSX). > Also add verifier support to recognize these insns and to > do proper verification with new insns. In verifier, besides > to deduce proper bounds for the dst_reg, probed memory access > is handled by remembering insn mode in insn->imm field so later > on proper jit insns can be emitted. > > Signed-off-by: Yonghong Song <yhs@fb.com> > --- > arch/x86/net/bpf_jit_comp.c | 32 ++++++++- > include/uapi/linux/bpf.h | 1 + > kernel/bpf/core.c | 13 ++++ > kernel/bpf/verifier.c | 125 +++++++++++++++++++++++++++------ > tools/include/uapi/linux/bpf.h | 1 + > 5 files changed, 151 insertions(+), 21 deletions(-) > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index 438adb695daa..addeea95f397 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > *pprog = prog; > } > > +/* LDX: dst_reg = *(s8*)(src_reg + off) */ > +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > +{ > + u8 *prog = *pprog; > + > + switch (size) { > + case BPF_B: > + /* Emit 'movsx rax, byte ptr [rax + off]' */ > + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE); > + break; > + case BPF_H: > + /* Emit 'movsx rax, word ptr [rax + off]' */ > + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF); > + break; > + case BPF_W: > + /* Emit 'movsx rax, dword ptr [rax+0x14]' */ > + EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63); > + break; > + } > + emit_insn_suffix(&prog, src_reg, dst_reg, off); > + *pprog = prog; > +} > + > /* STX: *(u8*)(dst_reg + off) = src_reg */ > static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) > { > @@ -1370,6 +1393,9 @@ st: if (is_imm8(insn->off)) > case BPF_LDX | BPF_PROBE_MEM | BPF_W: > case BPF_LDX | BPF_MEM | BPF_DW: > case BPF_LDX | BPF_PROBE_MEM | BPF_DW: > + case BPF_LDX | BPF_MEMSX | BPF_B: > + case BPF_LDX | BPF_MEMSX | BPF_H: > + case BPF_LDX | BPF_MEMSX | BPF_W: > insn_off = insn->off; > > if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { > @@ -1415,7 +1441,11 @@ st: if (is_imm8(insn->off)) > start_of_ldx = prog; > end_of_jmp[-1] = start_of_ldx - end_of_jmp; > } > - emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > + if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) || > + BPF_MODE(insn->code) == BPF_MEMSX) > + emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > + else > + emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); > if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { > struct exception_table_entry *ex; > u8 *_insn = image + proglen + (start_of_ldx - temp); > diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h > index 600d0caebbd8..c7196302d1eb 100644 > --- a/include/uapi/linux/bpf.h > +++ b/include/uapi/linux/bpf.h > @@ -19,6 +19,7 @@ > > /* ld/ldx fields */ > #define BPF_DW 0x18 /* double word (64-bit) */ > +#define BPF_MEMSX 0x80 /* load with sign extension */ > #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ > #define BPF_XADD 0xc0 /* exclusive add - legacy name */ > > diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c > index dc85240a0134..8a1cc658789e 100644 > --- a/kernel/bpf/core.c > +++ b/kernel/bpf/core.c > @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base); > INSN_3(LDX, MEM, H), \ > INSN_3(LDX, MEM, W), \ > INSN_3(LDX, MEM, DW), \ > + INSN_3(LDX, MEMSX, B), \ > + INSN_3(LDX, MEMSX, H), \ > + INSN_3(LDX, MEMSX, W), \ > /* Immediate based. */ \ > INSN_3(LD, IMM, DW) > > @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) > LDST(DW, u64) > #undef LDST > > +#define LDS(SIZEOP, SIZE) \ > + LDX_MEMSX_##SIZEOP: \ > + DST = *(SIZE *)(unsigned long) (SRC + insn->off); \ > + CONT; > + > + LDS(B, s8) > + LDS(H, s16) > + LDS(W, s32) > +#undef LDS > + > #define ATOMIC_ALU_OP(BOP, KOP) \ > case BOP: \ > if (BPF_SIZE(insn->code) == BPF_W) \ > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c > index 81a93eeac7a0..fbe4ca72d4c1 100644 > --- a/kernel/bpf/verifier.c > +++ b/kernel/bpf/verifier.c > @@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size) > __reg_combine_64_into_32(reg); > } > > +static void set_sext64_default_val(struct bpf_reg_state *reg, int size) > +{ > + if (size == 1) { > + reg->smin_value = reg->s32_min_value = S8_MIN; > + reg->smax_value = reg->s32_max_value = S8_MAX; > + } else if (size == 2) { > + reg->smin_value = reg->s32_min_value = S16_MIN; > + reg->smax_value = reg->s32_max_value = S16_MAX; > + } else { > + /* size == 4 */ > + reg->smin_value = reg->s32_min_value = S32_MIN; > + reg->smax_value = reg->s32_max_value = S32_MAX; > + } > + reg->umin_value = reg->u32_min_value = 0; > + reg->umax_value = U64_MAX; > + reg->u32_max_value = U32_MAX; > + reg->var_off = tnum_unknown; > +} > + > +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) > +{ > + u64 top_smax_value, top_smin_value; > + s64 init_s64_max, init_s64_min, s64_max, s64_min; > + u64 num_bits = size * 8; > + > + top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; > + top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; > + > + if (top_smax_value != top_smin_value) > + goto out; > + > + /* find the s64_min and s64_min after sign extension */ > + if (size == 1) { > + init_s64_max = (s8)reg->smax_value; > + init_s64_min = (s8)reg->smin_value; > + } else if (size == 2) { > + init_s64_max = (s16)reg->smax_value; > + init_s64_min = (s16)reg->smin_value; > + } else { > + /* size == 4 */ > + init_s64_max = (s32)reg->smax_value; > + init_s64_min = (s32)reg->smin_value; > + } > + > + s64_max = max(init_s64_max, init_s64_min); > + s64_min = min(init_s64_max, init_s64_min); > + > + if (s64_max >= 0 && s64_min >= 0) { > + reg->smin_value = reg->s32_min_value = s64_min; > + reg->smax_value = reg->s32_max_value = s64_max; > + reg->umin_value = reg->u32_min_value = s64_min; > + reg->umax_value = reg->u32_max_value = s64_max; > + reg->var_off = tnum_range(s64_min, s64_max); > + return; > + } > + > + if (s64_min < 0 && s64_max < 0) { > + reg->smin_value = reg->s32_min_value = s64_min; > + reg->smax_value = reg->s32_max_value = s64_max; > + reg->umin_value = (u64)s64_max; > + reg->umax_value = (u64)s64_min; I think the last two assignments are not correct for the following example: { "testtesttest", .insns = { BPF_EMIT_CALL(BPF_FUNC_get_prandom_u32), BPF_JMP_IMM(BPF_JLT, BPF_REG_0, 0xff80, 2), BPF_JMP_IMM(BPF_JGT, BPF_REG_0, 0xffff, 1), { .code = BPF_ALU64 | BPF_MOV | BPF_X, .dst_reg = BPF_REG_0, .src_reg = BPF_REG_0, .off = 8, .imm = 0, }, BPF_EXIT_INSN(), }, .result = ACCEPT, .retval = 0, }, Here is execution log: 0: R1=ctx(off=0,imm=0) R10=fp0 0: (85) call bpf_get_prandom_u32#7 ; R0_w=Pscalar() 1: (a5) if r0 < 0xff80 goto pc+2 ; R0_w=Pscalar(umin=65408) 2: (25) if r0 > 0xffff goto pc+1 ; R0_w=Pscalar(umin=65408,umax=65535,var_off=(0xff80; 0x7f)) 3: (bf) r0 = r0 ; R0_w=Pscalar (smin=-128,smax=-1, umin=18'446'744'073'709'551'615, umax=18'446'744'073'709'551'488, var_off=(0xffffffffffffff80; 0x7f), u32_min=-1,u32_max=-128) 4: (95) exit Note that umax < umin, which should not happen. In this case the assignments in question are: reg->umin_value = (u64)s64_max; // == -1 == 0xffffffffffffffff reg->umax_value = (u64)s64_min; // == -128 == 0xffffffffffffff80 > + reg->u32_min_value = (u32)s64_max; > + reg->u32_max_value = (u32)s64_min; > + reg->var_off = tnum_range((u64)s64_max, (u64)s64_min); > + return; > + } > + > +out: > + set_sext64_default_val(reg, size); > +} > + > static bool bpf_map_is_rdonly(const struct bpf_map *map) > { > /* A map is considered read-only if the following condition are true: > @@ -5815,7 +5886,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map) > !bpf_map_write_active(map); > } > > -static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) > +static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val, > + bool is_ldsx) > { > void *ptr; > u64 addr; > @@ -5828,13 +5900,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) > > switch (size) { > case sizeof(u8): > - *val = (u64)*(u8 *)ptr; > + *val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr; > break; > case sizeof(u16): > - *val = (u64)*(u16 *)ptr; > + *val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr; > break; > case sizeof(u32): > - *val = (u64)*(u32 *)ptr; > + *val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr; > break; > case sizeof(u64): > *val = *(u64 *)ptr; > @@ -6248,7 +6320,7 @@ static int check_stack_access_within_bounds( > */ > static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno, > int off, int bpf_size, enum bpf_access_type t, > - int value_regno, bool strict_alignment_once) > + int value_regno, bool strict_alignment_once, bool is_ldsx) > { > struct bpf_reg_state *regs = cur_regs(env); > struct bpf_reg_state *reg = regs + regno; > @@ -6309,7 +6381,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn > u64 val = 0; > > err = bpf_map_direct_read(map, map_off, size, > - &val); > + &val, is_ldsx); > if (err) > return err; > > @@ -6479,8 +6551,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn > > if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ && > regs[value_regno].type == SCALAR_VALUE) { > - /* b/h/w load zero-extends, mark upper bits as known 0 */ > - coerce_reg_to_size(®s[value_regno], size); > + if (!is_ldsx) > + /* b/h/w load zero-extends, mark upper bits as known 0 */ > + coerce_reg_to_size(®s[value_regno], size); > + else > + coerce_reg_to_size_sx(®s[value_regno], size); > } > return err; > } > @@ -6572,17 +6647,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i > * case to simulate the register fill. > */ > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > - BPF_SIZE(insn->code), BPF_READ, -1, true); > + BPF_SIZE(insn->code), BPF_READ, -1, true, false); > if (!err && load_reg >= 0) > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > BPF_SIZE(insn->code), BPF_READ, load_reg, > - true); > + true, false); > if (err) > return err; > > /* Check whether we can write into the same memory. */ > err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, > - BPF_SIZE(insn->code), BPF_WRITE, -1, true); > + BPF_SIZE(insn->code), BPF_WRITE, -1, true, false); > if (err) > return err; > > @@ -6828,7 +6903,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno, > return zero_size_allowed ? 0 : -EACCES; > > return check_mem_access(env, env->insn_idx, regno, offset, BPF_B, > - atype, -1, false); > + atype, -1, false, false); > } > > fallthrough; > @@ -7200,7 +7275,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn > /* we write BPF_DW bits (8 bytes) at a time */ > for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) { > err = check_mem_access(env, insn_idx, regno, > - i, BPF_DW, BPF_WRITE, -1, false); > + i, BPF_DW, BPF_WRITE, -1, false, false); > if (err) > return err; > } > @@ -7293,7 +7368,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id > > for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) { > err = check_mem_access(env, insn_idx, regno, > - i, BPF_DW, BPF_WRITE, -1, false); > + i, BPF_DW, BPF_WRITE, -1, false, false); > if (err) > return err; > } > @@ -9437,7 +9512,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn > */ > for (i = 0; i < meta.access_size; i++) { > err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B, > - BPF_WRITE, -1, false); > + BPF_WRITE, -1, false, false); > if (err) > return err; > } > @@ -16315,7 +16390,8 @@ static int do_check(struct bpf_verifier_env *env) > */ > err = check_mem_access(env, env->insn_idx, insn->src_reg, > insn->off, BPF_SIZE(insn->code), > - BPF_READ, insn->dst_reg, false); > + BPF_READ, insn->dst_reg, false, > + BPF_MODE(insn->code) == BPF_MEMSX); > if (err) > return err; > > @@ -16352,7 +16428,7 @@ static int do_check(struct bpf_verifier_env *env) > /* check that memory (dst_reg + off) is writeable */ > err = check_mem_access(env, env->insn_idx, insn->dst_reg, > insn->off, BPF_SIZE(insn->code), > - BPF_WRITE, insn->src_reg, false); > + BPF_WRITE, insn->src_reg, false, false); > if (err) > return err; > > @@ -16377,7 +16453,7 @@ static int do_check(struct bpf_verifier_env *env) > /* check that memory (dst_reg + off) is writeable */ > err = check_mem_access(env, env->insn_idx, insn->dst_reg, > insn->off, BPF_SIZE(insn->code), > - BPF_WRITE, -1, false); > + BPF_WRITE, -1, false, false); > if (err) > return err; > > @@ -16805,7 +16881,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env) > > for (i = 0; i < insn_cnt; i++, insn++) { > if (BPF_CLASS(insn->code) == BPF_LDX && > - (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) { > + ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) || > + insn->imm != 0)) { > verbose(env, "BPF_LDX uses reserved fields\n"); > return -EINVAL; > } > @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) || > insn->code == (BPF_LDX | BPF_MEM | BPF_H) || > insn->code == (BPF_LDX | BPF_MEM | BPF_W) || > - insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) { > + insn->code == (BPF_LDX | BPF_MEM | BPF_DW) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) || > + insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) { > type = BPF_READ; > } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) || > insn->code == (BPF_STX | BPF_MEM | BPF_H) || > @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) > */ > case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED: > if (type == BPF_READ) { > + /* it is hard to differentiate that the > + * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX, > + * let us use insn->imm to remember it. > + */ > + insn->imm = BPF_MODE(insn->code); > insn->code = BPF_LDX | BPF_PROBE_MEM | > BPF_SIZE((insn)->code); > env->prog->aux->num_exentries++; > diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h > index 600d0caebbd8..c7196302d1eb 100644 > --- a/tools/include/uapi/linux/bpf.h > +++ b/tools/include/uapi/linux/bpf.h > @@ -19,6 +19,7 @@ > > /* ld/ldx fields */ > #define BPF_DW 0x18 /* double word (64-bit) */ > +#define BPF_MEMSX 0x80 /* load with sign extension */ > #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ > #define BPF_XADD 0xc0 /* exclusive add - legacy name */ >
On 7/18/23 5:15 PM, Eduard Zingerman wrote: > On Wed, 2023-07-12 at 23:07 -0700, Yonghong Song wrote: >> Add interpreter/jit support for new sign-extension load insns >> which adds a new mode (BPF_MEMSX). >> Also add verifier support to recognize these insns and to >> do proper verification with new insns. In verifier, besides >> to deduce proper bounds for the dst_reg, probed memory access >> is handled by remembering insn mode in insn->imm field so later >> on proper jit insns can be emitted. >> >> Signed-off-by: Yonghong Song <yhs@fb.com> >> --- >> arch/x86/net/bpf_jit_comp.c | 32 ++++++++- >> include/uapi/linux/bpf.h | 1 + >> kernel/bpf/core.c | 13 ++++ >> kernel/bpf/verifier.c | 125 +++++++++++++++++++++++++++------ >> tools/include/uapi/linux/bpf.h | 1 + >> 5 files changed, 151 insertions(+), 21 deletions(-) >> >> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c >> index 438adb695daa..addeea95f397 100644 >> --- a/arch/x86/net/bpf_jit_comp.c >> +++ b/arch/x86/net/bpf_jit_comp.c >> @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) >> *pprog = prog; >> } >> >> +/* LDX: dst_reg = *(s8*)(src_reg + off) */ >> +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) >> +{ >> + u8 *prog = *pprog; >> + >> + switch (size) { >> + case BPF_B: >> + /* Emit 'movsx rax, byte ptr [rax + off]' */ >> + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE); >> + break; >> + case BPF_H: >> + /* Emit 'movsx rax, word ptr [rax + off]' */ >> + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF); >> + break; >> + case BPF_W: >> + /* Emit 'movsx rax, dword ptr [rax+0x14]' */ >> + EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63); >> + break; >> + } >> + emit_insn_suffix(&prog, src_reg, dst_reg, off); >> + *pprog = prog; >> +} >> + >> /* STX: *(u8*)(dst_reg + off) = src_reg */ >> static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) >> { >> @@ -1370,6 +1393,9 @@ st: if (is_imm8(insn->off)) >> case BPF_LDX | BPF_PROBE_MEM | BPF_W: >> case BPF_LDX | BPF_MEM | BPF_DW: >> case BPF_LDX | BPF_PROBE_MEM | BPF_DW: >> + case BPF_LDX | BPF_MEMSX | BPF_B: >> + case BPF_LDX | BPF_MEMSX | BPF_H: >> + case BPF_LDX | BPF_MEMSX | BPF_W: >> insn_off = insn->off; >> >> if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { >> @@ -1415,7 +1441,11 @@ st: if (is_imm8(insn->off)) >> start_of_ldx = prog; >> end_of_jmp[-1] = start_of_ldx - end_of_jmp; >> } >> - emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); >> + if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) || >> + BPF_MODE(insn->code) == BPF_MEMSX) >> + emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); >> + else >> + emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); >> if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { >> struct exception_table_entry *ex; >> u8 *_insn = image + proglen + (start_of_ldx - temp); >> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h >> index 600d0caebbd8..c7196302d1eb 100644 >> --- a/include/uapi/linux/bpf.h >> +++ b/include/uapi/linux/bpf.h >> @@ -19,6 +19,7 @@ >> >> /* ld/ldx fields */ >> #define BPF_DW 0x18 /* double word (64-bit) */ >> +#define BPF_MEMSX 0x80 /* load with sign extension */ >> #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ >> #define BPF_XADD 0xc0 /* exclusive add - legacy name */ >> >> diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c >> index dc85240a0134..8a1cc658789e 100644 >> --- a/kernel/bpf/core.c >> +++ b/kernel/bpf/core.c >> @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base); >> INSN_3(LDX, MEM, H), \ >> INSN_3(LDX, MEM, W), \ >> INSN_3(LDX, MEM, DW), \ >> + INSN_3(LDX, MEMSX, B), \ >> + INSN_3(LDX, MEMSX, H), \ >> + INSN_3(LDX, MEMSX, W), \ >> /* Immediate based. */ \ >> INSN_3(LD, IMM, DW) >> >> @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) >> LDST(DW, u64) >> #undef LDST >> >> +#define LDS(SIZEOP, SIZE) \ >> + LDX_MEMSX_##SIZEOP: \ >> + DST = *(SIZE *)(unsigned long) (SRC + insn->off); \ >> + CONT; >> + >> + LDS(B, s8) >> + LDS(H, s16) >> + LDS(W, s32) >> +#undef LDS >> + >> #define ATOMIC_ALU_OP(BOP, KOP) \ >> case BOP: \ >> if (BPF_SIZE(insn->code) == BPF_W) \ >> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c >> index 81a93eeac7a0..fbe4ca72d4c1 100644 >> --- a/kernel/bpf/verifier.c >> +++ b/kernel/bpf/verifier.c >> @@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size) >> __reg_combine_64_into_32(reg); >> } >> >> +static void set_sext64_default_val(struct bpf_reg_state *reg, int size) >> +{ >> + if (size == 1) { >> + reg->smin_value = reg->s32_min_value = S8_MIN; >> + reg->smax_value = reg->s32_max_value = S8_MAX; >> + } else if (size == 2) { >> + reg->smin_value = reg->s32_min_value = S16_MIN; >> + reg->smax_value = reg->s32_max_value = S16_MAX; >> + } else { >> + /* size == 4 */ >> + reg->smin_value = reg->s32_min_value = S32_MIN; >> + reg->smax_value = reg->s32_max_value = S32_MAX; >> + } >> + reg->umin_value = reg->u32_min_value = 0; >> + reg->umax_value = U64_MAX; >> + reg->u32_max_value = U32_MAX; >> + reg->var_off = tnum_unknown; >> +} >> + >> +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) >> +{ >> + u64 top_smax_value, top_smin_value; >> + s64 init_s64_max, init_s64_min, s64_max, s64_min; >> + u64 num_bits = size * 8; >> + >> + top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; >> + top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; >> + >> + if (top_smax_value != top_smin_value) >> + goto out; >> + >> + /* find the s64_min and s64_min after sign extension */ >> + if (size == 1) { >> + init_s64_max = (s8)reg->smax_value; >> + init_s64_min = (s8)reg->smin_value; >> + } else if (size == 2) { >> + init_s64_max = (s16)reg->smax_value; >> + init_s64_min = (s16)reg->smin_value; >> + } else { >> + /* size == 4 */ >> + init_s64_max = (s32)reg->smax_value; >> + init_s64_min = (s32)reg->smin_value; >> + } >> + >> + s64_max = max(init_s64_max, init_s64_min); >> + s64_min = min(init_s64_max, init_s64_min); >> + >> + if (s64_max >= 0 && s64_min >= 0) { >> + reg->smin_value = reg->s32_min_value = s64_min; >> + reg->smax_value = reg->s32_max_value = s64_max; >> + reg->umin_value = reg->u32_min_value = s64_min; >> + reg->umax_value = reg->u32_max_value = s64_max; >> + reg->var_off = tnum_range(s64_min, s64_max); >> + return; >> + } >> + >> + if (s64_min < 0 && s64_max < 0) { >> + reg->smin_value = reg->s32_min_value = s64_min; >> + reg->smax_value = reg->s32_max_value = s64_max; >> + reg->umin_value = (u64)s64_max; >> + reg->umax_value = (u64)s64_min; > > I think the last two assignments are not correct for the following example: > > { > "testtesttest", > .insns = { > BPF_EMIT_CALL(BPF_FUNC_get_prandom_u32), > BPF_JMP_IMM(BPF_JLT, BPF_REG_0, 0xff80, 2), > BPF_JMP_IMM(BPF_JGT, BPF_REG_0, 0xffff, 1), > { > .code = BPF_ALU64 | BPF_MOV | BPF_X, > .dst_reg = BPF_REG_0, > .src_reg = BPF_REG_0, > .off = 8, > .imm = 0, > }, > BPF_EXIT_INSN(), > }, > .result = ACCEPT, > .retval = 0, > }, > > Here is execution log: > > 0: R1=ctx(off=0,imm=0) R10=fp0 > 0: (85) call bpf_get_prandom_u32#7 ; R0_w=Pscalar() > 1: (a5) if r0 < 0xff80 goto pc+2 ; R0_w=Pscalar(umin=65408) > 2: (25) if r0 > 0xffff goto pc+1 ; R0_w=Pscalar(umin=65408,umax=65535,var_off=(0xff80; 0x7f)) > 3: (bf) r0 = r0 ; R0_w=Pscalar > (smin=-128,smax=-1, > umin=18'446'744'073'709'551'615, > umax=18'446'744'073'709'551'488, > var_off=(0xffffffffffffff80; 0x7f), > u32_min=-1,u32_max=-128) > 4: (95) exit > > Note that umax < umin, which should not happen. > In this case the assignments in question are: > > reg->umin_value = (u64)s64_max; // == -1 == 0xffffffffffffffff > reg->umax_value = (u64)s64_min; // == -128 == 0xffffffffffffff80 Thanks for pointing out. Yes, the assignment is incorrect and they are mismatched. Will fix the issue and add a test for this. > > >> + reg->u32_min_value = (u32)s64_max; >> + reg->u32_max_value = (u32)s64_min; >> + reg->var_off = tnum_range((u64)s64_max, (u64)s64_min); >> + return; >> + } >> + >> +out: >> + set_sext64_default_val(reg, size); >> +} >> + >[...]
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 438adb695daa..addeea95f397 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) *pprog = prog; } +/* LDX: dst_reg = *(s8*)(src_reg + off) */ +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) +{ + u8 *prog = *pprog; + + switch (size) { + case BPF_B: + /* Emit 'movsx rax, byte ptr [rax + off]' */ + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE); + break; + case BPF_H: + /* Emit 'movsx rax, word ptr [rax + off]' */ + EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF); + break; + case BPF_W: + /* Emit 'movsx rax, dword ptr [rax+0x14]' */ + EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63); + break; + } + emit_insn_suffix(&prog, src_reg, dst_reg, off); + *pprog = prog; +} + /* STX: *(u8*)(dst_reg + off) = src_reg */ static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off) { @@ -1370,6 +1393,9 @@ st: if (is_imm8(insn->off)) case BPF_LDX | BPF_PROBE_MEM | BPF_W: case BPF_LDX | BPF_MEM | BPF_DW: case BPF_LDX | BPF_PROBE_MEM | BPF_DW: + case BPF_LDX | BPF_MEMSX | BPF_B: + case BPF_LDX | BPF_MEMSX | BPF_H: + case BPF_LDX | BPF_MEMSX | BPF_W: insn_off = insn->off; if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { @@ -1415,7 +1441,11 @@ st: if (is_imm8(insn->off)) start_of_ldx = prog; end_of_jmp[-1] = start_of_ldx - end_of_jmp; } - emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); + if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) || + BPF_MODE(insn->code) == BPF_MEMSX) + emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); + else + emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); if (BPF_MODE(insn->code) == BPF_PROBE_MEM) { struct exception_table_entry *ex; u8 *_insn = image + proglen + (start_of_ldx - temp); diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index 600d0caebbd8..c7196302d1eb 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -19,6 +19,7 @@ /* ld/ldx fields */ #define BPF_DW 0x18 /* double word (64-bit) */ +#define BPF_MEMSX 0x80 /* load with sign extension */ #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ #define BPF_XADD 0xc0 /* exclusive add - legacy name */ diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c index dc85240a0134..8a1cc658789e 100644 --- a/kernel/bpf/core.c +++ b/kernel/bpf/core.c @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base); INSN_3(LDX, MEM, H), \ INSN_3(LDX, MEM, W), \ INSN_3(LDX, MEM, DW), \ + INSN_3(LDX, MEMSX, B), \ + INSN_3(LDX, MEMSX, H), \ + INSN_3(LDX, MEMSX, W), \ /* Immediate based. */ \ INSN_3(LD, IMM, DW) @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn) LDST(DW, u64) #undef LDST +#define LDS(SIZEOP, SIZE) \ + LDX_MEMSX_##SIZEOP: \ + DST = *(SIZE *)(unsigned long) (SRC + insn->off); \ + CONT; + + LDS(B, s8) + LDS(H, s16) + LDS(W, s32) +#undef LDS + #define ATOMIC_ALU_OP(BOP, KOP) \ case BOP: \ if (BPF_SIZE(insn->code) == BPF_W) \ diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 81a93eeac7a0..fbe4ca72d4c1 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size) __reg_combine_64_into_32(reg); } +static void set_sext64_default_val(struct bpf_reg_state *reg, int size) +{ + if (size == 1) { + reg->smin_value = reg->s32_min_value = S8_MIN; + reg->smax_value = reg->s32_max_value = S8_MAX; + } else if (size == 2) { + reg->smin_value = reg->s32_min_value = S16_MIN; + reg->smax_value = reg->s32_max_value = S16_MAX; + } else { + /* size == 4 */ + reg->smin_value = reg->s32_min_value = S32_MIN; + reg->smax_value = reg->s32_max_value = S32_MAX; + } + reg->umin_value = reg->u32_min_value = 0; + reg->umax_value = U64_MAX; + reg->u32_max_value = U32_MAX; + reg->var_off = tnum_unknown; +} + +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) +{ + u64 top_smax_value, top_smin_value; + s64 init_s64_max, init_s64_min, s64_max, s64_min; + u64 num_bits = size * 8; + + top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; + top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; + + if (top_smax_value != top_smin_value) + goto out; + + /* find the s64_min and s64_min after sign extension */ + if (size == 1) { + init_s64_max = (s8)reg->smax_value; + init_s64_min = (s8)reg->smin_value; + } else if (size == 2) { + init_s64_max = (s16)reg->smax_value; + init_s64_min = (s16)reg->smin_value; + } else { + /* size == 4 */ + init_s64_max = (s32)reg->smax_value; + init_s64_min = (s32)reg->smin_value; + } + + s64_max = max(init_s64_max, init_s64_min); + s64_min = min(init_s64_max, init_s64_min); + + if (s64_max >= 0 && s64_min >= 0) { + reg->smin_value = reg->s32_min_value = s64_min; + reg->smax_value = reg->s32_max_value = s64_max; + reg->umin_value = reg->u32_min_value = s64_min; + reg->umax_value = reg->u32_max_value = s64_max; + reg->var_off = tnum_range(s64_min, s64_max); + return; + } + + if (s64_min < 0 && s64_max < 0) { + reg->smin_value = reg->s32_min_value = s64_min; + reg->smax_value = reg->s32_max_value = s64_max; + reg->umin_value = (u64)s64_max; + reg->umax_value = (u64)s64_min; + reg->u32_min_value = (u32)s64_max; + reg->u32_max_value = (u32)s64_min; + reg->var_off = tnum_range((u64)s64_max, (u64)s64_min); + return; + } + +out: + set_sext64_default_val(reg, size); +} + static bool bpf_map_is_rdonly(const struct bpf_map *map) { /* A map is considered read-only if the following condition are true: @@ -5815,7 +5886,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map) !bpf_map_write_active(map); } -static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) +static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val, + bool is_ldsx) { void *ptr; u64 addr; @@ -5828,13 +5900,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val) switch (size) { case sizeof(u8): - *val = (u64)*(u8 *)ptr; + *val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr; break; case sizeof(u16): - *val = (u64)*(u16 *)ptr; + *val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr; break; case sizeof(u32): - *val = (u64)*(u32 *)ptr; + *val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr; break; case sizeof(u64): *val = *(u64 *)ptr; @@ -6248,7 +6320,7 @@ static int check_stack_access_within_bounds( */ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno, int off, int bpf_size, enum bpf_access_type t, - int value_regno, bool strict_alignment_once) + int value_regno, bool strict_alignment_once, bool is_ldsx) { struct bpf_reg_state *regs = cur_regs(env); struct bpf_reg_state *reg = regs + regno; @@ -6309,7 +6381,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn u64 val = 0; err = bpf_map_direct_read(map, map_off, size, - &val); + &val, is_ldsx); if (err) return err; @@ -6479,8 +6551,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ && regs[value_regno].type == SCALAR_VALUE) { - /* b/h/w load zero-extends, mark upper bits as known 0 */ - coerce_reg_to_size(®s[value_regno], size); + if (!is_ldsx) + /* b/h/w load zero-extends, mark upper bits as known 0 */ + coerce_reg_to_size(®s[value_regno], size); + else + coerce_reg_to_size_sx(®s[value_regno], size); } return err; } @@ -6572,17 +6647,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i * case to simulate the register fill. */ err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, - BPF_SIZE(insn->code), BPF_READ, -1, true); + BPF_SIZE(insn->code), BPF_READ, -1, true, false); if (!err && load_reg >= 0) err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, BPF_SIZE(insn->code), BPF_READ, load_reg, - true); + true, false); if (err) return err; /* Check whether we can write into the same memory. */ err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off, - BPF_SIZE(insn->code), BPF_WRITE, -1, true); + BPF_SIZE(insn->code), BPF_WRITE, -1, true, false); if (err) return err; @@ -6828,7 +6903,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno, return zero_size_allowed ? 0 : -EACCES; return check_mem_access(env, env->insn_idx, regno, offset, BPF_B, - atype, -1, false); + atype, -1, false, false); } fallthrough; @@ -7200,7 +7275,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn /* we write BPF_DW bits (8 bytes) at a time */ for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) { err = check_mem_access(env, insn_idx, regno, - i, BPF_DW, BPF_WRITE, -1, false); + i, BPF_DW, BPF_WRITE, -1, false, false); if (err) return err; } @@ -7293,7 +7368,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) { err = check_mem_access(env, insn_idx, regno, - i, BPF_DW, BPF_WRITE, -1, false); + i, BPF_DW, BPF_WRITE, -1, false, false); if (err) return err; } @@ -9437,7 +9512,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn */ for (i = 0; i < meta.access_size; i++) { err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B, - BPF_WRITE, -1, false); + BPF_WRITE, -1, false, false); if (err) return err; } @@ -16315,7 +16390,8 @@ static int do_check(struct bpf_verifier_env *env) */ err = check_mem_access(env, env->insn_idx, insn->src_reg, insn->off, BPF_SIZE(insn->code), - BPF_READ, insn->dst_reg, false); + BPF_READ, insn->dst_reg, false, + BPF_MODE(insn->code) == BPF_MEMSX); if (err) return err; @@ -16352,7 +16428,7 @@ static int do_check(struct bpf_verifier_env *env) /* check that memory (dst_reg + off) is writeable */ err = check_mem_access(env, env->insn_idx, insn->dst_reg, insn->off, BPF_SIZE(insn->code), - BPF_WRITE, insn->src_reg, false); + BPF_WRITE, insn->src_reg, false, false); if (err) return err; @@ -16377,7 +16453,7 @@ static int do_check(struct bpf_verifier_env *env) /* check that memory (dst_reg + off) is writeable */ err = check_mem_access(env, env->insn_idx, insn->dst_reg, insn->off, BPF_SIZE(insn->code), - BPF_WRITE, -1, false); + BPF_WRITE, -1, false, false); if (err) return err; @@ -16805,7 +16881,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env) for (i = 0; i < insn_cnt; i++, insn++) { if (BPF_CLASS(insn->code) == BPF_LDX && - (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) { + ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) || + insn->imm != 0)) { verbose(env, "BPF_LDX uses reserved fields\n"); return -EINVAL; } @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) || insn->code == (BPF_LDX | BPF_MEM | BPF_H) || insn->code == (BPF_LDX | BPF_MEM | BPF_W) || - insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) { + insn->code == (BPF_LDX | BPF_MEM | BPF_DW) || + insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) || + insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) || + insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) { type = BPF_READ; } else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) || insn->code == (BPF_STX | BPF_MEM | BPF_H) || @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) */ case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED: if (type == BPF_READ) { + /* it is hard to differentiate that the + * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX, + * let us use insn->imm to remember it. + */ + insn->imm = BPF_MODE(insn->code); insn->code = BPF_LDX | BPF_PROBE_MEM | BPF_SIZE((insn)->code); env->prog->aux->num_exentries++; diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index 600d0caebbd8..c7196302d1eb 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -19,6 +19,7 @@ /* ld/ldx fields */ #define BPF_DW 0x18 /* double word (64-bit) */ +#define BPF_MEMSX 0x80 /* load with sign extension */ #define BPF_ATOMIC 0xc0 /* atomic memory ops - op type in immediate */ #define BPF_XADD 0xc0 /* exclusive add - legacy name */
Add interpreter/jit support for new sign-extension load insns which adds a new mode (BPF_MEMSX). Also add verifier support to recognize these insns and to do proper verification with new insns. In verifier, besides to deduce proper bounds for the dst_reg, probed memory access is handled by remembering insn mode in insn->imm field so later on proper jit insns can be emitted. Signed-off-by: Yonghong Song <yhs@fb.com> --- arch/x86/net/bpf_jit_comp.c | 32 ++++++++- include/uapi/linux/bpf.h | 1 + kernel/bpf/core.c | 13 ++++ kernel/bpf/verifier.c | 125 +++++++++++++++++++++++++++------ tools/include/uapi/linux/bpf.h | 1 + 5 files changed, 151 insertions(+), 21 deletions(-)