@@ -1053,7 +1053,7 @@ int io_recvzc(struct io_kiocb *req, unsigned int issue_flags)
if (flags & MSG_WAITALL)
min_ret = zc->len;
- ret = io_zc_rx_recv(sock, zc->datalen, flags);
+ ret = io_zc_rx_recv(ifq, sock, zc->datalen, flags);
if (ret < min_ret) {
if (ret == -EAGAIN && force_nonblock) {
if (issue_flags & IO_URING_F_MULTISHOT)
@@ -577,7 +577,7 @@ static struct io_zc_rx_ifq *io_zc_rx_ifq_skb(struct sk_buff *skb)
}
static int zc_rx_recv_frag(struct io_zc_rx_ifq *ifq, const skb_frag_t *frag,
- int off, int len)
+ int off, int len, bool zc_skb)
{
struct io_uring_rbuf_cqe *cqe;
unsigned int cq_idx, queued, free, entries;
@@ -588,7 +588,7 @@ static int zc_rx_recv_frag(struct io_zc_rx_ifq *ifq, const skb_frag_t *frag,
page = skb_frag_page(frag);
off += skb_frag_off(frag);
- if (likely(ifq && is_zc_rx_page(page))) {
+ if (likely(zc_skb && is_zc_rx_page(page))) {
mask = ifq->cq_entries - 1;
pgid = page_private(page) & 0xffffffff;
io_zc_rx_get_buf_uref(ifq->pool, pgid);
@@ -618,14 +618,19 @@ static int
zc_rx_recv_skb(read_descriptor_t *desc, struct sk_buff *skb,
unsigned int offset, size_t len)
{
- struct io_zc_rx_ifq *ifq;
+ struct io_zc_rx_ifq *ifq = desc->arg.data;
+ struct io_zc_rx_ifq *skb_ifq;
struct sk_buff *frag_iter;
unsigned start, start_off;
int i, copy, end, off;
+ bool zc_skb = true;
int ret = 0;
- ifq = io_zc_rx_ifq_skb(skb);
- if (!ifq) {
+ skb_ifq = io_zc_rx_ifq_skb(skb);
+ if (unlikely(ifq != skb_ifq)) {
+ zc_skb = false;
+ if (WARN_ON_ONCE(skb_ifq))
+ return -EFAULT;
pr_debug("non zerocopy pages are not supported\n");
return -EFAULT;
}
@@ -649,7 +654,7 @@ zc_rx_recv_skb(read_descriptor_t *desc, struct sk_buff *skb,
copy = len;
off = offset - start;
- ret = zc_rx_recv_frag(ifq, frag, off, copy);
+ ret = zc_rx_recv_frag(ifq, frag, off, copy, zc_skb);
if (ret < 0)
goto out;
@@ -690,16 +695,18 @@ zc_rx_recv_skb(read_descriptor_t *desc, struct sk_buff *skb,
return offset - start_off;
}
-static int io_zc_rx_tcp_read(struct sock *sk)
+static int io_zc_rx_tcp_read(struct io_zc_rx_ifq *ifq, struct sock *sk)
{
read_descriptor_t rd_desc = {
.count = 1,
+ .arg.data = ifq,
};
return tcp_read_sock(sk, &rd_desc, zc_rx_recv_skb);
}
-static int io_zc_rx_tcp_recvmsg(struct sock *sk, unsigned int recv_limit,
+static int io_zc_rx_tcp_recvmsg(struct io_zc_rx_ifq *ifq, struct sock *sk,
+ unsigned int recv_limit,
int flags, int *addr_len)
{
size_t used;
@@ -712,7 +719,7 @@ static int io_zc_rx_tcp_recvmsg(struct sock *sk, unsigned int recv_limit,
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
while (recv_limit) {
- ret = io_zc_rx_tcp_read(sk);
+ ret = io_zc_rx_tcp_read(ifq, sk);
if (ret < 0)
break;
if (!ret) {
@@ -767,7 +774,8 @@ static int io_zc_rx_tcp_recvmsg(struct sock *sk, unsigned int recv_limit,
return ret;
}
-int io_zc_rx_recv(struct socket *sock, unsigned int limit, unsigned int flags)
+int io_zc_rx_recv(struct io_zc_rx_ifq *ifq, struct socket *sock,
+ unsigned int limit, unsigned int flags)
{
struct sock *sk = sock->sk;
const struct proto *prot;
@@ -783,7 +791,7 @@ int io_zc_rx_recv(struct socket *sock, unsigned int limit, unsigned int flags)
sock_rps_record_flow(sk);
- ret = io_zc_rx_tcp_recvmsg(sk, limit, flags, &addr_len);
+ ret = io_zc_rx_tcp_recvmsg(ifq, sk, limit, flags, &addr_len);
return ret;
}
@@ -62,6 +62,7 @@ static inline int io_register_zc_rx_sock(struct io_ring_ctx *ctx,
int io_recvzc(struct io_kiocb *req, unsigned int issue_flags);
int io_recvzc_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe);
-int io_zc_rx_recv(struct socket *sock, unsigned int limit, unsigned int flags);
+int io_zc_rx_recv(struct io_zc_rx_ifq *ifq, struct socket *sock,
+ unsigned int limit, unsigned int flags);
#endif