@@ -8,6 +8,7 @@
#include <net/page_pool/helpers.h>
#include <net/page_pool/memory_provider.h>
#include <trace/events/page_pool.h>
+#include <net/netdev_rx_queue.h>
#include <net/tcp.h>
#include <net/rps.h>
@@ -36,6 +37,65 @@ static inline struct io_zcrx_area *io_zcrx_iov_to_area(const struct net_iov *nio
return container_of(owner, struct io_zcrx_area, nia);
}
+static int io_open_zc_rxq(struct io_zcrx_ifq *ifq, unsigned ifq_idx)
+{
+ struct netdev_rx_queue *rxq;
+ struct net_device *dev = ifq->dev;
+ int ret;
+
+ ASSERT_RTNL();
+
+ if (ifq_idx >= dev->num_rx_queues)
+ return -EINVAL;
+ ifq_idx = array_index_nospec(ifq_idx, dev->num_rx_queues);
+
+ rxq = __netif_get_rx_queue(ifq->dev, ifq_idx);
+ if (rxq->mp_params.mp_priv)
+ return -EEXIST;
+
+ ifq->if_rxq = ifq_idx;
+ rxq->mp_params.mp_ops = &io_uring_pp_zc_ops;
+ rxq->mp_params.mp_priv = ifq;
+ ret = netdev_rx_queue_restart(ifq->dev, ifq->if_rxq);
+ if (ret)
+ goto fail;
+ return 0;
+fail:
+ rxq->mp_params.mp_ops = NULL;
+ rxq->mp_params.mp_priv = NULL;
+ ifq->if_rxq = -1;
+ return ret;
+}
+
+static void io_close_zc_rxq(struct io_zcrx_ifq *ifq)
+{
+ struct netdev_rx_queue *rxq;
+ int err;
+
+ if (ifq->if_rxq == -1)
+ return;
+
+ rtnl_lock();
+ if (WARN_ON_ONCE(ifq->if_rxq >= ifq->dev->num_rx_queues)) {
+ rtnl_unlock();
+ return;
+ }
+
+ rxq = __netif_get_rx_queue(ifq->dev, ifq->if_rxq);
+
+ WARN_ON_ONCE(rxq->mp_params.mp_priv != ifq);
+
+ rxq->mp_params.mp_ops = NULL;
+ rxq->mp_params.mp_priv = NULL;
+
+ err = netdev_rx_queue_restart(ifq->dev, ifq->if_rxq);
+ if (err)
+ pr_devel("io_uring: can't restart a queue on zcrx close\n");
+
+ rtnl_unlock();
+ ifq->if_rxq = -1;
+}
+
static int io_allocate_rbuf_ring(struct io_zcrx_ifq *ifq,
struct io_uring_zcrx_ifq_reg *reg)
{
@@ -156,9 +216,12 @@ static struct io_zcrx_ifq *io_zcrx_ifq_alloc(struct io_ring_ctx *ctx)
static void io_zcrx_ifq_free(struct io_zcrx_ifq *ifq)
{
+ io_close_zc_rxq(ifq);
+
if (ifq->area)
io_zcrx_free_area(ifq->area);
-
+ if (ifq->dev)
+ netdev_put(ifq->dev, &ifq->netdev_tracker);
io_free_rbuf_ring(ifq);
kfree(ifq);
}
@@ -214,7 +277,18 @@ int io_register_zcrx_ifq(struct io_ring_ctx *ctx,
goto err;
ifq->rq_entries = reg.rq_entries;
- ifq->if_rxq = reg.if_rxq;
+
+ ret = -ENODEV;
+ rtnl_lock();
+ ifq->dev = netdev_get_by_index(current->nsproxy->net_ns, reg.if_idx,
+ &ifq->netdev_tracker, GFP_KERNEL);
+ if (!ifq->dev)
+ goto err_rtnl_unlock;
+
+ ret = io_open_zc_rxq(ifq, reg.if_rxq);
+ if (ret)
+ goto err_rtnl_unlock;
+ rtnl_unlock();
ring_sz = sizeof(struct io_uring);
rqes_sz = sizeof(struct io_uring_zcrx_rqe) * ifq->rq_entries;
@@ -224,15 +298,20 @@ int io_register_zcrx_ifq(struct io_ring_ctx *ctx,
reg.offsets.tail = offsetof(struct io_uring, tail);
if (copy_to_user(arg, ®, sizeof(reg))) {
+ io_close_zc_rxq(ifq);
ret = -EFAULT;
goto err;
}
if (copy_to_user(u64_to_user_ptr(reg.area_ptr), &area, sizeof(area))) {
+ io_close_zc_rxq(ifq);
ret = -EFAULT;
goto err;
}
ctx->ifq = ifq;
return 0;
+
+err_rtnl_unlock:
+ rtnl_unlock();
err:
io_zcrx_ifq_free(ifq);
return ret;
@@ -254,6 +333,9 @@ void io_unregister_zcrx_ifqs(struct io_ring_ctx *ctx)
void io_shutdown_zcrx_ifqs(struct io_ring_ctx *ctx)
{
lockdep_assert_held(&ctx->uring_lock);
+
+ if (ctx->ifq)
+ io_close_zc_rxq(ctx->ifq);
}
static void io_zcrx_get_buf_uref(struct net_iov *niov)
@@ -5,6 +5,7 @@
#include <linux/io_uring_types.h>
#include <linux/socket.h>
#include <net/page_pool/types.h>
+#include <net/net_trackers.h>
#define IO_ZC_RX_UREF 0x10000
#define IO_ZC_RX_KREF_MASK (IO_ZC_RX_UREF - 1)
@@ -37,6 +38,7 @@ struct io_zcrx_ifq {
struct page **rqe_pages;
u32 if_rxq;
+ netdevice_tracker netdev_tracker;
};
#if defined(CONFIG_IO_URING_ZCRX)