diff mbox series

[RFC,bpf-next,v2,1/2] bpf, x64: Fix tailcall infinite loop

Message ID 20230818151216.7686-2-hffilwlqm@gmail.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series bpf, x64: Fix tailcall infinite loop | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR pending PR summary
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-6 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-2 success Logs for build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-4 success Logs for build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-5 success Logs for build for x86_64 with llvm-16
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 2828 this patch: 2828
netdev/cc_maintainers warning 16 maintainers not CCed: tglx@linutronix.de hpa@zytor.com dsahern@kernel.org mingo@redhat.com kpsingh@kernel.org x86@kernel.org john.fastabend@gmail.com sdf@google.com netdev@vger.kernel.org martin.lau@linux.dev yonghong.song@linux.dev dave.hansen@linux.intel.com davem@davemloft.net jolsa@kernel.org haoluo@google.com bp@alien8.de
netdev/build_clang success Errors and warnings before: 1526 this patch: 1526
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 2856 this patch: 2856
netdev/checkpatch warning WARNING: line length of 82 exceeds 80 columns
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline fail Was 0 now: 1
bpf/vmtest-bpf-next-VM_Test-3 success Logs for build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-8 pending Logs for test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-12 pending Logs for test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-16 pending Logs for test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-26 pending Logs for test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for test_maps on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-11 success Logs for test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-13 success Logs for test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-15 success Logs for test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-19 success Logs for test_progs_no_alu32_parallel on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for test_progs_no_alu32_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-22 success Logs for test_progs_parallel on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-24 success Logs for test_progs_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-25 success Logs for test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-28 success Logs for test_verifier on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-29 success Logs for veristat
bpf/vmtest-bpf-next-VM_Test-14 success Logs for test_progs on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-18 success Logs for test_progs_no_alu32 on x86_64 with llvm-16

Commit Message

Leon Hwang Aug. 18, 2023, 3:12 p.m. UTC
From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall
handling in JIT"), the tailcall on x64 works better than before.

From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms
for x64 JIT"), tailcall is able to run in BPF subprograms on x64.

