diff mbox series

[bpf-next,v2,1/3] bpf: track find_equal_scalars history on per-instruction level

Message ID 20240705205851.2635794-2-eddyz87@gmail.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series bpf: track find_equal_scalars history on per-instruction level | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 852 this patch: 852
netdev/build_tools success Errors and warnings before: 0 this patch: 0
netdev/cc_maintainers warning 9 maintainers not CCed: song@kernel.org sdf@google.com mykolal@fb.com linux-kselftest@vger.kernel.org jolsa@kernel.org haoluo@google.com shuah@kernel.org kpsingh@kernel.org john.fastabend@gmail.com
netdev/build_clang success Errors and warnings before: 854 this patch: 854
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success Fixes tag looks correct
netdev/build_allmodconfig_warn success Errors and warnings before: 919 this patch: 919
netdev/checkpatch warning WARNING: line length of 82 exceeds 80 columns WARNING: line length of 83 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns WARNING: line length of 88 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 92 exceeds 80 columns WARNING: line length of 94 exceeds 80 columns WARNING: line length of 95 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline fail Was 0 now: 1
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Unittests
bpf/vmtest-bpf-next-VM_Test-3 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-34 success Logs for x86_64-llvm-17 / veristat
bpf/vmtest-bpf-next-VM_Test-35 success Logs for x86_64-llvm-18 / build / build for x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-36 success Logs for x86_64-llvm-18 / build-release / build for x86_64 with llvm-18-O2
bpf/vmtest-bpf-next-VM_Test-17 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-9 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-18 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-17 / build / build for x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-10 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-17 / build-release / build for x86_64 with llvm-17-O2
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-42 success Logs for x86_64-llvm-18 / veristat
bpf/vmtest-bpf-next-VM_Test-33 success Logs for x86_64-llvm-17 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-16 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-30 success Logs for x86_64-llvm-17 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-37 success Logs for x86_64-llvm-18 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-38 success Logs for x86_64-llvm-18 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-40 success Logs for x86_64-llvm-18 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-41 success Logs for x86_64-llvm-18 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_maps, false, 360) / test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-31 success Logs for x86_64-llvm-17 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-32 success Logs for x86_64-llvm-17 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-39 success Logs for x86_64-llvm-18 / test (test_progs_cpuv4, false, 360) / test_progs_cpuv4 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-15 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-PR success PR summary

Commit Message

Eduard Zingerman July 5, 2024, 8:58 p.m. UTC
Use bpf_verifier_state->jmp_history to track which registers were
updated by find_equal_scalars() when conditional jump was verified.
Use recorded information in backtrack_insn() to propagate precision.

E.g. for the following program:

            while verifying instructions
  r1 = r0              |
  if r1 < 8  goto ...  | push r0,r1 as equal_scalars in jmp_history
  if r0 > 16 goto ...  | push r0,r1 as equal_scalars in jmp_history
  r2 = r10             |
  r2 += r0             v mark_chain_precision(r0)

            while doing mark_chain_precision(r0)
  r1 = r0              ^
  if r1 < 8  goto ...  | mark r0,r1 as precise
  if r0 > 16 goto ...  | mark r0,r1 as precise
  r2 = r10             |
  r2 += r0             | mark r0 precise

Technically, achieve this as follows:
- Use 10 bits to identify each register that gains range because of
  find_equal_scalars():
  - 3 bits for frame number;
  - 6 bits for register or stack slot number;
  - 1 bit to indicate if register is spilled.
- Use u64 as a vector of 6 such records + 4 bits for vector length.
- Augment struct bpf_jmp_history_entry with field 'linked_regs'
  representing such vector.
- When doing check_cond_jmp_op() remember up to 6 registers that
  gain range because of find_equal_scalars() in such a vector.
- Don't propagate range information and reset IDs for registers that
  don't fit in 6-value vector.
- Push a pair {instruction index, equal scalars vector}
  to bpf_verifier_state->jmp_history.
- When doing backtrack_insn() check if any of recorded linked
  registers is currently marked precise, if so mark all linked
  registers as precise.

This also requires fixes for two test_verifier tests:
- precise: test 1
- precise: test 2

Both tests contain the following instruction sequence:

19: (bf) r2 = r9                      ; R2=scalar(id=3) R9=scalar(id=3)
20: (a5) if r2 < 0x8 goto pc+1        ; R2=scalar(id=3,umin=8)
21: (95) exit
22: (07) r2 += 1                      ; R2_w=scalar(id=3+1,...)
23: (bf) r1 = r10                     ; R1_w=fp0 R10=fp0
24: (07) r1 += -8                     ; R1_w=fp-8
25: (b7) r3 = 0                       ; R3_w=0
26: (85) call bpf_probe_read_kernel#113

The call to bpf_probe_read_kernel() at (26) forces r2 to be precise.
Previously, this forced all registers with same id to become precise
immediately when mark_chain_precision() is called.
After this change, the precision is propagated to registers sharing
same id only when 'if' instruction is backtracked.
Hence verification log for both tests is changed:
regs=r2,r9 -> regs=r2 for instructions 25..20.

Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()")
Reported-by: Hao Sun <sunhao.th@gmail.com>
Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@mail.gmail.com/
Suggested-by: Andrii Nakryiko <andrii@kernel.org>
Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
---
 include/linux/bpf_verifier.h                  |   4 +
 kernel/bpf/verifier.c                         | 231 ++++++++++++++++--
 .../bpf/progs/verifier_subprog_precision.c    |   2 +-
 .../testing/selftests/bpf/verifier/precise.c  |  20 +-
 4 files changed, 232 insertions(+), 25 deletions(-)

Comments

Andrii Nakryiko July 10, 2024, 12:34 a.m. UTC | #1
On Fri, Jul 5, 2024 at 1:59 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> Use bpf_verifier_state->jmp_history to track which registers were
> updated by find_equal_scalars() when conditional jump was verified.
> Use recorded information in backtrack_insn() to propagate precision.
>
> E.g. for the following program:
>
>             while verifying instructions
>   r1 = r0              |
>   if r1 < 8  goto ...  | push r0,r1 as equal_scalars in jmp_history
>   if r0 > 16 goto ...  | push r0,r1 as equal_scalars in jmp_history

linked_scalars? especially now that Alexei added offsets between
linked registers

>   r2 = r10             |
>   r2 += r0             v mark_chain_precision(r0)
>
>             while doing mark_chain_precision(r0)
>   r1 = r0              ^
>   if r1 < 8  goto ...  | mark r0,r1 as precise
>   if r0 > 16 goto ...  | mark r0,r1 as precise
>   r2 = r10             |
>   r2 += r0             | mark r0 precise

let's reverse the order here so it's linear in how the algorithm
actually works (backwards)?

>
> Technically, achieve this as follows:
> - Use 10 bits to identify each register that gains range because of
>   find_equal_scalars():

should this be renamed to find_linked_scalars() nowadays?

