diff mbox series

[RFC,v2,9/9] netfilter: hook_jit: add prog cache

Message ID 20221005141309.31758-10-fw@strlen.de (mailing list archive)
State RFC
Delegated to: Netdev Maintainers
Headers show
Series netfilter: bpf base hook program generator | expand

Checks

Context Check Description
netdev/tree_selection success Guessed tree name to be net-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix success Link
netdev/cover_letter success Series has a cover letter
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 2 this patch: 2
netdev/cc_maintainers warning 9 maintainers not CCed: kuba@kernel.org davem@davemloft.net pablo@netfilter.org netfilter-devel@vger.kernel.org kadlec@netfilter.org netdev@vger.kernel.org coreteam@netfilter.org edumazet@google.com pabeni@redhat.com
netdev/build_clang success Errors and warnings before: 11 this patch: 9
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 2 this patch: 2
netdev/checkpatch warning CHECK: multiple assignments should be avoided WARNING: line length of 87 exceeds 80 columns WARNING: line length of 93 exceeds 80 columns WARNING: line length of 96 exceeds 80 columns
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Florian Westphal Oct. 5, 2022, 2:13 p.m. UTC
This allows to re-use the same program.  For example, a nft
ruleset that attaches filter basechains to input, forward, output would
use the same program for all three hook points.

The cache is intentionally netns agnostic, so same config
in different netns will all use same programs.

Signed-off-by: Florian Westphal <fw@strlen.de>
---
 net/netfilter/nf_hook_bpf.c | 150 ++++++++++++++++++++++++++++++++++++
 1 file changed, 150 insertions(+)
diff mbox series

Patch

diff --git a/net/netfilter/nf_hook_bpf.c b/net/netfilter/nf_hook_bpf.c
index dab13b803801..0ca2e4404b1b 100644
--- a/net/netfilter/nf_hook_bpf.c
+++ b/net/netfilter/nf_hook_bpf.c
@@ -38,6 +38,24 @@  struct nf_hook_prog {
 	unsigned int pos;
 };
 
+struct nf_hook_bpf_prog {
+	struct rcu_head rcu_head;
+
+	struct hlist_node node_key;
+	struct hlist_node node_prog;
+	u32 key;
+	u16 hook_count;
+	refcount_t refcnt;
+	struct bpf_prog	*prog;
+	unsigned long hooks[32];
+};
+
+#define NF_BPF_PROG_HT_BITS	8
+
+/* users need to hold nf_hook_mutex */
+static DEFINE_HASHTABLE(nf_bpf_progs_ht_key, NF_BPF_PROG_HT_BITS);
+static DEFINE_HASHTABLE(nf_bpf_progs_ht_prog, NF_BPF_PROG_HT_BITS);
+
 static bool emit(struct nf_hook_prog *p, struct bpf_insn insn)
 {
 	if (WARN_ON_ONCE(p->pos >= BPF_MAXINSNS))
@@ -398,12 +416,112 @@  struct bpf_prog *nf_hook_bpf_create_fb(void)
 	return prog;
 }
 
