@@ -3,7 +3,7 @@
#define _LINUX_BPFILTER_H
#include <uapi/linux/bpfilter.h>
-#include <linux/usermode_driver.h>
+#include <linux/usermode_driver_mgmt.h>
#include <linux/sockptr.h>
struct sock;
@@ -13,13 +13,5 @@ int bpfilter_ip_get_sockopt(struct sock *sk, int optname, char __user *optval,
int __user *optlen);
void bpfilter_umh_cleanup(struct umd_info *info);
-struct bpfilter_umh_ops {
- struct umd_info info;
- /* since ip_getsockopt() can run in parallel, serialize access to umh */
- struct mutex lock;
- int (*sockopt)(struct sock *sk, int optname, sockptr_t optval,
- unsigned int optlen, bool is_set);
- int (*start)(void);
-};
-extern struct bpfilter_umh_ops bpfilter_ops;
+extern struct umd_mgmt bpfilter_ops;
#endif
@@ -1,113 +1,21 @@
// SPDX-License-Identifier: GPL-2.0
#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
-#include <linux/init.h>
#include <linux/module.h>
-#include <linux/umh.h>
#include <linux/bpfilter.h>
-#include <linux/sched.h>
-#include <linux/sched/signal.h>
-#include <linux/fs.h>
-#include <linux/file.h>
#include "msgfmt.h"
extern char bpfilter_umh_start;
extern char bpfilter_umh_end;
-static void shutdown_umh(void)
-{
- struct umd_info *info = &bpfilter_ops.info;
- struct pid *tgid = info->tgid;
-
- if (tgid) {
- kill_pid(tgid, SIGKILL, 1);
- wait_event(tgid->wait_pidfd, thread_group_exited(tgid));
- bpfilter_umh_cleanup(info);
- }
-}
-
-static int bpfilter_process_sockopt(struct sock *sk, int optname,
- sockptr_t optval, unsigned int optlen,
- bool is_set)
-{
- struct mbox_request req = {
- .is_set = is_set,
- .pid = current->pid,
- .cmd = optname,
- .addr = (uintptr_t)optval.user,
- .len = optlen,
- };
- struct mbox_reply reply;
- int err;
-
- if (sockptr_is_kernel(optval)) {
- pr_err("kernel access not supported\n");
- return -EFAULT;
- }
- err = umd_send_recv(&bpfilter_ops.info, &req, sizeof(req), &reply,
- sizeof(reply));
- if (err) {
- shutdown_umh();
- return err;
- }
-
- return reply.status;
-}
-
-static int start_umh(void)
-{
- struct mbox_request req = { .pid = current->pid };
- struct mbox_reply reply;
- int err;
-
- /* fork usermode process */
- err = fork_usermode_driver(&bpfilter_ops.info);
- if (err)
- return err;
- pr_info("Loaded bpfilter_umh pid %d\n", pid_nr(bpfilter_ops.info.tgid));
-
- /* health check that usermode process started correctly */
- if (umd_send_recv(&bpfilter_ops.info, &req, sizeof(req), &reply,
- sizeof(reply)) != 0 || reply.status != 0) {
- shutdown_umh();
- return -EFAULT;
- }
-
- return 0;
-}
-
static int __init load_umh(void)
{
- int err;
-
- err = umd_load_blob(&bpfilter_ops.info,
- &bpfilter_umh_start,
- &bpfilter_umh_end - &bpfilter_umh_start);
- if (err)
- return err;
-
- mutex_lock(&bpfilter_ops.lock);
- err = start_umh();
- if (!err && IS_ENABLED(CONFIG_INET)) {
- bpfilter_ops.sockopt = &bpfilter_process_sockopt;
- bpfilter_ops.start = &start_umh;
- }
- mutex_unlock(&bpfilter_ops.lock);
- if (err)
- umd_unload_blob(&bpfilter_ops.info);
- return err;
+ return umd_mgmt_load(&bpfilter_ops, &bpfilter_umh_start,
+ &bpfilter_umh_end);
}
static void __exit fini_umh(void)
{
- mutex_lock(&bpfilter_ops.lock);
- if (IS_ENABLED(CONFIG_INET)) {
- shutdown_umh();
- bpfilter_ops.start = NULL;
- bpfilter_ops.sockopt = NULL;
- }
- mutex_unlock(&bpfilter_ops.lock);
-
- umd_unload_blob(&bpfilter_ops.info);
+ umd_mgmt_unload(&bpfilter_ops);
}
module_init(load_umh);
module_exit(fini_umh);
@@ -8,8 +8,9 @@
#include <linux/kmod.h>
#include <linux/fs.h>
#include <linux/file.h>
+#include "../../bpfilter/msgfmt.h"
-struct bpfilter_umh_ops bpfilter_ops;
+struct umd_mgmt bpfilter_ops;
EXPORT_SYMBOL_GPL(bpfilter_ops);
void bpfilter_umh_cleanup(struct umd_info *info)
@@ -21,40 +22,49 @@ void bpfilter_umh_cleanup(struct umd_info *info)
}
EXPORT_SYMBOL_GPL(bpfilter_umh_cleanup);
-static int bpfilter_mbox_request(struct sock *sk, int optname, sockptr_t optval,
- unsigned int optlen, bool is_set)
+static int bpfilter_process_sockopt(struct sock *sk, int optname,
+ sockptr_t optval, unsigned int optlen,
+ bool is_set)
{
+ struct mbox_request req = {
+ .is_set = is_set,
+ .pid = current->pid,
+ .cmd = optname,
+ .addr = (uintptr_t)optval.user,
+ .len = optlen,
+ };
+ struct mbox_reply reply;
int err;
- mutex_lock(&bpfilter_ops.lock);
- if (!bpfilter_ops.sockopt) {
- mutex_unlock(&bpfilter_ops.lock);
- request_module("bpfilter");
- mutex_lock(&bpfilter_ops.lock);
- if (!bpfilter_ops.sockopt) {
- err = -ENOPROTOOPT;
- goto out;
- }
+ if (sockptr_is_kernel(optval)) {
+ pr_err("kernel access not supported\n");
+ return -EFAULT;
}
- if (bpfilter_ops.info.tgid &&
- thread_group_exited(bpfilter_ops.info.tgid))
- bpfilter_umh_cleanup(&bpfilter_ops.info);
+ err = umd_mgmt_send_recv(&bpfilter_ops, &req, sizeof(req), &reply,
+ sizeof(reply));
+ if (err)
+ return err;
- if (!bpfilter_ops.info.tgid) {
- err = bpfilter_ops.start();
- if (err)
- goto out;
- }
- err = bpfilter_ops.sockopt(sk, optname, optval, optlen, is_set);
-out:
- mutex_unlock(&bpfilter_ops.lock);
- return err;
+ return reply.status;
+}
+
+static int bpfilter_post_start_umh(struct umd_mgmt *mgmt)
+{
+ struct mbox_request req = { .pid = current->pid };
+ struct mbox_reply reply;
+
+ /* health check that usermode process started correctly */
+ if (umd_send_recv(&bpfilter_ops.info, &req, sizeof(req), &reply,
+ sizeof(reply)) != 0 || reply.status != 0)
+ return -EFAULT;
+
+ return 0;
}
int bpfilter_ip_set_sockopt(struct sock *sk, int optname, sockptr_t optval,
unsigned int optlen)
{
- return bpfilter_mbox_request(sk, optname, optval, optlen, true);
+ return bpfilter_process_sockopt(sk, optname, optval, optlen, true);
}
int bpfilter_ip_get_sockopt(struct sock *sk, int optname, char __user *optval,
@@ -65,8 +75,8 @@ int bpfilter_ip_get_sockopt(struct sock *sk, int optname, char __user *optval,
if (get_user(len, optlen))
return -EFAULT;
- return bpfilter_mbox_request(sk, optname, USER_SOCKPTR(optval), len,
- false);
+ return bpfilter_process_sockopt(sk, optname, USER_SOCKPTR(optval), len,
+ false);
}
static int __init bpfilter_sockopt_init(void)
@@ -74,6 +84,9 @@ static int __init bpfilter_sockopt_init(void)
mutex_init(&bpfilter_ops.lock);
bpfilter_ops.info.tgid = NULL;
bpfilter_ops.info.driver_name = "bpfilter_umh";
+ bpfilter_ops.post_start = bpfilter_post_start_umh;
+ bpfilter_ops.kmod = "bpfilter";
+ bpfilter_ops.kmod_loaded = false;
return 0;
}