>   - 3 bits for frame number;
>   - 6 bits for register or stack slot number;
>   - 1 bit to indicate if register is spilled.
> - Use u64 as a vector of 6 such records + 4 bits for vector length.
> - Augment struct bpf_jmp_history_entry with field 'linked_regs'
>   representing such vector.
> - When doing check_cond_jmp_op() remember up to 6 registers that
>   gain range because of find_equal_scalars() in such a vector.
> - Don't propagate range information and reset IDs for registers that
>   don't fit in 6-value vector.
> - Push a pair {instruction index, equal scalars vector}
>   to bpf_verifier_state->jmp_history.
> - When doing backtrack_insn() check if any of recorded linked
>   registers is currently marked precise, if so mark all linked
>   registers as precise.
>
> This also requires fixes for two test_verifier tests:
> - precise: test 1
> - precise: test 2
>
> Both tests contain the following instruction sequence:
>
> 19: (bf) r2 = r9                      ; R2=scalar(id=3) R9=scalar(id=3)
> 20: (a5) if r2 < 0x8 goto pc+1        ; R2=scalar(id=3,umin=8)
> 21: (95) exit
> 22: (07) r2 += 1                      ; R2_w=scalar(id=3+1,...)
> 23: (bf) r1 = r10                     ; R1_w=fp0 R10=fp0
> 24: (07) r1 += -8                     ; R1_w=fp-8
> 25: (b7) r3 = 0                       ; R3_w=0
> 26: (85) call bpf_probe_read_kernel#113
>
> The call to bpf_probe_read_kernel() at (26) forces r2 to be precise.
> Previously, this forced all registers with same id to become precise
> immediately when mark_chain_precision() is called.
> After this change, the precision is propagated to registers sharing
> same id only when 'if' instruction is backtracked.
> Hence verification log for both tests is changed:
> regs=r2,r9 -> regs=r2 for instructions 25..20.
>
> Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()")
> Reported-by: Hao Sun <sunhao.th@gmail.com>
> Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@mail.gmail.com/
> Suggested-by: Andrii Nakryiko <andrii@kernel.org>
> Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
> ---
>  include/linux/bpf_verifier.h                  |   4 +
>  kernel/bpf/verifier.c                         | 231 ++++++++++++++++--
>  .../bpf/progs/verifier_subprog_precision.c    |   2 +-
>  .../testing/selftests/bpf/verifier/precise.c  |  20 +-
>  4 files changed, 232 insertions(+), 25 deletions(-)
>

The logic looks good (though I had a few small questions), I think.

> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> index 2b54e25d2364..da450552c278 100644
> --- a/include/linux/bpf_verifier.h
> +++ b/include/linux/bpf_verifier.h
> @@ -371,6 +371,10 @@ struct bpf_jmp_history_entry {
>         u32 prev_idx : 22;
>         /* special flags, e.g., whether insn is doing register stack spill/load */
>         u32 flags : 10;
> +       /* additional registers that need precision tracking when this
> +        * jump is backtracked, vector of six 10-bit records
> +        */
> +       u64 linked_regs;
>  };
>
>  /* Maximum number of register states that can exist at once */
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index e25ad5fb9115..ec493360607e 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -3335,9 +3335,87 @@ static bool is_jmp_point(struct bpf_verifier_env *env, int insn_idx)
>         return env->insn_aux_data[insn_idx].jmp_point;
>  }
>
> +#define ES_FRAMENO_BITS        3
> +#define ES_SPI_BITS    6
> +#define ES_ENTRY_BITS  (ES_SPI_BITS + ES_FRAMENO_BITS + 1)
> +#define ES_SIZE_BITS   4
> +#define ES_FRAMENO_MASK        ((1ul << ES_FRAMENO_BITS) - 1)
> +#define ES_SPI_MASK    ((1ul << ES_SPI_BITS)     - 1)
> +#define ES_SIZE_MASK   ((1ul << ES_SIZE_BITS)    - 1)

ull for 32-bit arches?

> +#define ES_SPI_OFF     ES_FRAMENO_BITS
> +#define ES_IS_REG_OFF  (ES_SPI_BITS + ES_FRAMENO_BITS)

ES makes no sense now, no? LR or LINKREG or something along those lines?

> +#define LINKED_REGS_MAX        6
> +
> +struct reg_or_spill {

reg_or_spill -> linked_reg ?

> +       u8 frameno:3;
> +       union {
> +               u8 spi:6;
> +               u8 regno:6;
> +       };
> +       bool is_reg:1;
> +};

Do we need these bitfields for unpacked representation? It's going to
use 2 bytes for this struct anyways. If you just use u8 for everything
you end up with 3 bytes. Bitfields are a bit slower because the
compiler will need to do more bit manipulations, so is it really worth
it?

> +
> +struct linked_regs {
> +       int cnt;
> +       struct reg_or_spill entries[LINKED_REGS_MAX];
> +};
> +

[...]

> @@ -3615,6 +3739,12 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
>                 print_bpf_insn(&cbs, insn, env->allow_ptr_leaks);
>         }
>
> +       /* If there is a history record that some registers gained range at this insn,
> +        * propagate precision marks to those registers, so that bt_is_reg_set()
> +        * accounts for these registers.
> +        */
> +       bt_sync_linked_regs(bt, hist);
> +
>         if (class == BPF_ALU || class == BPF_ALU64) {
>                 if (!bt_is_reg_set(bt, dreg))
>                         return 0;
> @@ -3844,6 +3974,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
>                          */
>                         bt_set_reg(bt, dreg);
>                         bt_set_reg(bt, sreg);
> +               } else if (BPF_SRC(insn->code) == BPF_K) {
>                          /* else dreg <cond> K

drop "else" from the comment then? I like this change.

>                           * Only dreg still needs precision before
>                           * this insn, so for the K-based conditional
> @@ -3862,6 +3993,10 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
>                         /* to be analyzed */
>                         return -ENOTSUPP;
>         }
> +       /* Propagate precision marks to linked registers, to account for
> +        * registers marked as precise in this function.
> +        */
> +       bt_sync_linked_regs(bt, hist);

Radical Andrii is fine with this, though I wonder if there is some
place outside of backtrack_insn() where the first
bt_sync_linked_regs() could be called just once?

But regardless, this is only mildly expensive when we do have linked
registers, so unlikely to have any noticeable performance effect.

>         return 0;
>  }
>
> @@ -4624,7 +4759,7 @@ static int check_stack_write_fixed_off(struct bpf_verifier_env *env,
>         }
>
>         if (insn_flags)
> -               return push_jmp_history(env, env->cur_state, insn_flags);
> +               return push_jmp_history(env, env->cur_state, insn_flags, 0);
>         return 0;
>  }
>
> @@ -4929,7 +5064,7 @@ static int check_stack_read_fixed_off(struct bpf_verifier_env *env,
>                 insn_flags = 0; /* we are not restoring spilled register */
>         }
>         if (insn_flags)
> -               return push_jmp_history(env, env->cur_state, insn_flags);
> +               return push_jmp_history(env, env->cur_state, insn_flags, 0);
>         return 0;
>  }
>
> @@ -15154,14 +15289,66 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
>         return true;
>  }
>
> -static void find_equal_scalars(struct bpf_verifier_state *vstate,
> -                              struct bpf_reg_state *known_reg)
> +static void __find_equal_scalars(struct linked_regs *reg_set, struct bpf_reg_state *reg,
> +                                u32 id, u32 frameno, u32 spi_or_reg, bool is_reg)