+static u32 nf_hook_entries_hash(const struct nf_hook_entries *new)
+{
+	u32 i = 0, hook_count = new->num_hook_entries;
+	u32 a, b, c;
+
+	a = b = c = JHASH_INITVAL + hook_count;
+
+	while (hook_count > 3) {
+		a += hash32_ptr(new->hooks[i].hook);
+		b += hash32_ptr(new->hooks[i + 1].hook);
+		c += hash32_ptr(new->hooks[i + 2].hook);
+		__jhash_mix(a, b, c);
+		hook_count -= 3;
+		i += 3;
+	}
+
+	switch (hook_count) {
+	case 3:
+		c += hash32_ptr(new->hooks[i + 2].hook);
+		fallthrough;
+	case 2:
+		b += hash32_ptr(new->hooks[i + 1].hook);
+		fallthrough;
+	case 1:
+		a += hash32_ptr(new->hooks[i].hook);
+		__jhash_final(a, b, c);
+		break;
+	}
+
+	return c;
+}
+
+static struct bpf_prog *nf_hook_bpf_find_prog_by_key(const struct nf_hook_entries *new, u32 key)
+{
+	int i, hook_count = new->num_hook_entries;
+	struct nf_hook_bpf_prog *pc;
+
+	hash_for_each_possible(nf_bpf_progs_ht_key, pc, node_key, key) {
+		if (pc->hook_count != hook_count ||
+		    pc->key != key)
+			continue;
+
+		for (i = 0; i < hook_count; i++) {
+			if (pc->hooks[i] != (unsigned long)new->hooks[i].hook)
+				break;
+		}
+
+		if (i == hook_count) {
+			refcount_inc(&pc->refcnt);
+			return pc->prog;
+		}
+	}
+
+	return NULL;
+}
+
+static struct nf_hook_bpf_prog *nf_hook_bpf_find_prog(const struct bpf_prog *p)
+{
+	struct nf_hook_bpf_prog *pc;
+
+	hash_for_each_possible(nf_bpf_progs_ht_prog, pc, node_prog, (unsigned long)p) {
+		if (pc->prog == p)
+			return pc;
+	}
+
+	return NULL;
+}
+
+static void nf_hook_bpf_prog_store(const struct nf_hook_entries *new,
+				   struct bpf_prog *prog, u32 key)
+{
+	unsigned int i, hook_count = new->num_hook_entries;
+	struct nf_hook_bpf_prog *alloc;
+
+	if (hook_count >= ARRAY_SIZE(alloc->hooks))
+		return;
+
+	alloc = kzalloc(sizeof(*alloc), GFP_KERNEL);
+	if (!alloc)
+		return;
+
+	alloc->hook_count = new->num_hook_entries;
+	alloc->prog = prog;
+	alloc->key = key;
+
+	for (i = 0; i < hook_count; i++)
+		alloc->hooks[i] = (unsigned long)new->hooks[i].hook;
+
+	hash_add(nf_bpf_progs_ht_key, &alloc->node_key, key);
+	hash_add(nf_bpf_progs_ht_prog, &alloc->node_prog, (unsigned long)prog);
+	refcount_set(&alloc->refcnt, 1);
+
+	bpf_prog_inc(prog);
+}
+
 struct bpf_prog *nf_hook_bpf_create(const struct nf_hook_entries *new)
 {
+	u32 key = nf_hook_entries_hash(new);
 	struct bpf_prog *prog;
 	struct nf_hook_prog p;
 	int err;
 
+	prog = nf_hook_bpf_find_prog_by_key(new, key);
+	if (prog)
+		return prog;
+
 	err = nf_hook_prog_init(&p);
 	if (err)
 		return NULL;
@@ -413,12 +531,44 @@  struct bpf_prog *nf_hook_bpf_create(const struct nf_hook_entries *new)
 		goto err;
 
 	prog = nf_hook_jit_compile(p.insns, p.pos);
+	if (prog)
+		nf_hook_bpf_prog_store(new, prog, key);
 err:
 	nf_hook_prog_free(&p);
 	return prog;
 }
 
+static void __nf_hook_free_prog(struct rcu_head *head)
+{
+	struct nf_hook_bpf_prog *old = container_of(head, struct nf_hook_bpf_prog, rcu_head);
+
+	bpf_prog_put(old->prog);
+	kfree(old);
+}
+
+static void nf_hook_free_prog(struct nf_hook_bpf_prog *old)
+{
+	call_rcu(&old->rcu_head, __nf_hook_free_prog);
+}
+
 void nf_hook_bpf_change_prog(struct bpf_dispatcher *d, struct bpf_prog *from, struct bpf_prog *to)
 {
+	if (from == to)
+		return;
+
+	if (from) {
+		struct nf_hook_bpf_prog *old;
+
+		old = nf_hook_bpf_find_prog(from);
+		if (old) {
+			WARN_ON_ONCE(from != old->prog);
+			if (refcount_dec_and_test(&old->refcnt)) {
+				hash_del(&old->node_key);
+				hash_del(&old->node_prog);
+				nf_hook_free_prog(old);
+			}
+		}
+	}
+
 	bpf_dispatcher_change_prog(d, from, to);
 }