From commit 5b92a28aae4dd0f8 ("bpf: Support attaching tracing BPF program
to other BPF programs"), BPF program is able to trace other BPF programs.

How about combining them all together?

1. FENTRY/FEXIT on a BPF subprogram.
2. A tailcall runs in the BPF subprogram.
3. The tailcall calls itself.

As a result, a tailcall infinite loop comes up. And the loop would halt
the machine.

As we know, in tail call context, the tail_call_cnt propagates by stack
and RAX register between BPF subprograms. So do it in trampolines.

Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
---
 arch/x86/net/bpf_jit_comp.c | 40 +++++++++++++++++++++++++++++--------
 include/linux/bpf.h         |  5 +++++
 kernel/bpf/trampoline.c     |  4 ++--
 kernel/bpf/verifier.c       | 31 +++++++++++++++++++++-------
 4 files changed, 63 insertions(+), 17 deletions(-)

Comments

Leon Hwang Aug. 18, 2023, 3:25 p.m. UTC | #1
On 2023/8/18 23:12, Leon Hwang wrote:
> From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall
> handling in JIT"), the tailcall on x64 works better than before.
> 
> From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms
> for x64 JIT"), tailcall is able to run in BPF subprograms on x64.
> 
> From commit 5b92a28aae4dd0f8 ("bpf: Support attaching tracing BPF program
> to other BPF programs"), BPF program is able to trace other BPF programs.
> 
> How about combining them all together?
> 
> 1. FENTRY/FEXIT on a BPF subprogram.
> 2. A tailcall runs in the BPF subprogram.
> 3. The tailcall calls itself.
> 
> As a result, a tailcall infinite loop comes up. And the loop would halt
> the machine.
> 
> As we know, in tail call context, the tail_call_cnt propagates by stack
> and RAX register between BPF subprograms. So do it in trampolines.
> 
> Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
> ---
>  arch/x86/net/bpf_jit_comp.c | 40 +++++++++++++++++++++++++++++--------
>  include/linux/bpf.h         |  5 +++++
>  kernel/bpf/trampoline.c     |  4 ++--
>  kernel/bpf/verifier.c       | 31 +++++++++++++++++++++-------
>  4 files changed, 63 insertions(+), 17 deletions(-)
> 
> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
> index a5930042139d3..1ad17d7de5eee 100644
> --- a/arch/x86/net/bpf_jit_comp.c
> +++ b/arch/x86/net/bpf_jit_comp.c
> @@ -303,8 +303,12 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
>  	prog += X86_PATCH_SIZE;
>  	if (!ebpf_from_cbpf) {
>  		if (tail_call_reachable && !is_subprog)
> +			/* When it's the entry of the whole tailcall context,
> +			 * zeroing rax means initialising tail_call_cnt.
> +			 */
>  			EMIT2(0x31, 0xC0); /* xor eax, eax */
>  		else
> +			// Keep the same instruction layout.
>  			EMIT2(0x66, 0x90); /* nop2 */
>  	}
>  	EMIT1(0x55);             /* push rbp */
> @@ -1018,6 +1022,10 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
>  
>  #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
>  
> +/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
> +#define RESTORE_TAIL_CALL_CNT(stack)				\
> +	EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8)
> +
>  static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
>  		  int oldproglen, struct jit_context *ctx, bool jmp_padding)
>  {
> @@ -1623,9 +1631,7 @@ st:			if (is_imm8(insn->off))
>  
>  			func = (u8 *) __bpf_call_base + imm32;
>  			if (tail_call_reachable) {
> -				/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
> -				EMIT3_off32(0x48, 0x8B, 0x85,
> -					    -round_up(bpf_prog->aux->stack_depth, 8) - 8);
> +				RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
>  				if (!imm32)
>  					return -EINVAL;
>  				offs = 7 + x86_call_depth_emit_accounting(&prog, func);
> @@ -2298,7 +2304,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
>   * push rbp
>   * mov rbp, rsp
>   * sub rsp, 16                     // space for skb and dev
> - * push rbx                        // temp regs to pass start time
> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
> + * mov rax, 2                      // cache number of argument to rax
> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
>   * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
>   * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
> @@ -2323,7 +2331,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
>   * push rbp
>   * mov rbp, rsp
>   * sub rsp, 24                     // space for skb, dev, return value
> - * push rbx                        // temp regs to pass start time
> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
> + * mov rax, 2                      // cache number of argument to rax
> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
>   * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
>   * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
> @@ -2400,6 +2410,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>  	 *                     [ ...        ]
>  	 *                     [ stack_arg2 ]
>  	 * RBP - arg_stack_off [ stack_arg1 ]
> +	 * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
>  	 */
>  
>  	/* room for return value of orig_call or fentry prog */
> @@ -2464,6 +2475,8 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>  	else
>  		/* sub rsp, stack_size */
>  		EMIT4(0x48, 0x83, 0xEC, stack_size);
> +	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +		EMIT1(0x50);		/* push rax */
>  	/* mov QWORD PTR [rbp - rbx_off], rbx */
>  	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
>  
> @@ -2516,9 +2529,15 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>  		restore_regs(m, &prog, regs_off);
>  		save_args(m, &prog, arg_stack_off, true);
>  
> +		if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +			/* Before calling the original function, restore the
> +			 * tail_call_cnt from stack to rax.
> +			 */
> +			RESTORE_TAIL_CALL_CNT(stack_size);
> +
>  		if (flags & BPF_TRAMP_F_ORIG_STACK) {
> -			emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
> -			EMIT2(0xff, 0xd0); /* call *rax */
> +			emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
> +			EMIT2(0xff, 0xd3); /* call *rbx */ // FIXME: Confirm 0xd3?

To avoiding rax conflict with tail call, change the calling register
from rax to rbx

But, I'm unable to confirm the opcode.

Then, I asked chatGPT to list `call` and corresponding opcode:

Certainly! Here's a table that provides `call` instructions along with
their corresponding opcodes in x86-64 assembly:

| `call` Register | Opcode (Hex) | Opcode (Binary) |
|-----------------|--------------|-----------------|
| `rax`           | `FF D0`      | `11111111 11010000` |
| `rcx`           | `FF D1`      | `11111111 11010001` |
| `rdx`           | `FF D2`      | `11111111 11010010` |
| `rbx`           | `FF D3`      | `11111111 11010011` |
| `rsp`           | `FF D4`      | `11111111 11010100` |
| `rbp`           | `FF D5`      | `11111111 11010101` |
| `rsi`           | `FF D6`      | `11111111 11010110` |
| `rdi`           | `FF D7`      | `11111111 11010111` |
| `r8`            | `41 FF D0`   | `01000001 11111111 11010000` |
| `r9`            | `41 FF D1`   | `01000001 11111111 11010001` |
| `r10`           | `41 FF D2`   | `01000001 11111111 11010010` |
| `r11`           | `41 FF D3`   | `01000001 11111111 11010011` |
| `r12`           | `41 FF D4`   | `01000001 11111111 11010100` |
| `r13`           | `41 FF D5`   | `01000001 11111111 11010101` |
| `r14`           | `41 FF D6`   | `01000001 11111111 11010110` |
| `r15`           | `41 FF D7`   | `01000001 11111111 11010111` |

EMIT2(0xff, 0xd3); /* call *rbx */, is it right?

Thanks,
Leon

[...]
Alexei Starovoitov Aug. 21, 2023, 10:33 p.m. UTC | #2
On Fri, Aug 18, 2023 at 8:12 AM Leon Hwang <hffilwlqm@gmail.com> wrote:
>
> From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall
> handling in JIT"), the tailcall on x64 works better than before.
>
> From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms
> for x64 JIT"), tailcall is able to run in BPF subprograms on x64.
>
> From commit 5b92a28aae4dd0f8 ("bpf: Support attaching tracing BPF program
> to other BPF programs"), BPF program is able to trace other BPF programs.
>
> How about combining them all together?
>
> 1. FENTRY/FEXIT on a BPF subprogram.
> 2. A tailcall runs in the BPF subprogram.
> 3. The tailcall calls itself.
>
> As a result, a tailcall infinite loop comes up. And the loop would halt
> the machine.
>
> As we know, in tail call context, the tail_call_cnt propagates by stack
> and RAX register between BPF subprograms. So do it in trampolines.
>
> Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
> ---
>  arch/x86/net/bpf_jit_comp.c | 40 +++++++++++++++++++++++++++++--------
>  include/linux/bpf.h         |  5 +++++
>  kernel/bpf/trampoline.c     |  4 ++--
>  kernel/bpf/verifier.c       | 31 +++++++++++++++++++++-------
>  4 files changed, 63 insertions(+), 17 deletions(-)
>
> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
> index a5930042139d3..1ad17d7de5eee 100644
> --- a/arch/x86/net/bpf_jit_comp.c
> +++ b/arch/x86/net/bpf_jit_comp.c
> @@ -303,8 +303,12 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
>         prog += X86_PATCH_SIZE;
>         if (!ebpf_from_cbpf) {
>                 if (tail_call_reachable && !is_subprog)
> +                       /* When it's the entry of the whole tailcall context,
> +                        * zeroing rax means initialising tail_call_cnt.
> +                        */
>                         EMIT2(0x31, 0xC0); /* xor eax, eax */
>                 else
> +                       // Keep the same instruction layout.

No c++ style comments please.

>                         EMIT2(0x66, 0x90); /* nop2 */
>         }
>         EMIT1(0x55);             /* push rbp */
> @@ -1018,6 +1022,10 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
>
>  #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
>
> +/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
> +#define RESTORE_TAIL_CALL_CNT(stack)                           \
> +       EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8)
> +
>  static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
>                   int oldproglen, struct jit_context *ctx, bool jmp_padding)
>  {
> @@ -1623,9 +1631,7 @@ st:                       if (is_imm8(insn->off))
>
>                         func = (u8 *) __bpf_call_base + imm32;
>                         if (tail_call_reachable) {
> -                               /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
> -                               EMIT3_off32(0x48, 0x8B, 0x85,
> -                                           -round_up(bpf_prog->aux->stack_depth, 8) - 8);
> +                               RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
>                                 if (!imm32)
>                                         return -EINVAL;
>                                 offs = 7 + x86_call_depth_emit_accounting(&prog, func);
> @@ -2298,7 +2304,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
>   * push rbp
>   * mov rbp, rsp
>   * sub rsp, 16                     // space for skb and dev
> - * push rbx                        // temp regs to pass start time
> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
> + * mov rax, 2                      // cache number of argument to rax

What does it mean?

> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack

Here // is ok since it's inside /* */

>   * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
>   * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
> @@ -2323,7 +2331,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
>   * push rbp
>   * mov rbp, rsp
>   * sub rsp, 24                     // space for skb, dev, return value
> - * push rbx                        // temp regs to pass start time
> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
> + * mov rax, 2                      // cache number of argument to rax
> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
>   * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
>   * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
> @@ -2400,6 +2410,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>          *                     [ ...        ]
>          *                     [ stack_arg2 ]
>          * RBP - arg_stack_off [ stack_arg1 ]
> +        * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
>          */
>
>         /* room for return value of orig_call or fentry prog */
> @@ -2464,6 +2475,8 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>         else
>                 /* sub rsp, stack_size */
>                 EMIT4(0x48, 0x83, 0xEC, stack_size);
> +       if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +               EMIT1(0x50);            /* push rax */
>         /* mov QWORD PTR [rbp - rbx_off], rbx */
>         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
>
> @@ -2516,9 +2529,15 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>                 restore_regs(m, &prog, regs_off);
>                 save_args(m, &prog, arg_stack_off, true);
>
> +               if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +                       /* Before calling the original function, restore the
> +                        * tail_call_cnt from stack to rax.
> +                        */
> +                       RESTORE_TAIL_CALL_CNT(stack_size);
> +
>                 if (flags & BPF_TRAMP_F_ORIG_STACK) {
> -                       emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
> -                       EMIT2(0xff, 0xd0); /* call *rax */
> +                       emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
> +                       EMIT2(0xff, 0xd3); /* call *rbx */ // FIXME: Confirm 0xd3?

please no FIXME like comments.
You have to be confident in the code you're submitting.
llvm-mc -triple=x86_64 -show-encoding -x86-asm-syntax=intel
-output-asm-variant=1 <<< 'call rbx'

>                 } else {
>                         /* call original function */
>                         if (emit_rsb_call(&prog, orig_call, prog)) {
> @@ -2569,7 +2588,12 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>                         ret = -EINVAL;
>                         goto cleanup;
>                 }
> -       }
> +       } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> +               /* Before running the original function, restore the
> +                * tail_call_cnt from stack to rax.
> +                */
> +               RESTORE_TAIL_CALL_CNT(stack_size);
> +
>         /* restore return value of orig_call or fentry prog back into RAX */
>         if (save_ret)
>                 emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index cfabbcf47bdb8..c8df257ea435d 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -1028,6 +1028,11 @@ struct btf_func_model {
>   */
>  #define BPF_TRAMP_F_SHARE_IPMODIFY     BIT(6)
>
> +/* Indicate that current trampoline is in a tail call context. Then, it has to
> + * cache and restore tail_call_cnt to avoid infinite tail call loop.
> + */
> +#define BPF_TRAMP_F_TAIL_CALL_CTX      BIT(7)
> +
>  /* Each call __bpf_prog_enter + call bpf_func + call __bpf_prog_exit is ~50
>   * bytes on x86.
>   */
> diff --git a/kernel/bpf/trampoline.c b/kernel/bpf/trampoline.c
> index 78acf28d48732..16ab5da7161f2 100644
> --- a/kernel/bpf/trampoline.c
> +++ b/kernel/bpf/trampoline.c
> @@ -415,8 +415,8 @@ static int bpf_trampoline_update(struct bpf_trampoline *tr, bool lock_direct_mut
>                 goto out;
>         }
>
> -       /* clear all bits except SHARE_IPMODIFY */
> -       tr->flags &= BPF_TRAMP_F_SHARE_IPMODIFY;
> +       /* clear all bits except SHARE_IPMODIFY and TAIL_CALL_CTX */
> +       tr->flags &= (BPF_TRAMP_F_SHARE_IPMODIFY | BPF_TRAMP_F_TAIL_CALL_CTX);
>
>         if (tlinks[BPF_TRAMP_FEXIT].nr_links ||
>             tlinks[BPF_TRAMP_MODIFY_RETURN].nr_links) {
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 4ccca1f6c9981..52ba9b043f16e 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -19246,6 +19246,21 @@ static int check_non_sleepable_error_inject(u32 btf_id)
>         return btf_id_set_contains(&btf_non_sleepable_error_inject, btf_id);
>  }
>
> +static inline int find_subprog_index(const struct bpf_prog *prog,
> +                                    u32 btf_id)
> +{
> +       struct bpf_prog_aux *aux = prog->aux;
> +       int i, subprog = -1;
> +
> +       for (i = 0; i < aux->func_info_cnt; i++)
> +               if (aux->func_info[i].type_id == btf_id) {
> +                       subprog = i;
> +                       break;
> +               }
> +
> +       return subprog;
> +}
> +
>  int bpf_check_attach_target(struct bpf_verifier_log *log,
>                             const struct bpf_prog *prog,
>                             const struct bpf_prog *tgt_prog,
> @@ -19254,9 +19269,9 @@ int bpf_check_attach_target(struct bpf_verifier_log *log,
>  {
>         bool prog_extension = prog->type == BPF_PROG_TYPE_EXT;
>         const char prefix[] = "btf_trace_";
> -       int ret = 0, subprog = -1, i;
>         const struct btf_type *t;
>         bool conservative = true;
> +       int ret = 0, subprog;
>         const char *tname;
>         struct btf *btf;
>         long addr = 0;
> @@ -19291,11 +19306,7 @@ int bpf_check_attach_target(struct bpf_verifier_log *log,
>                         return -EINVAL;
>                 }
>
> -               for (i = 0; i < aux->func_info_cnt; i++)
> -                       if (aux->func_info[i].type_id == btf_id) {
> -                               subprog = i;
> -                               break;
> -                       }
> +               subprog = find_subprog_index(tgt_prog, btf_id);
>                 if (subprog == -1) {
>                         bpf_log(log, "Subprog %s doesn't exist\n", tname);
>                         return -EINVAL;
> @@ -19559,7 +19570,7 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
>         struct bpf_attach_target_info tgt_info = {};
>         u32 btf_id = prog->aux->attach_btf_id;
>         struct bpf_trampoline *tr;
> -       int ret;
> +       int ret, subprog;
>         u64 key;
>
>         if (prog->type == BPF_PROG_TYPE_SYSCALL) {
> @@ -19629,6 +19640,12 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
>         if (!tr)
>                 return -ENOMEM;
>
> +       if (tgt_prog && tgt_prog->aux->tail_call_reachable) {
> +               subprog = find_subprog_index(tgt_prog, btf_id);
> +               tr->flags = subprog > 0 && tgt_prog->aux->func[subprog]->is_func ?
> +                           BPF_TRAMP_F_TAIL_CALL_CTX : 0;

If prog has subprogs all of them will 'is_func', no?
What's the point of the search ?
Just tgt_prog->aux->tail_call_reachable and func_cnt > 0 would be enough?
Leon Hwang Aug. 22, 2023, 3:17 a.m. UTC | #3
On 22/8/23 06:33, Alexei Starovoitov wrote:
> On Fri, Aug 18, 2023 at 8:12 AM Leon Hwang <hffilwlqm@gmail.com> wrote:
>>
>> From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall
>> handling in JIT"), the tailcall on x64 works better than before.
>>
>> From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms
>> for x64 JIT"), tailcall is able to run in BPF subprograms on x64.
>>
>> From commit 5b92a28aae4dd0f8 ("bpf: Support attaching tracing BPF program
>> to other BPF programs"), BPF program is able to trace other BPF programs.
>>
>> How about combining them all together?
>>
>> 1. FENTRY/FEXIT on a BPF subprogram.
>> 2. A tailcall runs in the BPF subprogram.
>> 3. The tailcall calls itself.
>>
>> As a result, a tailcall infinite loop comes up. And the loop would halt
>> the machine.
>>
>> As we know, in tail call context, the tail_call_cnt propagates by stack
>> and RAX register between BPF subprograms. So do it in trampolines.
>>
>> Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
>> ---
>>  arch/x86/net/bpf_jit_comp.c | 40 +++++++++++++++++++++++++++++--------
>>  include/linux/bpf.h         |  5 +++++
>>  kernel/bpf/trampoline.c     |  4 ++--
>>  kernel/bpf/verifier.c       | 31 +++++++++++++++++++++-------
>>  4 files changed, 63 insertions(+), 17 deletions(-)
>>
>> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
>> index a5930042139d3..1ad17d7de5eee 100644
>> --- a/arch/x86/net/bpf_jit_comp.c
>> +++ b/arch/x86/net/bpf_jit_comp.c
>> @@ -303,8 +303,12 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
>>         prog += X86_PATCH_SIZE;
>>         if (!ebpf_from_cbpf) {
>>                 if (tail_call_reachable && !is_subprog)
>> +                       /* When it's the entry of the whole tailcall context,
>> +                        * zeroing rax means initialising tail_call_cnt.
>> +                        */
>>                         EMIT2(0x31, 0xC0); /* xor eax, eax */
>>                 else
>> +                       // Keep the same instruction layout.
> 
> No c++ style comments please.

Got it.

> 
>>                         EMIT2(0x66, 0x90); /* nop2 */
>>         }
>>         EMIT1(0x55);             /* push rbp */
>> @@ -1018,6 +1022,10 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
>>
>>  #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
>>
>> +/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
>> +#define RESTORE_TAIL_CALL_CNT(stack)                           \
>> +       EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8)
>> +
>>  static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
>>                   int oldproglen, struct jit_context *ctx, bool jmp_padding)
>>  {
>> @@ -1623,9 +1631,7 @@ st:                       if (is_imm8(insn->off))
>>
>>                         func = (u8 *) __bpf_call_base + imm32;
>>                         if (tail_call_reachable) {
>> -                               /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
>> -                               EMIT3_off32(0x48, 0x8B, 0x85,
>> -                                           -round_up(bpf_prog->aux->stack_depth, 8) - 8);
>> +                               RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
>>                                 if (!imm32)
>>                                         return -EINVAL;
>>                                 offs = 7 + x86_call_depth_emit_accounting(&prog, func);
>> @@ -2298,7 +2304,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
>>   * push rbp
>>   * mov rbp, rsp
>>   * sub rsp, 16                     // space for skb and dev
>> - * push rbx                        // temp regs to pass start time
>> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
>> + * mov rax, 2                      // cache number of argument to rax
> 
> What does it mean?

I think it's the corresponding instruction to the following code snippet
in arch_prepare_bpf_trampoline().

	/* Store number of argument registers of the traced function:
	 *   mov rax, nr_regs
	 *   mov QWORD PTR [rbp - nregs_off], rax
	 */
	emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_regs);
	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -nregs_off);

> 
>> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
> 
> Here // is ok since it's inside /* */

Got it.

> 
>>   * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
>>   * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
>>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
>> @@ -2323,7 +2331,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
>>   * push rbp
>>   * mov rbp, rsp
>>   * sub rsp, 24                     // space for skb, dev, return value
>> - * push rbx                        // temp regs to pass start time
>> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
>> + * mov rax, 2                      // cache number of argument to rax
>> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
>>   * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
>>   * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
>>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
>> @@ -2400,6 +2410,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>>          *                     [ ...        ]
>>          *                     [ stack_arg2 ]
>>          * RBP - arg_stack_off [ stack_arg1 ]
>> +        * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
>>          */
>>
>>         /* room for return value of orig_call or fentry prog */
>> @@ -2464,6 +2475,8 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>>         else
>>                 /* sub rsp, stack_size */
>>                 EMIT4(0x48, 0x83, 0xEC, stack_size);
>> +       if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
>> +               EMIT1(0x50);            /* push rax */
>>         /* mov QWORD PTR [rbp - rbx_off], rbx */
>>         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
>>
>> @@ -2516,9 +2529,15 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>>                 restore_regs(m, &prog, regs_off);
>>                 save_args(m, &prog, arg_stack_off, true);
>>
>> +               if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
>> +                       /* Before calling the original function, restore the
>> +                        * tail_call_cnt from stack to rax.
>> +                        */
>> +                       RESTORE_TAIL_CALL_CNT(stack_size);
>> +
>>                 if (flags & BPF_TRAMP_F_ORIG_STACK) {
>> -                       emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
>> -                       EMIT2(0xff, 0xd0); /* call *rax */
>> +                       emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
>> +                       EMIT2(0xff, 0xd3); /* call *rbx */ // FIXME: Confirm 0xd3?
> 
> please no FIXME like comments.
> You have to be confident in the code you're submitting.
> llvm-mc -triple=x86_64 -show-encoding -x86-asm-syntax=intel
> -output-asm-variant=1 <<< 'call rbx'

Got it. Thanks for the guide.

> 
>>                 } else {
>>                         /* call original function */
>>                         if (emit_rsb_call(&prog, orig_call, prog)) {
>> @@ -2569,7 +2588,12 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>>                         ret = -EINVAL;
>>                         goto cleanup;
>>                 }
>> -       }
>> +       } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
>> +               /* Before running the original function, restore the
>> +                * tail_call_cnt from stack to rax.
>> +                */
>> +               RESTORE_TAIL_CALL_CNT(stack_size);
>> +
>>         /* restore return value of orig_call or fentry prog back into RAX */
>>         if (save_ret)
>>                 emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
>> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
>> index cfabbcf47bdb8..c8df257ea435d 100644
>> --- a/include/linux/bpf.h
>> +++ b/include/linux/bpf.h
>> @@ -1028,6 +1028,11 @@ struct btf_func_model {
>>   */
>>  #define BPF_TRAMP_F_SHARE_IPMODIFY     BIT(6)
>>
>> +/* Indicate that current trampoline is in a tail call context. Then, it has to
>> + * cache and restore tail_call_cnt to avoid infinite tail call loop.
>> + */
>> +#define BPF_TRAMP_F_TAIL_CALL_CTX      BIT(7)
>> +
>>  /* Each call __bpf_prog_enter + call bpf_func + call __bpf_prog_exit is ~50
>>   * bytes on x86.
>>   */
>> diff --git a/kernel/bpf/trampoline.c b/kernel/bpf/trampoline.c
>> index 78acf28d48732..16ab5da7161f2 100644
>> --- a/kernel/bpf/trampoline.c
>> +++ b/kernel/bpf/trampoline.c
>> @@ -415,8 +415,8 @@ static int bpf_trampoline_update(struct bpf_trampoline *tr, bool lock_direct_mut
>>                 goto out;
>>         }
>>
>> -       /* clear all bits except SHARE_IPMODIFY */
>> -       tr->flags &= BPF_TRAMP_F_SHARE_IPMODIFY;
>> +       /* clear all bits except SHARE_IPMODIFY and TAIL_CALL_CTX */
>> +       tr->flags &= (BPF_TRAMP_F_SHARE_IPMODIFY | BPF_TRAMP_F_TAIL_CALL_CTX);
>>
>>         if (tlinks[BPF_TRAMP_FEXIT].nr_links ||
>>             tlinks[BPF_TRAMP_MODIFY_RETURN].nr_links) {
>> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
>> index 4ccca1f6c9981..52ba9b043f16e 100644
>> --- a/kernel/bpf/verifier.c
>> +++ b/kernel/bpf/verifier.c
>> @@ -19246,6 +19246,21 @@ static int check_non_sleepable_error_inject(u32 btf_id)
>>         return btf_id_set_contains(&btf_non_sleepable_error_inject, btf_id);
>>  }
>>
>> +static inline int find_subprog_index(const struct bpf_prog *prog,
>> +                                    u32 btf_id)
>> +{
>> +       struct bpf_prog_aux *aux = prog->aux;
>> +       int i, subprog = -1;
>> +
>> +       for (i = 0; i < aux->func_info_cnt; i++)
>> +               if (aux->func_info[i].type_id == btf_id) {
>> +                       subprog = i;
>> +                       break;
>> +               }
>> +
>> +       return subprog;
>> +}
>> +
>>  int bpf_check_attach_target(struct bpf_verifier_log *log,
>>                             const struct bpf_prog *prog,
>>                             const struct bpf_prog *tgt_prog,
>> @@ -19254,9 +19269,9 @@ int bpf_check_attach_target(struct bpf_verifier_log *log,
>>  {
>>         bool prog_extension = prog->type == BPF_PROG_TYPE_EXT;
>>         const char prefix[] = "btf_trace_";
>> -       int ret = 0, subprog = -1, i;
>>         const struct btf_type *t;
>>         bool conservative = true;
>> +       int ret = 0, subprog;
>>         const char *tname;
>>         struct btf *btf;
>>         long addr = 0;
>> @@ -19291,11 +19306,7 @@ int bpf_check_attach_target(struct bpf_verifier_log *log,
>>                         return -EINVAL;
>>                 }
>>
>> -               for (i = 0; i < aux->func_info_cnt; i++)
>> -                       if (aux->func_info[i].type_id == btf_id) {
>> -                               subprog = i;
>> -                               break;
>> -                       }
>> +               subprog = find_subprog_index(tgt_prog, btf_id);
>>                 if (subprog == -1) {
>>                         bpf_log(log, "Subprog %s doesn't exist\n", tname);
>>                         return -EINVAL;
>> @@ -19559,7 +19570,7 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
>>         struct bpf_attach_target_info tgt_info = {};
>>         u32 btf_id = prog->aux->attach_btf_id;
>>         struct bpf_trampoline *tr;
>> -       int ret;
>> +       int ret, subprog;
>>         u64 key;
>>
>>         if (prog->type == BPF_PROG_TYPE_SYSCALL) {
>> @@ -19629,6 +19640,12 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
>>         if (!tr)
>>                 return -ENOMEM;
>>
>> +       if (tgt_prog && tgt_prog->aux->tail_call_reachable) {
>> +               subprog = find_subprog_index(tgt_prog, btf_id);
>> +               tr->flags = subprog > 0 && tgt_prog->aux->func[subprog]->is_func ?
>> +                           BPF_TRAMP_F_TAIL_CALL_CTX : 0;
> 
> If prog has subprogs all of them will 'is_func', no?
> What's the point of the search ?
> Just tgt_prog->aux->tail_call_reachable and func_cnt > 0 would be enough?

tgt_prog->aux->tail_call_reachable and subprog > 0 would be enough?
It has to confirm that the attaching target is a subprog of tgt_prog instead of
tgt_prog itself.

In tail call context, when 'call' a func, tail_call_cnt will be restored to rax.

static int do_jit() {
			/* call */
		case BPF_JMP | BPF_CALL: {
			int offs;

			func = (u8 *) __bpf_call_base + imm32;
			if (tail_call_reachable) {
				/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
				EMIT3_off32(0x48, 0x8B, 0x85,
					    -round_up(bpf_prog->aux->stack_depth, 8) - 8);
				/* ... */
			}
}

As a result, when 'call' a subprog, tail_call_cnt will be transferred by rax.
Do all of subprogs run by 'call', including not-'is_func' subprogs?

The point of the search is to confirm that the attaching subprog runs by 'call'.

Currently, I'm sure that tgt_prog->aux->tail_call_reachable, subprog > 0 and
tgt_prog->aux->func[subprog]->is_func is the case to be fixed.

Thanks,
Leon
Alexei Starovoitov Aug. 22, 2023, 9:29 p.m. UTC | #4
On Mon, Aug 21, 2023 at 8:17 PM Leon Hwang <hffilwlqm@gmail.com> wrote:
>
>
>
> On 22/8/23 06:33, Alexei Starovoitov wrote:
> > On Fri, Aug 18, 2023 at 8:12 AM Leon Hwang <hffilwlqm@gmail.com> wrote:
> >>
> >> From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall
> >> handling in JIT"), the tailcall on x64 works better than before.
> >>
> >> From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms
> >> for x64 JIT"), tailcall is able to run in BPF subprograms on x64.
> >>
> >> From commit 5b92a28aae4dd0f8 ("bpf: Support attaching tracing BPF program
> >> to other BPF programs"), BPF program is able to trace other BPF programs.
> >>
> >> How about combining them all together?
> >>
> >> 1. FENTRY/FEXIT on a BPF subprogram.
> >> 2. A tailcall runs in the BPF subprogram.
> >> 3. The tailcall calls itself.
> >>
> >> As a result, a tailcall infinite loop comes up. And the loop would halt
> >> the machine.
> >>
> >> As we know, in tail call context, the tail_call_cnt propagates by stack
> >> and RAX register between BPF subprograms. So do it in trampolines.
> >>
> >> Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
> >> ---
> >>  arch/x86/net/bpf_jit_comp.c | 40 +++++++++++++++++++++++++++++--------
> >>  include/linux/bpf.h         |  5 +++++
> >>  kernel/bpf/trampoline.c     |  4 ++--
> >>  kernel/bpf/verifier.c       | 31 +++++++++++++++++++++-------
> >>  4 files changed, 63 insertions(+), 17 deletions(-)
> >>
> >> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
> >> index a5930042139d3..1ad17d7de5eee 100644
> >> --- a/arch/x86/net/bpf_jit_comp.c
> >> +++ b/arch/x86/net/bpf_jit_comp.c
> >> @@ -303,8 +303,12 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
> >>         prog += X86_PATCH_SIZE;
> >>         if (!ebpf_from_cbpf) {
> >>                 if (tail_call_reachable && !is_subprog)
> >> +                       /* When it's the entry of the whole tailcall context,
> >> +                        * zeroing rax means initialising tail_call_cnt.
> >> +                        */
> >>                         EMIT2(0x31, 0xC0); /* xor eax, eax */
> >>                 else
> >> +                       // Keep the same instruction layout.
> >
> > No c++ style comments please.
>
> Got it.
>
> >
> >>                         EMIT2(0x66, 0x90); /* nop2 */
> >>         }
> >>         EMIT1(0x55);             /* push rbp */
> >> @@ -1018,6 +1022,10 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
> >>
> >>  #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
> >>
> >> +/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
> >> +#define RESTORE_TAIL_CALL_CNT(stack)                           \
> >> +       EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8)
> >> +
> >>  static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
> >>                   int oldproglen, struct jit_context *ctx, bool jmp_padding)
> >>  {
> >> @@ -1623,9 +1631,7 @@ st:                       if (is_imm8(insn->off))
> >>
> >>                         func = (u8 *) __bpf_call_base + imm32;
> >>                         if (tail_call_reachable) {
> >> -                               /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
> >> -                               EMIT3_off32(0x48, 0x8B, 0x85,
> >> -                                           -round_up(bpf_prog->aux->stack_depth, 8) - 8);
> >> +                               RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
> >>                                 if (!imm32)
> >>                                         return -EINVAL;
> >>                                 offs = 7 + x86_call_depth_emit_accounting(&prog, func);
> >> @@ -2298,7 +2304,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
> >>   * push rbp
> >>   * mov rbp, rsp
> >>   * sub rsp, 16                     // space for skb and dev
> >> - * push rbx                        // temp regs to pass start time
> >> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
> >> + * mov rax, 2                      // cache number of argument to rax
> >
> > What does it mean?
>
> I think it's the corresponding instruction to the following code snippet
> in arch_prepare_bpf_trampoline().
>
>         /* Store number of argument registers of the traced function:
>          *   mov rax, nr_regs
>          *   mov QWORD PTR [rbp - nregs_off], rax
>          */
>         emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_regs);
>         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -nregs_off);

Ahh. I see.
The comment on top of arch_prepare_bpf_trampoline() is hopelessly obsolete.
Don't touch it in this patch set. We probably should delete it at some point
or take an effort to update it thoroughly.
Earlier recommendation to you was to update this comment:
/* Generated trampoline stack layout:

> >
> >> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
> >
> > Here // is ok since it's inside /* */
>
> Got it.
>
> >
> >>   * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
> >>   * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
> >>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
> >> @@ -2323,7 +2331,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
> >>   * push rbp
> >>   * mov rbp, rsp
> >>   * sub rsp, 24                     // space for skb, dev, return value
> >> - * push rbx                        // temp regs to pass start time
> >> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
> >> + * mov rax, 2                      // cache number of argument to rax
> >> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
> >>   * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
> >>   * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
> >>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
> >> @@ -2400,6 +2410,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
> >>          *                     [ ...        ]
> >>          *                     [ stack_arg2 ]
> >>          * RBP - arg_stack_off [ stack_arg1 ]
> >> +        * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
> >>          */
> >>
> >>         /* room for return value of orig_call or fentry prog */
> >> @@ -2464,6 +2475,8 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
> >>         else
> >>                 /* sub rsp, stack_size */
> >>                 EMIT4(0x48, 0x83, 0xEC, stack_size);
> >> +       if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> >> +               EMIT1(0x50);            /* push rax */
> >>         /* mov QWORD PTR [rbp - rbx_off], rbx */
> >>         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
> >>
> >> @@ -2516,9 +2529,15 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
> >>                 restore_regs(m, &prog, regs_off);
> >>                 save_args(m, &prog, arg_stack_off, true);
> >>
> >> +               if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> >> +                       /* Before calling the original function, restore the
> >> +                        * tail_call_cnt from stack to rax.
> >> +                        */
> >> +                       RESTORE_TAIL_CALL_CNT(stack_size);
> >> +
> >>                 if (flags & BPF_TRAMP_F_ORIG_STACK) {
> >> -                       emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
> >> -                       EMIT2(0xff, 0xd0); /* call *rax */
> >> +                       emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
> >> +                       EMIT2(0xff, 0xd3); /* call *rbx */ // FIXME: Confirm 0xd3?
> >
> > please no FIXME like comments.
> > You have to be confident in the code you're submitting.
> > llvm-mc -triple=x86_64 -show-encoding -x86-asm-syntax=intel
> > -output-asm-variant=1 <<< 'call rbx'
>
> Got it. Thanks for the guide.
>
> >
> >>                 } else {
> >>                         /* call original function */
> >>                         if (emit_rsb_call(&prog, orig_call, prog)) {
> >> @@ -2569,7 +2588,12 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
> >>                         ret = -EINVAL;
> >>                         goto cleanup;
> >>                 }
> >> -       }
> >> +       } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
> >> +               /* Before running the original function, restore the
> >> +                * tail_call_cnt from stack to rax.
> >> +                */
> >> +               RESTORE_TAIL_CALL_CNT(stack_size);
> >> +
> >>         /* restore return value of orig_call or fentry prog back into RAX */
> >>         if (save_ret)
> >>                 emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
> >> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> >> index cfabbcf47bdb8..c8df257ea435d 100644
> >> --- a/include/linux/bpf.h
> >> +++ b/include/linux/bpf.h
> >> @@ -1028,6 +1028,11 @@ struct btf_func_model {
> >>   */
> >>  #define BPF_TRAMP_F_SHARE_IPMODIFY     BIT(6)
> >>
> >> +/* Indicate that current trampoline is in a tail call context. Then, it has to
> >> + * cache and restore tail_call_cnt to avoid infinite tail call loop.
> >> + */
> >> +#define BPF_TRAMP_F_TAIL_CALL_CTX      BIT(7)
> >> +
> >>  /* Each call __bpf_prog_enter + call bpf_func + call __bpf_prog_exit is ~50
> >>   * bytes on x86.
> >>   */
> >> diff --git a/kernel/bpf/trampoline.c b/kernel/bpf/trampoline.c
> >> index 78acf28d48732..16ab5da7161f2 100644
> >> --- a/kernel/bpf/trampoline.c
> >> +++ b/kernel/bpf/trampoline.c
> >> @@ -415,8 +415,8 @@ static int bpf_trampoline_update(struct bpf_trampoline *tr, bool lock_direct_mut
> >>                 goto out;
> >>         }
> >>
> >> -       /* clear all bits except SHARE_IPMODIFY */
> >> -       tr->flags &= BPF_TRAMP_F_SHARE_IPMODIFY;
> >> +       /* clear all bits except SHARE_IPMODIFY and TAIL_CALL_CTX */
> >> +       tr->flags &= (BPF_TRAMP_F_SHARE_IPMODIFY | BPF_TRAMP_F_TAIL_CALL_CTX);
> >>
> >>         if (tlinks[BPF_TRAMP_FEXIT].nr_links ||
> >>             tlinks[BPF_TRAMP_MODIFY_RETURN].nr_links) {
> >> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> >> index 4ccca1f6c9981..52ba9b043f16e 100644
> >> --- a/kernel/bpf/verifier.c
> >> +++ b/kernel/bpf/verifier.c
> >> @@ -19246,6 +19246,21 @@ static int check_non_sleepable_error_inject(u32 btf_id)
> >>         return btf_id_set_contains(&btf_non_sleepable_error_inject, btf_id);
> >>  }
> >>
> >> +static inline int find_subprog_index(const struct bpf_prog *prog,
> >> +                                    u32 btf_id)
> >> +{
> >> +       struct bpf_prog_aux *aux = prog->aux;
> >> +       int i, subprog = -1;
> >> +
> >> +       for (i = 0; i < aux->func_info_cnt; i++)
> >> +               if (aux->func_info[i].type_id == btf_id) {
> >> +                       subprog = i;
> >> +                       break;
> >> +               }
> >> +
> >> +       return subprog;
> >> +}
> >> +
> >>  int bpf_check_attach_target(struct bpf_verifier_log *log,
> >>                             const struct bpf_prog *prog,
> >>                             const struct bpf_prog *tgt_prog,
> >> @@ -19254,9 +19269,9 @@ int bpf_check_attach_target(struct bpf_verifier_log *log,
> >>  {
> >>         bool prog_extension = prog->type == BPF_PROG_TYPE_EXT;
> >>         const char prefix[] = "btf_trace_";
> >> -       int ret = 0, subprog = -1, i;
> >>         const struct btf_type *t;
> >>         bool conservative = true;
> >> +       int ret = 0, subprog;
> >>         const char *tname;
> >>         struct btf *btf;
> >>         long addr = 0;
> >> @@ -19291,11 +19306,7 @@ int bpf_check_attach_target(struct bpf_verifier_log *log,
> >>                         return -EINVAL;
> >>                 }
> >>
> >> -               for (i = 0; i < aux->func_info_cnt; i++)
> >> -                       if (aux->func_info[i].type_id == btf_id) {
> >> -                               subprog = i;
> >> -                               break;
> >> -                       }
> >> +               subprog = find_subprog_index(tgt_prog, btf_id);
> >>                 if (subprog == -1) {
> >>                         bpf_log(log, "Subprog %s doesn't exist\n", tname);
> >>                         return -EINVAL;
> >> @@ -19559,7 +19570,7 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
> >>         struct bpf_attach_target_info tgt_info = {};
> >>         u32 btf_id = prog->aux->attach_btf_id;
> >>         struct bpf_trampoline *tr;
> >> -       int ret;
> >> +       int ret, subprog;
> >>         u64 key;
> >>
> >>         if (prog->type == BPF_PROG_TYPE_SYSCALL) {
> >> @@ -19629,6 +19640,12 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
> >>         if (!tr)
> >>                 return -ENOMEM;
> >>
> >> +       if (tgt_prog && tgt_prog->aux->tail_call_reachable) {
> >> +               subprog = find_subprog_index(tgt_prog, btf_id);
> >> +               tr->flags = subprog > 0 && tgt_prog->aux->func[subprog]->is_func ?
> >> +                           BPF_TRAMP_F_TAIL_CALL_CTX : 0;
> >
> > If prog has subprogs all of them will 'is_func', no?
> > What's the point of the search ?
> > Just tgt_prog->aux->tail_call_reachable and func_cnt > 0 would be enough?
>
> tgt_prog->aux->tail_call_reachable and subprog > 0 would be enough?
> It has to confirm that the attaching target is a subprog of tgt_prog instead of
> tgt_prog itself.
>
> In tail call context, when 'call' a func, tail_call_cnt will be restored to rax.
>
> static int do_jit() {
>                         /* call */
>                 case BPF_JMP | BPF_CALL: {
>                         int offs;
>
>                         func = (u8 *) __bpf_call_base + imm32;
>                         if (tail_call_reachable) {
>                                 /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
>                                 EMIT3_off32(0x48, 0x8B, 0x85,
>                                             -round_up(bpf_prog->aux->stack_depth, 8) - 8);
>                                 /* ... */
>                         }
> }
>
> As a result, when 'call' a subprog, tail_call_cnt will be transferred by rax.
> Do all of subprogs run by 'call', including not-'is_func' subprogs?

Let me ask again. Do you see a subprog that has is_func==0 ?
Leon Hwang Aug. 23, 2023, 1:49 a.m. UTC | #5
On 23/8/23 05:29, Alexei Starovoitov wrote:
> On Mon, Aug 21, 2023 at 8:17 PM Leon Hwang <hffilwlqm@gmail.com> wrote:
>>
>>
>>
>> On 22/8/23 06:33, Alexei Starovoitov wrote:
>>> On Fri, Aug 18, 2023 at 8:12 AM Leon Hwang <hffilwlqm@gmail.com> wrote:
>>>>

[SNIP]

>>>>   * sub rsp, 16                     // space for skb and dev
>>>> - * push rbx                        // temp regs to pass start time
>>>> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
>>>> + * mov rax, 2                      // cache number of argument to rax
>>>
>>> What does it mean?
>>
>> I think it's the corresponding instruction to the following code snippet
>> in arch_prepare_bpf_trampoline().
>>
>>         /* Store number of argument registers of the traced function:
>>          *   mov rax, nr_regs
>>          *   mov QWORD PTR [rbp - nregs_off], rax
>>          */
>>         emit_mov_imm64(&prog, BPF_REG_0, 0, (u32) nr_regs);
>>         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -nregs_off);
> 
> Ahh. I see.
> The comment on top of arch_prepare_bpf_trampoline() is hopelessly obsolete.
> Don't touch it in this patch set. We probably should delete it at some point
> or take an effort to update it thoroughly.

Got it.

> Earlier recommendation to you was to update this comment:
> /* Generated trampoline stack layout:
> 
>>>
>>>> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
>>>
>>> Here // is ok since it's inside /* */
>>
>> Got it.
>>
>>>
>>>>   * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
>>>>   * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
>>>>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
>>>> @@ -2323,7 +2331,9 @@ static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
>>>>   * push rbp
>>>>   * mov rbp, rsp
>>>>   * sub rsp, 24                     // space for skb, dev, return value
>>>> - * push rbx                        // temp regs to pass start time
>>>> + * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
>>>> + * mov rax, 2                      // cache number of argument to rax
>>>> + * mov qword ptr [rbp - 32], rax   // save number of argument to stack
>>>>   * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
>>>>   * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
>>>>   * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
>>>> @@ -2400,6 +2410,7 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>>>>          *                     [ ...        ]
>>>>          *                     [ stack_arg2 ]
>>>>          * RBP - arg_stack_off [ stack_arg1 ]
>>>> +        * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
>>>>          */
>>>>
>>>>         /* room for return value of orig_call or fentry prog */
>>>> @@ -2464,6 +2475,8 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>>>>         else
>>>>                 /* sub rsp, stack_size */
>>>>                 EMIT4(0x48, 0x83, 0xEC, stack_size);
>>>> +       if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
>>>> +               EMIT1(0x50);            /* push rax */
>>>>         /* mov QWORD PTR [rbp - rbx_off], rbx */
>>>>         emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
>>>>
>>>> @@ -2516,9 +2529,15 @@ int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
>>>>                 restore_regs(m, &prog, regs_off);
>>>>                 save_args(m, &prog, arg_stack_off, true);
>>>>
>>>> +               if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
>>>> +                       /* Before calling the original function, restore the
>>>> +                        * tail_call_cnt from stack to rax.
>>>> +                        */
>>>> +                       RESTORE_TAIL_CALL_CNT(stack_size);
>>>> +
>>>>                 if (flags & BPF_TRAMP_F_ORIG_STACK) {
>>>> -                       emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
>>>> -                       EMIT2(0xff, 0xd0); /* call *rax */
>>>> +                       emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
>>>> +                       EMIT2(0xff, 0xd3); /* call *rbx */ // FIXME: Confirm 0xd3?
>>>
>>> please no FIXME like comments.
>>> You have to be confident in the code you're submitting.
>>> llvm-mc -triple=x86_64 -show-encoding -x86-asm-syntax=intel
>>> -output-asm-variant=1 <<< 'call rbx'
>>
>> Got it. Thanks for the guide.
>>
>>>
>>>>                 } else {
>>>>                         /* call original function */
>>>>                         if (emit_rsb_call(&prog, orig_call, prog)) {

[SNIP]

>>>>
>>>>         if (prog->type == BPF_PROG_TYPE_SYSCALL) {
>>>> @@ -19629,6 +19640,12 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
>>>>         if (!tr)
>>>>                 return -ENOMEM;
>>>>
>>>> +       if (tgt_prog && tgt_prog->aux->tail_call_reachable) {
>>>> +               subprog = find_subprog_index(tgt_prog, btf_id);
>>>> +               tr->flags = subprog > 0 && tgt_prog->aux->func[subprog]->is_func ?
>>>> +                           BPF_TRAMP_F_TAIL_CALL_CTX : 0;
>>>
>>> If prog has subprogs all of them will 'is_func', no?
>>> What's the point of the search ?
>>> Just tgt_prog->aux->tail_call_reachable and func_cnt > 0 would be enough?
>>
>> tgt_prog->aux->tail_call_reachable and subprog > 0 would be enough?
>> It has to confirm that the attaching target is a subprog of tgt_prog instead of
>> tgt_prog itself.
>>
>> In tail call context, when 'call' a func, tail_call_cnt will be restored to rax.
>>
>> static int do_jit() {
>>                         /* call */
>>                 case BPF_JMP | BPF_CALL: {
>>                         int offs;
>>
>>                         func = (u8 *) __bpf_call_base + imm32;
>>                         if (tail_call_reachable) {
>>                                 /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
>>                                 EMIT3_off32(0x48, 0x8B, 0x85,
>>                                             -round_up(bpf_prog->aux->stack_depth, 8) - 8);
>>                                 /* ... */
>>                         }
>> }
>>
>> As a result, when 'call' a subprog, tail_call_cnt will be transferred by rax.
>> Do all of subprogs run by 'call', including not-'is_func' subprogs?
> 
> Let me ask again. Do you see a subprog that has is_func==0 ?

Oh, I get it.

In jit_subprogs(), all of subprogs are 'is_func'.

So, it's unnecessary to check tgt_prog->aux->func[subprog]->is_func.

I'll submit a new RFC PATCH later.

Thanks,
Leon
diff mbox series

Patch

diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index a5930042139d3..1ad17d7de5eee 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -303,8 +303,12 @@  static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf,
 	prog += X86_PATCH_SIZE;
 	if (!ebpf_from_cbpf) {
 		if (tail_call_reachable && !is_subprog)
+			/* When it's the entry of the whole tailcall context,
+			 * zeroing rax means initialising tail_call_cnt.
+			 */
 			EMIT2(0x31, 0xC0); /* xor eax, eax */
 		else
+			// Keep the same instruction layout.
 			EMIT2(0x66, 0x90); /* nop2 */
 	}
 	EMIT1(0x55);             /* push rbp */
@@ -1018,6 +1022,10 @@  static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op)
 
 #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
 
+/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
+#define RESTORE_TAIL_CALL_CNT(stack)				\
+	EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8)
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image,
 		  int oldproglen, struct jit_context *ctx, bool jmp_padding)
 {
@@ -1623,9 +1631,7 @@  st:			if (is_imm8(insn->off))
 
 			func = (u8 *) __bpf_call_base + imm32;
 			if (tail_call_reachable) {
-				/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */
-				EMIT3_off32(0x48, 0x8B, 0x85,
-					    -round_up(bpf_prog->aux->stack_depth, 8) - 8);
+				RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth);
 				if (!imm32)
 					return -EINVAL;
 				offs = 7 + x86_call_depth_emit_accounting(&prog, func);
@@ -2298,7 +2304,9 @@  static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
  * push rbp
  * mov rbp, rsp
  * sub rsp, 16                     // space for skb and dev
- * push rbx                        // temp regs to pass start time
+ * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
+ * mov rax, 2                      // cache number of argument to rax
+ * mov qword ptr [rbp - 32], rax   // save number of argument to stack
  * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
  * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
@@ -2323,7 +2331,9 @@  static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
  * push rbp
  * mov rbp, rsp
  * sub rsp, 24                     // space for skb, dev, return value
- * push rbx                        // temp regs to pass start time
+ * mov qword ptr [rbp - 40], rbx   // temp regs to pass start time
+ * mov rax, 2                      // cache number of argument to rax
+ * mov qword ptr [rbp - 32], rax   // save number of argument to stack
  * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
  * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
  * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
@@ -2400,6 +2410,7 @@  int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
 	 *                     [ ...        ]
 	 *                     [ stack_arg2 ]
 	 * RBP - arg_stack_off [ stack_arg1 ]
+	 * RSP                 [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX
 	 */
 
 	/* room for return value of orig_call or fentry prog */
@@ -2464,6 +2475,8 @@  int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
 	else
 		/* sub rsp, stack_size */
 		EMIT4(0x48, 0x83, 0xEC, stack_size);
+	if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+		EMIT1(0x50);		/* push rax */
 	/* mov QWORD PTR [rbp - rbx_off], rbx */
 	emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off);
 
@@ -2516,9 +2529,15 @@  int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
 		restore_regs(m, &prog, regs_off);
 		save_args(m, &prog, arg_stack_off, true);
 
+		if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+			/* Before calling the original function, restore the
+			 * tail_call_cnt from stack to rax.
+			 */
+			RESTORE_TAIL_CALL_CNT(stack_size);
+
 		if (flags & BPF_TRAMP_F_ORIG_STACK) {
-			emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8);
-			EMIT2(0xff, 0xd0); /* call *rax */
+			emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8);
+			EMIT2(0xff, 0xd3); /* call *rbx */ // FIXME: Confirm 0xd3?
 		} else {
 			/* call original function */
 			if (emit_rsb_call(&prog, orig_call, prog)) {
@@ -2569,7 +2588,12 @@  int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image, void *i
 			ret = -EINVAL;
 			goto cleanup;
 		}
-	}
+	} else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+		/* Before running the original function, restore the
+		 * tail_call_cnt from stack to rax.
+		 */
+		RESTORE_TAIL_CALL_CNT(stack_size);
+
 	/* restore return value of orig_call or fentry prog back into RAX */
 	if (save_ret)
 		emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index cfabbcf47bdb8..c8df257ea435d 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -1028,6 +1028,11 @@  struct btf_func_model {
  */
 #define BPF_TRAMP_F_SHARE_IPMODIFY	BIT(6)
 
+/* Indicate that current trampoline is in a tail call context. Then, it has to
+ * cache and restore tail_call_cnt to avoid infinite tail call loop.
+ */
+#define BPF_TRAMP_F_TAIL_CALL_CTX	BIT(7)
+
 /* Each call __bpf_prog_enter + call bpf_func + call __bpf_prog_exit is ~50
  * bytes on x86.
  */
diff --git a/kernel/bpf/trampoline.c b/kernel/bpf/trampoline.c
index 78acf28d48732..16ab5da7161f2 100644
--- a/kernel/bpf/trampoline.c
+++ b/kernel/bpf/trampoline.c
@@ -415,8 +415,8 @@  static int bpf_trampoline_update(struct bpf_trampoline *tr, bool lock_direct_mut
 		goto out;
 	}
 
-	/* clear all bits except SHARE_IPMODIFY */
-	tr->flags &= BPF_TRAMP_F_SHARE_IPMODIFY;
+	/* clear all bits except SHARE_IPMODIFY and TAIL_CALL_CTX */
+	tr->flags &= (BPF_TRAMP_F_SHARE_IPMODIFY | BPF_TRAMP_F_TAIL_CALL_CTX);
 
 	if (tlinks[BPF_TRAMP_FEXIT].nr_links ||
 	    tlinks[BPF_TRAMP_MODIFY_RETURN].nr_links) {
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 4ccca1f6c9981..52ba9b043f16e 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -19246,6 +19246,21 @@  static int check_non_sleepable_error_inject(u32 btf_id)
 	return btf_id_set_contains(&btf_non_sleepable_error_inject, btf_id);
 }
 
+static inline int find_subprog_index(const struct bpf_prog *prog,
+				     u32 btf_id)
+{
+	struct bpf_prog_aux *aux = prog->aux;
+	int i, subprog = -1;
+
+	for (i = 0; i < aux->func_info_cnt; i++)
+		if (aux->func_info[i].type_id == btf_id) {
+			subprog = i;
+			break;
+		}
+
+	return subprog;
+}
+
 int bpf_check_attach_target(struct bpf_verifier_log *log,
 			    const struct bpf_prog *prog,
 			    const struct bpf_prog *tgt_prog,
@@ -19254,9 +19269,9 @@  int bpf_check_attach_target(struct bpf_verifier_log *log,
 {
 	bool prog_extension = prog->type == BPF_PROG_TYPE_EXT;
 	const char prefix[] = "btf_trace_";
-	int ret = 0, subprog = -1, i;
 	const struct btf_type *t;
 	bool conservative = true;
+	int ret = 0, subprog;
 	const char *tname;
 	struct btf *btf;
 	long addr = 0;
@@ -19291,11 +19306,7 @@  int bpf_check_attach_target(struct bpf_verifier_log *log,
 			return -EINVAL;
 		}
 
-		for (i = 0; i < aux->func_info_cnt; i++)
-			if (aux->func_info[i].type_id == btf_id) {
-				subprog = i;
-				break;
-			}
+		subprog = find_subprog_index(tgt_prog, btf_id);
 		if (subprog == -1) {
 			bpf_log(log, "Subprog %s doesn't exist\n", tname);
 			return -EINVAL;
@@ -19559,7 +19570,7 @@  static int check_attach_btf_id(struct bpf_verifier_env *env)
 	struct bpf_attach_target_info tgt_info = {};
 	u32 btf_id = prog->aux->attach_btf_id;
 	struct bpf_trampoline *tr;
-	int ret;
+	int ret, subprog;
 	u64 key;
 
 	if (prog->type == BPF_PROG_TYPE_SYSCALL) {
@@ -19629,6 +19640,12 @@  static int check_attach_btf_id(struct bpf_verifier_env *env)
 	if (!tr)
 		return -ENOMEM;
 
+	if (tgt_prog && tgt_prog->aux->tail_call_reachable) {
+		subprog = find_subprog_index(tgt_prog, btf_id);
+		tr->flags = subprog > 0 && tgt_prog->aux->func[subprog]->is_func ?
+			    BPF_TRAMP_F_TAIL_CALL_CTX : 0;
+	}
+
 	prog->aux->dst_trampoline = tr;
 	return 0;
 }