we should abandon "equal scalars" terminology, they don't have to be
equal, they are just linked together (potentially with a fixed
difference between them)


how about "collect_linked_regs"?

> +{
> +       struct reg_or_spill *e;
> +
> +       if (reg->type != SCALAR_VALUE || (reg->id & ~BPF_ADD_CONST) != id)

THIS is actually the place where I'd use u32 id:31; + bool
is_linked_reg:1; just so that it's not so easy to accidentally forget
about BPF_ADD_CONST flag (but it's unrelated to your patch)

> +               return;
> +
> +       e = linked_regs_push(reg_set);
> +       if (e) {
> +               e->frameno = frameno;
> +               e->is_reg = is_reg;
> +               e->regno = spi_or_reg;
> +       } else {
> +               reg->id = 0;
> +       }
> +}
> +

[...]

> @@ -15312,6 +15500,21 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
>                 return 0;
>         }
>
> +       /* Push scalar registers sharing same ID to jump history,
> +        * do this before creating 'other_branch', so that both
> +        * 'this_branch' and 'other_branch' share this history
> +        * if parent state is created.
> +        */
> +       if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
> +               find_equal_scalars(this_branch, src_reg->id, &linked_regs);
> +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
> +               find_equal_scalars(this_branch, dst_reg->id, &linked_regs);
> +       if (linked_regs.cnt > 1) {

if we have just one, should it be even marked as linked?

> +               err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs));
> +               if (err)
> +                       return err;
> +       }
> +
>         other_branch = push_stack(env, *insn_idx + insn->off + 1, *insn_idx,
>                                   false);
>         if (!other_branch)
> @@ -15336,13 +15539,13 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
>         if (BPF_SRC(insn->code) == BPF_X &&
>             src_reg->type == SCALAR_VALUE && src_reg->id &&
>             !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
> -               find_equal_scalars(this_branch, src_reg);
> -               find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);
> +               copy_known_reg(this_branch, src_reg, &linked_regs);
> +               copy_known_reg(other_branch, &other_branch_regs[insn->src_reg], &linked_regs);

I liked the "sync" terminology you used for bt, so why not call this
"sync_linked_regs" ?

>         }
>         if (dst_reg->type == SCALAR_VALUE && dst_reg->id &&
>             !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) {
> -               find_equal_scalars(this_branch, dst_reg);
> -               find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);
> +               copy_known_reg(this_branch, dst_reg, &linked_regs);
> +               copy_known_reg(other_branch, &other_branch_regs[insn->dst_reg], &linked_regs);
>         }
>

[...]
Eduard Zingerman July 10, 2024, 1:21 a.m. UTC | #2
On Tue, 2024-07-09 at 17:34 -0700, Andrii Nakryiko wrote:
> On Fri, Jul 5, 2024 at 1:59 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
> > 
> > Use bpf_verifier_state->jmp_history to track which registers were
> > updated by find_equal_scalars() when conditional jump was verified.
> > Use recorded information in backtrack_insn() to propagate precision.
> > 
> > E.g. for the following program:
> > 
> >             while verifying instructions
> >   r1 = r0              |
> >   if r1 < 8  goto ...  | push r0,r1 as equal_scalars in jmp_history
> >   if r0 > 16 goto ...  | push r0,r1 as equal_scalars in jmp_history
> 
> linked_scalars? especially now that Alexei added offsets between
> linked registers

Missed this, will update.

> 
> >   r2 = r10             |
> >   r2 += r0             v mark_chain_precision(r0)
> > 
> >             while doing mark_chain_precision(r0)
> >   r1 = r0              ^
> >   if r1 < 8  goto ...  | mark r0,r1 as precise
> >   if r0 > 16 goto ...  | mark r0,r1 as precise
> >   r2 = r10             |
> >   r2 += r0             | mark r0 precise
> 
> let's reverse the order here so it's linear in how the algorithm
> actually works (backwards)?

I thought the arrow would be enough. Ok, can reverse.

> > Technically, achieve this as follows:
> > - Use 10 bits to identify each register that gains range because of
> >   find_equal_scalars():
> 
> should this be renamed to find_linked_scalars() nowadays?

That would be sync_linked_regs() if we use naming that you suggest.
Will update.

[...]

> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index e25ad5fb9115..ec493360607e 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -3335,9 +3335,87 @@ static bool is_jmp_point(struct bpf_verifier_env *env, int insn_idx)
> >         return env->insn_aux_data[insn_idx].jmp_point;
> >  }
> > 
> > +#define ES_FRAMENO_BITS        3
> > +#define ES_SPI_BITS    6
> > +#define ES_ENTRY_BITS  (ES_SPI_BITS + ES_FRAMENO_BITS + 1)
> > +#define ES_SIZE_BITS   4
> > +#define ES_FRAMENO_MASK        ((1ul << ES_FRAMENO_BITS) - 1)
> > +#define ES_SPI_MASK    ((1ul << ES_SPI_BITS)     - 1)
> > +#define ES_SIZE_MASK   ((1ul << ES_SIZE_BITS)    - 1)
> 
> ull for 32-bit arches?

Ok

> 
> > +#define ES_SPI_OFF     ES_FRAMENO_BITS
> > +#define ES_IS_REG_OFF  (ES_SPI_BITS + ES_FRAMENO_BITS)
> 
> ES makes no sense now, no? LR or LINKREG or something along those lines?
> 
> > +#define LINKED_REGS_MAX        6
> > +
> > +struct reg_or_spill {
> 
> reg_or_spill -> linked_reg ?

Ok

> 
> > +       u8 frameno:3;
> > +       union {
> > +               u8 spi:6;
> > +               u8 regno:6;
> > +       };
> > +       bool is_reg:1;
> > +};
> 
> Do we need these bitfields for unpacked representation? It's going to
> use 2 bytes for this struct anyways. If you just use u8 for everything
> you end up with 3 bytes. Bitfields are a bit slower because the
> compiler will need to do more bit manipulations, so is it really worth
> it?

Ok, will remove bitfields.

[...]

> > @@ -3844,6 +3974,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> >                          */
> >                         bt_set_reg(bt, dreg);
> >                         bt_set_reg(bt, sreg);
> > +               } else if (BPF_SRC(insn->code) == BPF_K) {
> >                          /* else dreg <cond> K
> 
> drop "else" from the comment then? I like this change.

This is actually a leftover from v1. I can drop "else" from the
comment or drop this hunk as it is not necessary for the series.

> >                           * Only dreg still needs precision before
> >                           * this insn, so for the K-based conditional
> > @@ -3862,6 +3993,10 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> >                         /* to be analyzed */
> >                         return -ENOTSUPP;
> >         }
> > +       /* Propagate precision marks to linked registers, to account for
> > +        * registers marked as precise in this function.
> > +        */
> > +       bt_sync_linked_regs(bt, hist);
> 
> Radical Andrii is fine with this, though I wonder if there is some
> place outside of backtrack_insn() where the first
> bt_sync_linked_regs() could be called just once?

The problem here is that:
- in theory linked_regs could be present for any instruction, thus
  sync() is needed after get_jmp_hist_entry call in
  __mark_chain_precision();
- backtrack_insn() might both remove and add some registers in bt,
  hence, to correctly handle bt_empty() call in __mark_chain_precision
  the sync() is also needed after backtrack_insn().

So, current placement is the simplest I could come up with.

> But regardless, this is only mildly expensive when we do have linked
> registers, so unlikely to have any noticeable performance effect.

Yes, that was my thinking as well.

[...]

> > @@ -15154,14 +15289,66 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
> >         return true;
> >  }
> > 
> > -static void find_equal_scalars(struct bpf_verifier_state *vstate,
> > -                              struct bpf_reg_state *known_reg)
> > +static void __find_equal_scalars(struct linked_regs *reg_set, struct bpf_reg_state *reg,
> > +                                u32 id, u32 frameno, u32 spi_or_reg, bool is_reg)
> 
> we should abandon "equal scalars" terminology, they don't have to be
> equal, they are just linked together (potentially with a fixed
> difference between them)
> 
> how about "collect_linked_regs"?

Sounds good.

[...]

> > @@ -15312,6 +15500,21 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
> >                 return 0;
> >         }
> > 
> > +       /* Push scalar registers sharing same ID to jump history,
> > +        * do this before creating 'other_branch', so that both
> > +        * 'this_branch' and 'other_branch' share this history
> > +        * if parent state is created.
> > +        */
> > +       if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
> > +               find_equal_scalars(this_branch, src_reg->id, &linked_regs);
> > +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
> > +               find_equal_scalars(this_branch, dst_reg->id, &linked_regs);
> > +       if (linked_regs.cnt > 1) {
> 
> if we have just one, should it be even marked as linked?

Sorry, I don't understand. Do you suggest to add an additional check
in find_equal_scalars/collect_linked_regs and reset it if 'cnt' equals 1?

[...]
> 
> > +               err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs));
> > +               if (err)
> > +                       return err;
> > +       }
> > +
> >         other_branch = push_stack(env, *insn_idx + insn->off + 1, *insn_idx,
> >                                   false);
> >         if (!other_branch)
> > @@ -15336,13 +15539,13 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
> >         if (BPF_SRC(insn->code) == BPF_X &&
> >             src_reg->type == SCALAR_VALUE && src_reg->id &&
> >             !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
> > -               find_equal_scalars(this_branch, src_reg);
> > -               find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);
> > +               copy_known_reg(this_branch, src_reg, &linked_regs);
> > +               copy_known_reg(other_branch, &other_branch_regs[insn->src_reg], &linked_regs);
> 
> I liked the "sync" terminology you used for bt, so why not call this
> "sync_linked_regs" ?

I kept the current name for the function.
Suggested name makes sense, though.

[...]
Andrii Nakryiko July 10, 2024, 5:28 a.m. UTC | #3
On Tue, Jul 9, 2024 at 6:21 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Tue, 2024-07-09 at 17:34 -0700, Andrii Nakryiko wrote:
> > On Fri, Jul 5, 2024 at 1:59 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
> > >
> > > Use bpf_verifier_state->jmp_history to track which registers were
> > > updated by find_equal_scalars() when conditional jump was verified.
> > > Use recorded information in backtrack_insn() to propagate precision.
> > >
> > > E.g. for the following program:
> > >
> > >             while verifying instructions
> > >   r1 = r0              |
> > >   if r1 < 8  goto ...  | push r0,r1 as equal_scalars in jmp_history
> > >   if r0 > 16 goto ...  | push r0,r1 as equal_scalars in jmp_history
> >
> > linked_scalars? especially now that Alexei added offsets between
> > linked registers
>
> Missed this, will update.
>
> >
> > >   r2 = r10             |
> > >   r2 += r0             v mark_chain_precision(r0)
> > >
> > >             while doing mark_chain_precision(r0)
> > >   r1 = r0              ^
> > >   if r1 < 8  goto ...  | mark r0,r1 as precise
> > >   if r0 > 16 goto ...  | mark r0,r1 as precise
> > >   r2 = r10             |
> > >   r2 += r0             | mark r0 precise
> >
> > let's reverse the order here so it's linear in how the algorithm
> > actually works (backwards)?
>
> I thought the arrow would be enough. Ok, can reverse.

it's the reverse order compared to what you'd see in the verifier log.
I did see the arrow (though it wasn't all that clear on the first
reading), but still feels like it would be better to have consistent
order with verifier log

[...]

> > > @@ -3844,6 +3974,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> > >                          */
> > >                         bt_set_reg(bt, dreg);
> > >                         bt_set_reg(bt, sreg);
> > > +               } else if (BPF_SRC(insn->code) == BPF_K) {
> > >                          /* else dreg <cond> K
> >
> > drop "else" from the comment then? I like this change.
>
> This is actually a leftover from v1. I can drop "else" from the
> comment or drop this hunk as it is not necessary for the series.

I'd keep explicit `else if`

>
> > >                           * Only dreg still needs precision before
> > >                           * this insn, so for the K-based conditional
> > > @@ -3862,6 +3993,10 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> > >                         /* to be analyzed */
> > >                         return -ENOTSUPP;
> > >         }
> > > +       /* Propagate precision marks to linked registers, to account for
> > > +        * registers marked as precise in this function.
> > > +        */
> > > +       bt_sync_linked_regs(bt, hist);
> >
> > Radical Andrii is fine with this, though I wonder if there is some
> > place outside of backtrack_insn() where the first
> > bt_sync_linked_regs() could be called just once?
>
> The problem here is that:
> - in theory linked_regs could be present for any instruction, thus
>   sync() is needed after get_jmp_hist_entry call in
>   __mark_chain_precision();
> - backtrack_insn() might both remove and add some registers in bt,
>   hence, to correctly handle bt_empty() call in __mark_chain_precision
>   the sync() is also needed after backtrack_insn().
>
> So, current placement is the simplest I could come up with.

agreed, let's keep it as is

[...]

> > > @@ -15312,6 +15500,21 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
> > >                 return 0;
> > >         }
> > >
> > > +       /* Push scalar registers sharing same ID to jump history,
> > > +        * do this before creating 'other_branch', so that both
> > > +        * 'this_branch' and 'other_branch' share this history
> > > +        * if parent state is created.
> > > +        */
> > > +       if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
> > > +               find_equal_scalars(this_branch, src_reg->id, &linked_regs);
> > > +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
> > > +               find_equal_scalars(this_branch, dst_reg->id, &linked_regs);
> > > +       if (linked_regs.cnt > 1) {
> >
> > if we have just one, should it be even marked as linked?
>
> Sorry, I don't understand. Do you suggest to add an additional check
> in find_equal_scalars/collect_linked_regs and reset it if 'cnt' equals 1?

I find `if (linked_regs.cnt > 1)` check a bit weird and it feels like
it should be unnecessary. As soon as we are left with just one
"linked" register (linked with what? with itself?) it shouldn't be
linked anymore. Is there a point where we break the link between
registers where we can/should drop ID from the singularly linked
register? Why keep that scalar register ID set?

>
> [...]
> >
> > > +               err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs));
> > > +               if (err)
> > > +                       return err;
> > > +       }
> > > +
> > >         other_branch = push_stack(env, *insn_idx + insn->off + 1, *insn_idx,
> > >                                   false);
> > >         if (!other_branch)
> > > @@ -15336,13 +15539,13 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
> > >         if (BPF_SRC(insn->code) == BPF_X &&
> > >             src_reg->type == SCALAR_VALUE && src_reg->id &&
> > >             !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
> > > -               find_equal_scalars(this_branch, src_reg);
> > > -               find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);
> > > +               copy_known_reg(this_branch, src_reg, &linked_regs);
> > > +               copy_known_reg(other_branch, &other_branch_regs[insn->src_reg], &linked_regs);
> >
> > I liked the "sync" terminology you used for bt, so why not call this
> > "sync_linked_regs" ?
>
> I kept the current name for the function.
> Suggested name makes sense, though.
>
> [...]
Eduard Zingerman July 10, 2024, 6:36 a.m. UTC | #4
On Tue, 2024-07-09 at 22:28 -0700, Andrii Nakryiko wrote:

[...]

> > > >   r2 = r10             |
> > > >   r2 += r0             v mark_chain_precision(r0)
> > > > 
> > > >             while doing mark_chain_precision(r0)
> > > >   r1 = r0              ^
> > > >   if r1 < 8  goto ...  | mark r0,r1 as precise
> > > >   if r0 > 16 goto ...  | mark r0,r1 as precise
> > > >   r2 = r10             |
> > > >   r2 += r0             | mark r0 precise
> > > 
> > > let's reverse the order here so it's linear in how the algorithm
> > > actually works (backwards)?
> > 
> > I thought the arrow would be enough. Ok, can reverse.
> 
> it's the reverse order compared to what you'd see in the verifier log.
> I did see the arrow (though it wasn't all that clear on the first
> reading), but still feels like it would be better to have consistent
> order with verifier log

Ok, no problem

> 
> [...]
> 
> > > > @@ -3844,6 +3974,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> > > >                          */
> > > >                         bt_set_reg(bt, dreg);
> > > >                         bt_set_reg(bt, sreg);
> > > > +               } else if (BPF_SRC(insn->code) == BPF_K) {
> > > >                          /* else dreg <cond> K
> > > 
> > > drop "else" from the comment then? I like this change.
> > 
> > This is actually a leftover from v1. I can drop "else" from the
> > comment or drop this hunk as it is not necessary for the series.
> 
> I'd keep explicit `else if`

Ok, will do

[...]

> > > > @@ -15312,6 +15500,21 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
> > > >                 return 0;
> > > >         }
> > > > 
> > > > +       /* Push scalar registers sharing same ID to jump history,
> > > > +        * do this before creating 'other_branch', so that both
> > > > +        * 'this_branch' and 'other_branch' share this history
> > > > +        * if parent state is created.
> > > > +        */
> > > > +       if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
> > > > +               find_equal_scalars(this_branch, src_reg->id, &linked_regs);
> > > > +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
> > > > +               find_equal_scalars(this_branch, dst_reg->id, &linked_regs);
> > > > +       if (linked_regs.cnt > 1) {
> > > 
> > > if we have just one, should it be even marked as linked?
> > 
> > Sorry, I don't understand. Do you suggest to add an additional check
> > in find_equal_scalars/collect_linked_regs and reset it if 'cnt' equals 1?
> 
> I find `if (linked_regs.cnt > 1)` check a bit weird and it feels like
> it should be unnecessary. As soon as we are left with just one
> "linked" register (linked with what? with itself?) it shouldn't be
> linked anymore. Is there a point where we break the link between
> registers where we can/should drop ID from the singularly linked
> register? Why keep that scalar register ID set?

I can push this check inside find_equal_scalars/collect_linked_regs, e.g.:

collect_linked_regs(... linked_regs ...)
{
	...
	if (linked_regs.cnt == 1)
		linked_regs.cnt = 0;
	...
}

But then this particular place would have to be modified as follows:

	if (linked_regs.cnt > 0) {
		err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs));
		if (err)
			return err;
	}

Or something similar has to be done inside push_jmp_history().

[...]
Andrii Nakryiko July 10, 2024, 3:21 p.m. UTC | #5
On Tue, Jul 9, 2024 at 11:36 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Tue, 2024-07-09 at 22:28 -0700, Andrii Nakryiko wrote:
>
> [...]
>
> > > > >   r2 = r10             |
> > > > >   r2 += r0             v mark_chain_precision(r0)
> > > > >
> > > > >             while doing mark_chain_precision(r0)
> > > > >   r1 = r0              ^
> > > > >   if r1 < 8  goto ...  | mark r0,r1 as precise
> > > > >   if r0 > 16 goto ...  | mark r0,r1 as precise
> > > > >   r2 = r10             |
> > > > >   r2 += r0             | mark r0 precise
> > > >
> > > > let's reverse the order here so it's linear in how the algorithm
> > > > actually works (backwards)?
> > >
> > > I thought the arrow would be enough. Ok, can reverse.
> >
> > it's the reverse order compared to what you'd see in the verifier log.
> > I did see the arrow (though it wasn't all that clear on the first
> > reading), but still feels like it would be better to have consistent
> > order with verifier log
>
> Ok, no problem
>
> >
> > [...]
> >
> > > > > @@ -3844,6 +3974,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> > > > >                          */
> > > > >                         bt_set_reg(bt, dreg);
> > > > >                         bt_set_reg(bt, sreg);
> > > > > +               } else if (BPF_SRC(insn->code) == BPF_K) {
> > > > >                          /* else dreg <cond> K
> > > >
> > > > drop "else" from the comment then? I like this change.
> > >
> > > This is actually a leftover from v1. I can drop "else" from the
> > > comment or drop this hunk as it is not necessary for the series.
> >
> > I'd keep explicit `else if`
>
> Ok, will do
>
> [...]
>
> > > > > @@ -15312,6 +15500,21 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
> > > > >                 return 0;
> > > > >         }
> > > > >
> > > > > +       /* Push scalar registers sharing same ID to jump history,
> > > > > +        * do this before creating 'other_branch', so that both
> > > > > +        * 'this_branch' and 'other_branch' share this history
> > > > > +        * if parent state is created.
> > > > > +        */
> > > > > +       if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
> > > > > +               find_equal_scalars(this_branch, src_reg->id, &linked_regs);
> > > > > +       if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
> > > > > +               find_equal_scalars(this_branch, dst_reg->id, &linked_regs);
> > > > > +       if (linked_regs.cnt > 1) {
> > > >
> > > > if we have just one, should it be even marked as linked?
> > >
> > > Sorry, I don't understand. Do you suggest to add an additional check
> > > in find_equal_scalars/collect_linked_regs and reset it if 'cnt' equals 1?
> >
> > I find `if (linked_regs.cnt > 1)` check a bit weird and it feels like
> > it should be unnecessary. As soon as we are left with just one
> > "linked" register (linked with what? with itself?) it shouldn't be
> > linked anymore. Is there a point where we break the link between
> > registers where we can/should drop ID from the singularly linked
> > register? Why keep that scalar register ID set?
>
> I can push this check inside find_equal_scalars/collect_linked_regs, e.g.:
>
> collect_linked_regs(... linked_regs ...)
> {
>         ...
>         if (linked_regs.cnt == 1)
>                 linked_regs.cnt = 0;

I mean, fine, that's ok. But you are missing the point I'm making. I'm
saying there is somewhere in the verifier (and I'm too lazy/don't care
to go find where) where we break linked registers link (we reset ID on
one of them, probably). What I am asking is whether we should have a
check there to also reset ID on the last remaining
"kind-of-linked-but-not-really-anymore" register.


Anyways, this doesn't have to be solved right away, so let's do this
fixup you are proposing here and keep clean "linked_regs.cnt > 0"
check below.

>         ...
> }
>
> But then this particular place would have to be modified as follows:
>
>         if (linked_regs.cnt > 0) {

yes, this makes total sense ("are there any linked regs? if not, there
is nothing to push to history")

>                 err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs));
>                 if (err)
>                         return err;
>         }
>
> Or something similar has to be done inside push_jmp_history().

no need to push this inside push_jmp_history(), why paying the price
of linked_regs_pack() unnecessarily?

>
> [...]
diff mbox series

Patch

diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index 2b54e25d2364..da450552c278 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -371,6 +371,10 @@  struct bpf_jmp_history_entry {
 	u32 prev_idx : 22;
 	/* special flags, e.g., whether insn is doing register stack spill/load */
 	u32 flags : 10;
+	/* additional registers that need precision tracking when this
+	 * jump is backtracked, vector of six 10-bit records
+	 */
+	u64 linked_regs;
 };
 
 /* Maximum number of register states that can exist at once */
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index e25ad5fb9115..ec493360607e 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3335,9 +3335,87 @@  static bool is_jmp_point(struct bpf_verifier_env *env, int insn_idx)
 	return env->insn_aux_data[insn_idx].jmp_point;
 }
 
+#define ES_FRAMENO_BITS	3
+#define ES_SPI_BITS	6
+#define ES_ENTRY_BITS	(ES_SPI_BITS + ES_FRAMENO_BITS + 1)
+#define ES_SIZE_BITS	4
+#define ES_FRAMENO_MASK	((1ul << ES_FRAMENO_BITS) - 1)
+#define ES_SPI_MASK	((1ul << ES_SPI_BITS)     - 1)
+#define ES_SIZE_MASK	((1ul << ES_SIZE_BITS)    - 1)
+#define ES_SPI_OFF	ES_FRAMENO_BITS
+#define ES_IS_REG_OFF	(ES_SPI_BITS + ES_FRAMENO_BITS)
+#define LINKED_REGS_MAX	6
+
+struct reg_or_spill {
+	u8 frameno:3;
+	union {
+		u8 spi:6;
+		u8 regno:6;
+	};
+	bool is_reg:1;
+};
+
+struct linked_regs {
+	int cnt;
+	struct reg_or_spill entries[LINKED_REGS_MAX];
+};
+
+static struct reg_or_spill *linked_regs_push(struct linked_regs *s)
+{
+	if (s->cnt < LINKED_REGS_MAX)
+		return &s->entries[s->cnt++];
+
+	return NULL;
+}
+
+/* Use u64 as a vector of 6 10-bit values, use first 4-bits to track
+ * number of elements currently in stack.
+ * Pack one history entry for equal scalars as 10 bits in the following format:
+ * - 3-bits frameno
+ * - 6-bits spi_or_reg
+ * - 1-bit  is_reg
+ */
+static u64 linked_regs_pack(struct linked_regs *s)
+{
+	u64 val = 0;
+	int i;
+
+	for (i = 0; i < s->cnt; ++i) {
+		struct reg_or_spill *e = &s->entries[i];
+		u64 tmp = 0;
+
+		tmp |= e->frameno;
+		tmp |= e->spi << ES_SPI_OFF;
+		tmp |= (e->is_reg ? 1 : 0) << ES_IS_REG_OFF;
+
+		val <<= ES_ENTRY_BITS;
+		val |= tmp;
+	}
+	val <<= ES_SIZE_BITS;
+	val |= s->cnt;
+	return val;
+}
+
+static void linked_regs_unpack(u64 val, struct linked_regs *s)
+{
+	int i;
+
+	s->cnt = val & ES_SIZE_MASK;
+	val >>= ES_SIZE_BITS;
+
+	for (i = 0; i < s->cnt; ++i) {
+		struct reg_or_spill *e = &s->entries[i];
+
+		e->frameno =  val & ES_FRAMENO_MASK;
+		e->spi     = (val >> ES_SPI_OFF) & ES_SPI_MASK;
+		e->is_reg  = (val >> ES_IS_REG_OFF) & 0x1;
+		val >>= ES_ENTRY_BITS;
+	}
+}
+
 /* for any branch, call, exit record the history of jmps in the given state */
 static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_state *cur,
-			    int insn_flags)
+			    int insn_flags, u64 linked_regs)
 {
 	u32 cnt = cur->jmp_history_cnt;
 	struct bpf_jmp_history_entry *p;
@@ -3353,6 +3431,10 @@  static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
 			  "verifier insn history bug: insn_idx %d cur flags %x new flags %x\n",
 			  env->insn_idx, env->cur_hist_ent->flags, insn_flags);
 		env->cur_hist_ent->flags |= insn_flags;
+		WARN_ONCE(env->cur_hist_ent->linked_regs != 0,
+			  "verifier insn history bug: insn_idx %d linked_regs != 0: %#llx\n",
+			  env->insn_idx, env->cur_hist_ent->linked_regs);
+		env->cur_hist_ent->linked_regs = linked_regs;
 		return 0;
 	}
 
@@ -3367,6 +3449,7 @@  static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
 	p->idx = env->insn_idx;
 	p->prev_idx = env->prev_insn_idx;
 	p->flags = insn_flags;
+	p->linked_regs = linked_regs;
 	cur->jmp_history_cnt = cnt;
 	env->cur_hist_ent = p;
 
@@ -3532,6 +3615,11 @@  static inline bool bt_is_reg_set(struct backtrack_state *bt, u32 reg)
 	return bt->reg_masks[bt->frame] & (1 << reg);
 }
 
+static inline bool bt_is_frame_reg_set(struct backtrack_state *bt, u32 frame, u32 reg)
+{
+	return bt->reg_masks[frame] & (1 << reg);
+}
+
 static inline bool bt_is_frame_slot_set(struct backtrack_state *bt, u32 frame, u32 slot)
 {
 	return bt->stack_masks[frame] & (1ull << slot);
@@ -3576,6 +3664,42 @@  static void fmt_stack_mask(char *buf, ssize_t buf_sz, u64 stack_mask)
 	}
 }
 
+/* If any register R in hist->linked_regs is marked as precise in bt,
+ * do bt_set_frame_{reg,slot}(bt, R) for all registers in hist->linked_regs.
+ */
+static void bt_sync_linked_regs(struct backtrack_state *bt, struct bpf_jmp_history_entry *hist)
+{
+	struct linked_regs linked_regs;
+	bool some_precise = false;
+	int i;
+
+	if (!hist || hist->linked_regs == 0)
+		return;
+
+	linked_regs_unpack(hist->linked_regs, &linked_regs);
+	for (i = 0; i < linked_regs.cnt; ++i) {
+		struct reg_or_spill *e = &linked_regs.entries[i];
+
+		if ((e->is_reg && bt_is_frame_reg_set(bt, e->frameno, e->regno)) ||
+		    (!e->is_reg && bt_is_frame_slot_set(bt, e->frameno, e->spi))) {
+			some_precise = true;
+			break;
+		}
+	}
+
+	if (!some_precise)
+		return;
+
+	for (i = 0; i < linked_regs.cnt; ++i) {
+		struct reg_or_spill *e = &linked_regs.entries[i];
+
+		if (e->is_reg)
+			bt_set_frame_reg(bt, e->frameno, e->regno);
+		else
+			bt_set_frame_slot(bt, e->frameno, e->spi);
+	}
+}
+
 static bool calls_callback(struct bpf_verifier_env *env, int insn_idx);
 
 /* For given verifier state backtrack_insn() is called from the last insn to
@@ -3615,6 +3739,12 @@  static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
 		print_bpf_insn(&cbs, insn, env->allow_ptr_leaks);
 	}
 
+	/* If there is a history record that some registers gained range at this insn,
+	 * propagate precision marks to those registers, so that bt_is_reg_set()
+	 * accounts for these registers.
+	 */
+	bt_sync_linked_regs(bt, hist);
+
 	if (class == BPF_ALU || class == BPF_ALU64) {
 		if (!bt_is_reg_set(bt, dreg))
 			return 0;
@@ -3844,6 +3974,7 @@  static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
 			 */
 			bt_set_reg(bt, dreg);
 			bt_set_reg(bt, sreg);
+		} else if (BPF_SRC(insn->code) == BPF_K) {
 			 /* else dreg <cond> K
 			  * Only dreg still needs precision before
 			  * this insn, so for the K-based conditional
@@ -3862,6 +3993,10 @@  static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
 			/* to be analyzed */
 			return -ENOTSUPP;
 	}
+	/* Propagate precision marks to linked registers, to account for
+	 * registers marked as precise in this function.
+	 */
+	bt_sync_linked_regs(bt, hist);
 	return 0;
 }
 
@@ -4624,7 +4759,7 @@  static int check_stack_write_fixed_off(struct bpf_verifier_env *env,
 	}
 
 	if (insn_flags)
-		return push_jmp_history(env, env->cur_state, insn_flags);
+		return push_jmp_history(env, env->cur_state, insn_flags, 0);
 	return 0;
 }
 
@@ -4929,7 +5064,7 @@  static int check_stack_read_fixed_off(struct bpf_verifier_env *env,
 		insn_flags = 0; /* we are not restoring spilled register */
 	}
 	if (insn_flags)
-		return push_jmp_history(env, env->cur_state, insn_flags);
+		return push_jmp_history(env, env->cur_state, insn_flags, 0);
 	return 0;
 }
 
@@ -15154,14 +15289,66 @@  static bool try_match_pkt_pointers(const struct bpf_insn *insn,
 	return true;
 }
 
-static void find_equal_scalars(struct bpf_verifier_state *vstate,
-			       struct bpf_reg_state *known_reg)
+static void __find_equal_scalars(struct linked_regs *reg_set, struct bpf_reg_state *reg,
+				 u32 id, u32 frameno, u32 spi_or_reg, bool is_reg)
+{
+	struct reg_or_spill *e;
+
+	if (reg->type != SCALAR_VALUE || (reg->id & ~BPF_ADD_CONST) != id)
+		return;
+
+	e = linked_regs_push(reg_set);
+	if (e) {
+		e->frameno = frameno;
+		e->is_reg = is_reg;
+		e->regno = spi_or_reg;
+	} else {
+		reg->id = 0;
+	}
+}
+
+/* For all R being scalar registers or spilled scalar registers
+ * in verifier state, save R in linked_regs if R->id == id.
+ * If there are too many Rs sharing same id, reset id for leftover Rs.
+ */
+static void find_equal_scalars(struct bpf_verifier_state *vstate, u32 id,
+			       struct linked_regs *linked_regs)
+{
+	struct bpf_func_state *func;
+	struct bpf_reg_state *reg;
+	int i, j;
+
+	id = id & ~BPF_ADD_CONST;
+	for (i = vstate->curframe; i >= 0; i--) {
+		func = vstate->frame[i];
+		for (j = 0; j < BPF_REG_FP; j++) {
+			reg = &func->regs[j];
+			__find_equal_scalars(linked_regs, reg, id, i, j, true);
+		}
+		for (j = 0; j < func->allocated_stack / BPF_REG_SIZE; j++) {
+			if (!is_spilled_reg(&func->stack[j]))
+				continue;
+			reg = &func->stack[j].spilled_ptr;
+			__find_equal_scalars(linked_regs, reg, id, i, j, false);
+		}
+	}
+}
+
+/* For all R in linked_regs, copy known_reg range into R
+ * if R->id == known_reg->id.
+ */
+static void copy_known_reg(struct bpf_verifier_state *vstate, struct bpf_reg_state *known_reg,
+			   struct linked_regs *linked_regs)
 {
 	struct bpf_reg_state fake_reg;
-	struct bpf_func_state *state;
 	struct bpf_reg_state *reg;
+	struct reg_or_spill *e;
+	int i;
 
-	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
+	for (i = 0; i < linked_regs->cnt; ++i) {
+		e = &linked_regs->entries[i];
+		reg = e->is_reg ? &vstate->frame[e->frameno]->regs[e->regno]
+				: &vstate->frame[e->frameno]->stack[e->spi].spilled_ptr;
 		if (reg->type != SCALAR_VALUE || reg == known_reg)
 			continue;
 		if ((reg->id & ~BPF_ADD_CONST) != (known_reg->id & ~BPF_ADD_CONST))
@@ -15187,7 +15374,7 @@  static void find_equal_scalars(struct bpf_verifier_state *vstate,
 			scalar_min_max_add(reg, &fake_reg);
 			reg->var_off = tnum_add(reg->var_off, fake_reg.var_off);
 		}
-	}));
+	}
 }
 
 static int check_cond_jmp_op(struct bpf_verifier_env *env,
@@ -15198,6 +15385,7 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs;
 	struct bpf_reg_state *dst_reg, *other_branch_regs, *src_reg = NULL;
 	struct bpf_reg_state *eq_branch_regs;
+	struct linked_regs linked_regs = {};
 	struct bpf_reg_state fake_reg = {};
 	u8 opcode = BPF_OP(insn->code);
 	bool is_jmp32;
@@ -15312,6 +15500,21 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 		return 0;
 	}
 
+	/* Push scalar registers sharing same ID to jump history,
+	 * do this before creating 'other_branch', so that both
+	 * 'this_branch' and 'other_branch' share this history
+	 * if parent state is created.
+	 */
+	if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
+		find_equal_scalars(this_branch, src_reg->id, &linked_regs);
+	if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
+		find_equal_scalars(this_branch, dst_reg->id, &linked_regs);
+	if (linked_regs.cnt > 1) {
+		err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs));
+		if (err)
+			return err;
+	}
+
 	other_branch = push_stack(env, *insn_idx + insn->off + 1, *insn_idx,
 				  false);
 	if (!other_branch)
@@ -15336,13 +15539,13 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	if (BPF_SRC(insn->code) == BPF_X &&
 	    src_reg->type == SCALAR_VALUE && src_reg->id &&
 	    !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
-		find_equal_scalars(this_branch, src_reg);
-		find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);
+		copy_known_reg(this_branch, src_reg, &linked_regs);
+		copy_known_reg(other_branch, &other_branch_regs[insn->src_reg], &linked_regs);
 	}
 	if (dst_reg->type == SCALAR_VALUE && dst_reg->id &&
 	    !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) {
-		find_equal_scalars(this_branch, dst_reg);
-		find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);
+		copy_known_reg(this_branch, dst_reg, &linked_regs);
+		copy_known_reg(other_branch, &other_branch_regs[insn->dst_reg], &linked_regs);
 	}
 
 	/* if one pointer register is compared to another pointer
@@ -17624,7 +17827,7 @@  static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 			 * the current state.
 			 */
 			if (is_jmp_point(env, env->insn_idx))
-				err = err ? : push_jmp_history(env, cur, 0);
+				err = err ? : push_jmp_history(env, cur, 0, 0);
 			err = err ? : propagate_precision(env, &sl->state);
 			if (err)
 				return err;
@@ -17892,7 +18095,7 @@  static int do_check(struct bpf_verifier_env *env)
 		}
 
 		if (is_jmp_point(env, env->insn_idx)) {
-			err = push_jmp_history(env, state, 0);
+			err = push_jmp_history(env, state, 0, 0);
 			if (err)
 				return err;
 		}
diff --git a/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c b/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c
index 6a6fad625f7e..9d415f7ce599 100644
--- a/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c
+++ b/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c
@@ -278,7 +278,7 @@  __msg("mark_precise: frame0: last_idx 14 first_idx 9")
 __msg("mark_precise: frame0: regs=r6 stack= before 13: (bf) r1 = r7")
 __msg("mark_precise: frame0: regs=r6 stack= before 12: (27) r6 *= 4")
 __msg("mark_precise: frame0: regs=r6 stack= before 11: (25) if r6 > 0x3 goto pc+4")
-__msg("mark_precise: frame0: regs=r6 stack= before 10: (bf) r6 = r0")
+__msg("mark_precise: frame0: regs=r0,r6 stack= before 10: (bf) r6 = r0")
 __msg("mark_precise: frame0: regs=r0 stack= before 9: (85) call bpf_loop")
 /* State entering callback body popped from states stack */
 __msg("from 9 to 17: frame1:")
diff --git a/tools/testing/selftests/bpf/verifier/precise.c b/tools/testing/selftests/bpf/verifier/precise.c
index 90643ccc221d..64d722199e8f 100644
--- a/tools/testing/selftests/bpf/verifier/precise.c
+++ b/tools/testing/selftests/bpf/verifier/precise.c
@@ -39,11 +39,11 @@ 
 	.result = VERBOSE_ACCEPT,
 	.errstr =
 	"mark_precise: frame0: last_idx 26 first_idx 20\
-	mark_precise: frame0: regs=r2,r9 stack= before 25\
-	mark_precise: frame0: regs=r2,r9 stack= before 24\
-	mark_precise: frame0: regs=r2,r9 stack= before 23\
-	mark_precise: frame0: regs=r2,r9 stack= before 22\
-	mark_precise: frame0: regs=r2,r9 stack= before 20\
+	mark_precise: frame0: regs=r2 stack= before 25\
+	mark_precise: frame0: regs=r2 stack= before 24\
+	mark_precise: frame0: regs=r2 stack= before 23\
+	mark_precise: frame0: regs=r2 stack= before 22\
+	mark_precise: frame0: regs=r2 stack= before 20\
 	mark_precise: frame0: parent state regs=r2,r9 stack=:\
 	mark_precise: frame0: last_idx 19 first_idx 10\
 	mark_precise: frame0: regs=r2,r9 stack= before 19\
@@ -100,11 +100,11 @@ 
 	.errstr =
 	"26: (85) call bpf_probe_read_kernel#113\
 	mark_precise: frame0: last_idx 26 first_idx 22\
-	mark_precise: frame0: regs=r2,r9 stack= before 25\
-	mark_precise: frame0: regs=r2,r9 stack= before 24\
-	mark_precise: frame0: regs=r2,r9 stack= before 23\
-	mark_precise: frame0: regs=r2,r9 stack= before 22\
-	mark_precise: frame0: parent state regs=r2,r9 stack=:\
+	mark_precise: frame0: regs=r2 stack= before 25\
+	mark_precise: frame0: regs=r2 stack= before 24\
+	mark_precise: frame0: regs=r2 stack= before 23\
+	mark_precise: frame0: regs=r2 stack= before 22\
+	mark_precise: frame0: parent state regs=r2 stack=:\
 	mark_precise: frame0: last_idx 20 first_idx 20\
 	mark_precise: frame0: regs=r2,r9 stack= before 20\
 	mark_precise: frame0: parent state regs=r2,r9 stack=:\