diff mbox series

[RFC,net-next,v2,2/6] selftests: ncdevmem: Implement devmem TCP TX

Message ID 20250130211539.428952-3-almasrymina@google.com (mailing list archive)
State New
Headers show
Series Device memory TCP TX | expand

Commit Message

Mina Almasry Jan. 30, 2025, 9:15 p.m. UTC
Add support for devmem TX in ncdevmem.

This is a combination of the ncdevmem from the devmem TCP series RFCv1
which included the TX path, and work by Stan to include the netlink API
and refactored on top of his generic memory_provider support.

Signed-off-by: Mina Almasry <almasrymina@google.com>
Signed-off-by: Stanislav Fomichev <sdf@fomichev.me>

---

v2:
- make errors a static variable so that we catch instances where there
  are less than 20 errors across different buffers.
- Fix the issue where the seed is reset to 0 instead of its starting
  value 1.
- Use 1000ULL instead of 1000 to guard against overflow (Willem).
- Do not set POLLERR (Willem).
- Update the test to use the new interface where iov_base is the
  dmabuf_offset.
- Update the test to send 2 iov instead of 1, so we get some test
  coverage over sending multiple iovs at once.
- Print the ifindex the test is using, useful for debugging issues where
  maybe the test may fail because the ifindex of the socket is different
  from the dmabuf binding.
---
 .../selftests/drivers/net/hw/ncdevmem.c       | 276 +++++++++++++++++-
 1 file changed, 272 insertions(+), 4 deletions(-)

Comments

Stanislav Fomichev Jan. 30, 2025, 11:05 p.m. UTC | #1
On 01/30, Mina Almasry wrote:
> Add support for devmem TX in ncdevmem.
> 
> This is a combination of the ncdevmem from the devmem TCP series RFCv1
> which included the TX path, and work by Stan to include the netlink API
> and refactored on top of his generic memory_provider support.
> 
> Signed-off-by: Mina Almasry <almasrymina@google.com>
> Signed-off-by: Stanislav Fomichev <sdf@fomichev.me>
> 
> ---
> 
> v2:
> - make errors a static variable so that we catch instances where there
>   are less than 20 errors across different buffers.
> - Fix the issue where the seed is reset to 0 instead of its starting
>   value 1.
> - Use 1000ULL instead of 1000 to guard against overflow (Willem).
> - Do not set POLLERR (Willem).
> - Update the test to use the new interface where iov_base is the
>   dmabuf_offset.
> - Update the test to send 2 iov instead of 1, so we get some test
>   coverage over sending multiple iovs at once.
> - Print the ifindex the test is using, useful for debugging issues where
>   maybe the test may fail because the ifindex of the socket is different
>   from the dmabuf binding.
> ---
>  .../selftests/drivers/net/hw/ncdevmem.c       | 276 +++++++++++++++++-
>  1 file changed, 272 insertions(+), 4 deletions(-)
> 
> diff --git a/tools/testing/selftests/drivers/net/hw/ncdevmem.c b/tools/testing/selftests/drivers/net/hw/ncdevmem.c
> index 19a6969643f4..8455f19ecd1a 100644
> --- a/tools/testing/selftests/drivers/net/hw/ncdevmem.c
> +++ b/tools/testing/selftests/drivers/net/hw/ncdevmem.c
> @@ -40,15 +40,18 @@
>  #include <fcntl.h>
>  #include <malloc.h>
>  #include <error.h>
> +#include <poll.h>
>  
>  #include <arpa/inet.h>
>  #include <sys/socket.h>
>  #include <sys/mman.h>
>  #include <sys/ioctl.h>
>  #include <sys/syscall.h>
> +#include <sys/time.h>
>  
>  #include <linux/memfd.h>
>  #include <linux/dma-buf.h>
> +#include <linux/errqueue.h>
>  #include <linux/udmabuf.h>
>  #include <libmnl/libmnl.h>
>  #include <linux/types.h>
> @@ -80,6 +83,8 @@ static int num_queues = -1;
>  static char *ifname;
>  static unsigned int ifindex;
>  static unsigned int dmabuf_id;
> +static uint32_t tx_dmabuf_id;
> +static int waittime_ms = 500;
>  
>  struct memory_buffer {
>  	int fd;
> @@ -93,6 +98,8 @@ struct memory_buffer {
>  struct memory_provider {
>  	struct memory_buffer *(*alloc)(size_t size);
>  	void (*free)(struct memory_buffer *ctx);
> +	void (*memcpy_to_device)(struct memory_buffer *dst, size_t off,
> +				 void *src, int n);
>  	void (*memcpy_from_device)(void *dst, struct memory_buffer *src,
>  				   size_t off, int n);
>  };
> @@ -153,6 +160,20 @@ static void udmabuf_free(struct memory_buffer *ctx)
>  	free(ctx);
>  }
>  
> +static void udmabuf_memcpy_to_device(struct memory_buffer *dst, size_t off,
> +				     void *src, int n)
> +{
> +	struct dma_buf_sync sync = {};
> +
> +	sync.flags = DMA_BUF_SYNC_START | DMA_BUF_SYNC_WRITE;
> +	ioctl(dst->fd, DMA_BUF_IOCTL_SYNC, &sync);
> +
> +	memcpy(dst->buf_mem + off, src, n);
> +
> +	sync.flags = DMA_BUF_SYNC_END | DMA_BUF_SYNC_WRITE;
> +	ioctl(dst->fd, DMA_BUF_IOCTL_SYNC, &sync);
> +}
> +
>  static void udmabuf_memcpy_from_device(void *dst, struct memory_buffer *src,
>  				       size_t off, int n)
>  {
> @@ -170,6 +191,7 @@ static void udmabuf_memcpy_from_device(void *dst, struct memory_buffer *src,
>  static struct memory_provider udmabuf_memory_provider = {
>  	.alloc = udmabuf_alloc,
>  	.free = udmabuf_free,
> +	.memcpy_to_device = udmabuf_memcpy_to_device,
>  	.memcpy_from_device = udmabuf_memcpy_from_device,
>  };
>  
> @@ -188,7 +210,7 @@ void validate_buffer(void *line, size_t size)
>  {
>  	static unsigned char seed = 1;
>  	unsigned char *ptr = line;
> -	int errors = 0;
> +	static int errors;
>  	size_t i;
>  
>  	for (i = 0; i < size; i++) {
> @@ -202,7 +224,7 @@ void validate_buffer(void *line, size_t size)
>  		}
>  		seed++;
>  		if (seed == do_validation)
> -			seed = 0;
> +			seed = 1;
>  	}
>  
>  	fprintf(stdout, "Validated buffer\n");
> @@ -394,6 +416,49 @@ static int bind_rx_queue(unsigned int ifindex, unsigned int dmabuf_fd,
>  	return -1;
>  }
>  
> +static int bind_tx_queue(unsigned int ifindex, unsigned int dmabuf_fd,
> +			 struct ynl_sock **ys)
> +{
> +	struct netdev_bind_tx_req *req = NULL;
> +	struct netdev_bind_tx_rsp *rsp = NULL;
> +	struct ynl_error yerr;
> +
> +	*ys = ynl_sock_create(&ynl_netdev_family, &yerr);
> +	if (!*ys) {
> +		fprintf(stderr, "YNL: %s\n", yerr.msg);
> +		return -1;
> +	}
> +
> +	req = netdev_bind_tx_req_alloc();
> +	netdev_bind_tx_req_set_ifindex(req, ifindex);
> +	netdev_bind_tx_req_set_fd(req, dmabuf_fd);
> +
> +	rsp = netdev_bind_tx(*ys, req);
> +	if (!rsp) {
> +		perror("netdev_bind_tx");
> +		goto err_close;
> +	}
> +
> +	if (!rsp->_present.id) {
> +		perror("id not present");
> +		goto err_close;
> +	}
> +
> +	fprintf(stderr, "got tx dmabuf id=%d\n", rsp->id);
> +	tx_dmabuf_id = rsp->id;
> +
> +	netdev_bind_tx_req_free(req);
> +	netdev_bind_tx_rsp_free(rsp);
> +
> +	return 0;
> +
> +err_close:
> +	fprintf(stderr, "YNL failed: %s\n", (*ys)->err.msg);
> +	netdev_bind_tx_req_free(req);
> +	ynl_sock_destroy(*ys);
> +	return -1;
> +}
> +
>  static void enable_reuseaddr(int fd)
>  {
>  	int opt = 1;
> @@ -432,7 +497,7 @@ static int parse_address(const char *str, int port, struct sockaddr_in6 *sin6)
>  	return 0;
>  }
>  
> -int do_server(struct memory_buffer *mem)
> +static int do_server(struct memory_buffer *mem)
>  {
>  	char ctrl_data[sizeof(int) * 20000];
>  	struct netdev_queue_id *queues;
> @@ -686,6 +751,207 @@ void run_devmem_tests(void)
>  	provider->free(mem);
>  }
>  
> +static uint64_t gettimeofday_ms(void)
> +{
> +	struct timeval tv;
> +
> +	gettimeofday(&tv, NULL);
> +	return (tv.tv_sec * 1000ULL) + (tv.tv_usec / 1000ULL);
> +}
> +
> +static int do_poll(int fd)
> +{
> +	struct pollfd pfd;
> +	int ret;
> +
> +	pfd.revents = 0;
> +	pfd.fd = fd;
> +
> +	ret = poll(&pfd, 1, waittime_ms);
> +	if (ret == -1)
> +		error(1, errno, "poll");
> +
> +	return ret && (pfd.revents & POLLERR);
> +}
> +
> +static void wait_compl(int fd)
> +{
> +	int64_t tstop = gettimeofday_ms() + waittime_ms;
> +	char control[CMSG_SPACE(100)] = {};
> +	struct sock_extended_err *serr;
> +	struct msghdr msg = {};
> +	struct cmsghdr *cm;
> +	int retries = 10;
> +	__u32 hi, lo;
> +	int ret;
> +
> +	msg.msg_control = control;
> +	msg.msg_controllen = sizeof(control);
> +
> +	while (gettimeofday_ms() < tstop) {
> +		if (!do_poll(fd))
> +			continue;
> +
> +		ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
> +		if (ret < 0) {
> +			if (errno == EAGAIN)
> +				continue;
> +			error(1, ret, "recvmsg(MSG_ERRQUEUE)");
> +			return;
> +		}
> +		if (msg.msg_flags & MSG_CTRUNC)
> +			error(1, 0, "MSG_CTRUNC\n");
> +
> +		for (cm = CMSG_FIRSTHDR(&msg); cm; cm = CMSG_NXTHDR(&msg, cm)) {
> +			if (cm->cmsg_level != SOL_IP &&
> +			    cm->cmsg_level != SOL_IPV6)
> +				continue;
> +			if (cm->cmsg_level == SOL_IP &&
> +			    cm->cmsg_type != IP_RECVERR)
> +				continue;
> +			if (cm->cmsg_level == SOL_IPV6 &&
> +			    cm->cmsg_type != IPV6_RECVERR)
> +				continue;
> +
> +			serr = (void *)CMSG_DATA(cm);
> +			if (serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY)
> +				error(1, 0, "wrong origin %u", serr->ee_origin);
> +			if (serr->ee_errno != 0)
> +				error(1, 0, "wrong errno %d", serr->ee_errno);
> +
> +			hi = serr->ee_data;
> +			lo = serr->ee_info;
> +
> +			fprintf(stderr, "tx complete [%d,%d]\n", lo, hi);
> +			return;
> +		}
> +	}
> +
> +	error(1, 0, "did not receive tx completion");
> +}
> +
> +static int do_client(struct memory_buffer *mem)
> +{
> +	char ctrl_data[CMSG_SPACE(sizeof(struct dmabuf_tx_cmsg))];
> +	struct sockaddr_in6 server_sin;
> +	struct sockaddr_in6 client_sin;
> +	struct dmabuf_tx_cmsg ddmabuf;
> +	struct ynl_sock *ys = NULL;
> +	struct msghdr msg = {};
> +	ssize_t line_size = 0;
> +	struct cmsghdr *cmsg;
> +	struct iovec iov[2];
> +	uint64_t off = 100;
> +	char *line = NULL;
> +	size_t len = 0;
> +	int socket_fd;
> +	int ret, mid;
> +	int opt = 1;
> +
> +	ret = parse_address(server_ip, atoi(port), &server_sin);
> +	if (ret < 0)
> +		error(1, 0, "parse server address");
> +
> +	socket_fd = socket(AF_INET6, SOCK_STREAM, 0);
> +	if (socket_fd < 0)
> +		error(1, socket_fd, "create socket");
> +
> +	enable_reuseaddr(socket_fd);
> +
> +	ret = setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, ifname,
> +			 strlen(ifname) + 1);
> +	if (ret)
> +		error(1, ret, "bindtodevice");
> +
> +	if (bind_tx_queue(ifindex, mem->fd, &ys))
> +		error(1, 0, "Failed to bind\n");
> +
> +	ret = parse_address(client_ip, atoi(port), &client_sin);
> +	if (ret < 0)
> +		error(1, 0, "parse client address");
> +
> +	ret = bind(socket_fd, &client_sin, sizeof(client_sin));
> +	if (ret)
> +		error(1, ret, "bind");
> +
> +	ret = setsockopt(socket_fd, SOL_SOCKET, SO_ZEROCOPY, &opt, sizeof(opt));
> +	if (ret)
> +		error(1, ret, "set sock opt");
> +
> +	fprintf(stderr, "Connect to %s %d (via %s)\n", server_ip,
> +		ntohs(server_sin.sin6_port), ifname);
> +
> +	ret = connect(socket_fd, &server_sin, sizeof(server_sin));
> +	if (ret)
> +		error(1, ret, "connect");
> +
> +	while (1) {
> +		free(line);
> +		line = NULL;
> +		/* Subtract 1 from line_size to remove trailing newlines that
> +		 * get_line are surely to parse...
> +		 */
> +		line_size = getline(&line, &len, stdin) - 1;

Why not send the '\n' as well? If we skip the '\n', it's not keeping
netcat-like behavior :-(

> +
> +		if (line_size < 0)
> +			break;

[..]

> +		mid = (line_size / 2) + 1;
> +
> +		iov[0].iov_base = (void *)100;
> +		iov[0].iov_len = mid;
> +		iov[1].iov_base = (void *)2000;
> +		iov[1].iov_len = line_size - mid;

This seems a bit hard-coded. We should at least test that mid is < 2000?

But ideally we should have two modes for tx with a flag (and run them
both from the selftest):
- pass one big iov, this will test the sendmsg path which creates
  multiple skbs internally
- break 'line' into N sections (as you do here), but maybe have more
  control over the number of sections?

Maybe let's have a new --max-iov-size flag? Then we can call ncdevmem
with --max-iov-size <some prime number close to 4k> to exercise all
sorts of weird offsets?

(seems ok to also follow up on that separately)
Mina Almasry Jan. 30, 2025, 11:29 p.m. UTC | #2
On Thu, Jan 30, 2025 at 3:05 PM Stanislav Fomichev <stfomichev@gmail.com> wrote:
>
> On 01/30, Mina Almasry wrote:
> > Add support for devmem TX in ncdevmem.
> >
> > This is a combination of the ncdevmem from the devmem TCP series RFCv1
> > which included the TX path, and work by Stan to include the netlink API
> > and refactored on top of his generic memory_provider support.
> >
> > Signed-off-by: Mina Almasry <almasrymina@google.com>
> > Signed-off-by: Stanislav Fomichev <sdf@fomichev.me>
> >
> > ---
> >
> > v2:
> > - make errors a static variable so that we catch instances where there
> >   are less than 20 errors across different buffers.
> > - Fix the issue where the seed is reset to 0 instead of its starting
> >   value 1.
> > - Use 1000ULL instead of 1000 to guard against overflow (Willem).
> > - Do not set POLLERR (Willem).
> > - Update the test to use the new interface where iov_base is the
> >   dmabuf_offset.
> > - Update the test to send 2 iov instead of 1, so we get some test
> >   coverage over sending multiple iovs at once.
> > - Print the ifindex the test is using, useful for debugging issues where
> >   maybe the test may fail because the ifindex of the socket is different
> >   from the dmabuf binding.
> > ---
> >  .../selftests/drivers/net/hw/ncdevmem.c       | 276 +++++++++++++++++-
> >  1 file changed, 272 insertions(+), 4 deletions(-)
> >
> > diff --git a/tools/testing/selftests/drivers/net/hw/ncdevmem.c b/tools/testing/selftests/drivers/net/hw/ncdevmem.c
> > index 19a6969643f4..8455f19ecd1a 100644
> > --- a/tools/testing/selftests/drivers/net/hw/ncdevmem.c
> > +++ b/tools/testing/selftests/drivers/net/hw/ncdevmem.c
> > @@ -40,15 +40,18 @@
> >  #include <fcntl.h>
> >  #include <malloc.h>
> >  #include <error.h>
> > +#include <poll.h>
> >
> >  #include <arpa/inet.h>
> >  #include <sys/socket.h>
> >  #include <sys/mman.h>
> >  #include <sys/ioctl.h>
> >  #include <sys/syscall.h>
> > +#include <sys/time.h>
> >
> >  #include <linux/memfd.h>
> >  #include <linux/dma-buf.h>
> > +#include <linux/errqueue.h>
> >  #include <linux/udmabuf.h>
> >  #include <libmnl/libmnl.h>
> >  #include <linux/types.h>
> > @@ -80,6 +83,8 @@ static int num_queues = -1;
> >  static char *ifname;
> >  static unsigned int ifindex;
> >  static unsigned int dmabuf_id;
> > +static uint32_t tx_dmabuf_id;
> > +static int waittime_ms = 500;
> >
> >  struct memory_buffer {
> >       int fd;
> > @@ -93,6 +98,8 @@ struct memory_buffer {
> >  struct memory_provider {
> >       struct memory_buffer *(*alloc)(size_t size);
> >       void (*free)(struct memory_buffer *ctx);
> > +     void (*memcpy_to_device)(struct memory_buffer *dst, size_t off,
> > +                              void *src, int n);
> >       void (*memcpy_from_device)(void *dst, struct memory_buffer *src,
> >                                  size_t off, int n);
> >  };
> > @@ -153,6 +160,20 @@ static void udmabuf_free(struct memory_buffer *ctx)
> >       free(ctx);
> >  }
> >
> > +static void udmabuf_memcpy_to_device(struct memory_buffer *dst, size_t off,
> > +                                  void *src, int n)
> > +{
> > +     struct dma_buf_sync sync = {};
> > +
> > +     sync.flags = DMA_BUF_SYNC_START | DMA_BUF_SYNC_WRITE;
> > +     ioctl(dst->fd, DMA_BUF_IOCTL_SYNC, &sync);
> > +
> > +     memcpy(dst->buf_mem + off, src, n);
> > +
> > +     sync.flags = DMA_BUF_SYNC_END | DMA_BUF_SYNC_WRITE;
> > +     ioctl(dst->fd, DMA_BUF_IOCTL_SYNC, &sync);
> > +}
> > +
> >  static void udmabuf_memcpy_from_device(void *dst, struct memory_buffer *src,
> >                                      size_t off, int n)
> >  {
> > @@ -170,6 +191,7 @@ static void udmabuf_memcpy_from_device(void *dst, struct memory_buffer *src,
> >  static struct memory_provider udmabuf_memory_provider = {
> >       .alloc = udmabuf_alloc,
> >       .free = udmabuf_free,
> > +     .memcpy_to_device = udmabuf_memcpy_to_device,
> >       .memcpy_from_device = udmabuf_memcpy_from_device,
> >  };
> >
> > @@ -188,7 +210,7 @@ void validate_buffer(void *line, size_t size)
> >  {
> >       static unsigned char seed = 1;
> >       unsigned char *ptr = line;
> > -     int errors = 0;
> > +     static int errors;
> >       size_t i;
> >
> >       for (i = 0; i < size; i++) {
> > @@ -202,7 +224,7 @@ void validate_buffer(void *line, size_t size)
> >               }
> >               seed++;
> >               if (seed == do_validation)
> > -                     seed = 0;
> > +                     seed = 1;
> >       }
> >
> >       fprintf(stdout, "Validated buffer\n");
> > @@ -394,6 +416,49 @@ static int bind_rx_queue(unsigned int ifindex, unsigned int dmabuf_fd,
> >       return -1;
> >  }
> >
> > +static int bind_tx_queue(unsigned int ifindex, unsigned int dmabuf_fd,
> > +                      struct ynl_sock **ys)
> > +{
> > +     struct netdev_bind_tx_req *req = NULL;
> > +     struct netdev_bind_tx_rsp *rsp = NULL;
> > +     struct ynl_error yerr;
> > +
> > +     *ys = ynl_sock_create(&ynl_netdev_family, &yerr);
> > +     if (!*ys) {
> > +             fprintf(stderr, "YNL: %s\n", yerr.msg);
> > +             return -1;
> > +     }
> > +
> > +     req = netdev_bind_tx_req_alloc();
> > +     netdev_bind_tx_req_set_ifindex(req, ifindex);
> > +     netdev_bind_tx_req_set_fd(req, dmabuf_fd);
> > +
> > +     rsp = netdev_bind_tx(*ys, req);
> > +     if (!rsp) {
> > +             perror("netdev_bind_tx");
> > +             goto err_close;
> > +     }
> > +
> > +     if (!rsp->_present.id) {
> > +             perror("id not present");
> > +             goto err_close;
> > +     }
> > +
> > +     fprintf(stderr, "got tx dmabuf id=%d\n", rsp->id);
> > +     tx_dmabuf_id = rsp->id;
> > +
> > +     netdev_bind_tx_req_free(req);
> > +     netdev_bind_tx_rsp_free(rsp);
> > +
> > +     return 0;
> > +
> > +err_close:
> > +     fprintf(stderr, "YNL failed: %s\n", (*ys)->err.msg);
> > +     netdev_bind_tx_req_free(req);
> > +     ynl_sock_destroy(*ys);
> > +     return -1;
> > +}
> > +
> >  static void enable_reuseaddr(int fd)
> >  {
> >       int opt = 1;
> > @@ -432,7 +497,7 @@ static int parse_address(const char *str, int port, struct sockaddr_in6 *sin6)
> >       return 0;
> >  }
> >
> > -int do_server(struct memory_buffer *mem)
> > +static int do_server(struct memory_buffer *mem)
> >  {
> >       char ctrl_data[sizeof(int) * 20000];
> >       struct netdev_queue_id *queues;
> > @@ -686,6 +751,207 @@ void run_devmem_tests(void)
> >       provider->free(mem);
> >  }
> >
> > +static uint64_t gettimeofday_ms(void)
> > +{
> > +     struct timeval tv;
> > +
> > +     gettimeofday(&tv, NULL);
> > +     return (tv.tv_sec * 1000ULL) + (tv.tv_usec / 1000ULL);
> > +}
> > +
> > +static int do_poll(int fd)
> > +{
> > +     struct pollfd pfd;
> > +     int ret;
> > +
> > +     pfd.revents = 0;
> > +     pfd.fd = fd;
> > +
> > +     ret = poll(&pfd, 1, waittime_ms);
> > +     if (ret == -1)
> > +             error(1, errno, "poll");
> > +
> > +     return ret && (pfd.revents & POLLERR);
> > +}
> > +
> > +static void wait_compl(int fd)
> > +{
> > +     int64_t tstop = gettimeofday_ms() + waittime_ms;
> > +     char control[CMSG_SPACE(100)] = {};
> > +     struct sock_extended_err *serr;
> > +     struct msghdr msg = {};
> > +     struct cmsghdr *cm;
> > +     int retries = 10;
> > +     __u32 hi, lo;
> > +     int ret;
> > +
> > +     msg.msg_control = control;
> > +     msg.msg_controllen = sizeof(control);
> > +
> > +     while (gettimeofday_ms() < tstop) {
> > +             if (!do_poll(fd))
> > +                     continue;
> > +
> > +             ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
> > +             if (ret < 0) {
> > +                     if (errno == EAGAIN)
> > +                             continue;
> > +                     error(1, ret, "recvmsg(MSG_ERRQUEUE)");
> > +                     return;
> > +             }
> > +             if (msg.msg_flags & MSG_CTRUNC)
> > +                     error(1, 0, "MSG_CTRUNC\n");
> > +
> > +             for (cm = CMSG_FIRSTHDR(&msg); cm; cm = CMSG_NXTHDR(&msg, cm)) {
> > +                     if (cm->cmsg_level != SOL_IP &&
> > +                         cm->cmsg_level != SOL_IPV6)
> > +                             continue;
> > +                     if (cm->cmsg_level == SOL_IP &&
> > +                         cm->cmsg_type != IP_RECVERR)
> > +                             continue;
> > +                     if (cm->cmsg_level == SOL_IPV6 &&
> > +                         cm->cmsg_type != IPV6_RECVERR)
> > +                             continue;
> > +
> > +                     serr = (void *)CMSG_DATA(cm);
> > +                     if (serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY)
> > +                             error(1, 0, "wrong origin %u", serr->ee_origin);
> > +                     if (serr->ee_errno != 0)
> > +                             error(1, 0, "wrong errno %d", serr->ee_errno);
> > +
> > +                     hi = serr->ee_data;
> > +                     lo = serr->ee_info;
> > +
> > +                     fprintf(stderr, "tx complete [%d,%d]\n", lo, hi);
> > +                     return;
> > +             }
> > +     }
> > +
> > +     error(1, 0, "did not receive tx completion");
> > +}
> > +
> > +static int do_client(struct memory_buffer *mem)
> > +{
> > +     char ctrl_data[CMSG_SPACE(sizeof(struct dmabuf_tx_cmsg))];
> > +     struct sockaddr_in6 server_sin;
> > +     struct sockaddr_in6 client_sin;
> > +     struct dmabuf_tx_cmsg ddmabuf;
> > +     struct ynl_sock *ys = NULL;
> > +     struct msghdr msg = {};
> > +     ssize_t line_size = 0;
> > +     struct cmsghdr *cmsg;
> > +     struct iovec iov[2];
> > +     uint64_t off = 100;
> > +     char *line = NULL;
> > +     size_t len = 0;
> > +     int socket_fd;
> > +     int ret, mid;
> > +     int opt = 1;
> > +
> > +     ret = parse_address(server_ip, atoi(port), &server_sin);
> > +     if (ret < 0)
> > +             error(1, 0, "parse server address");
> > +
> > +     socket_fd = socket(AF_INET6, SOCK_STREAM, 0);
> > +     if (socket_fd < 0)
> > +             error(1, socket_fd, "create socket");
> > +
> > +     enable_reuseaddr(socket_fd);
> > +
> > +     ret = setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, ifname,
> > +                      strlen(ifname) + 1);
> > +     if (ret)
> > +             error(1, ret, "bindtodevice");
> > +
> > +     if (bind_tx_queue(ifindex, mem->fd, &ys))
> > +             error(1, 0, "Failed to bind\n");
> > +
> > +     ret = parse_address(client_ip, atoi(port), &client_sin);
> > +     if (ret < 0)
> > +             error(1, 0, "parse client address");
> > +
> > +     ret = bind(socket_fd, &client_sin, sizeof(client_sin));
> > +     if (ret)
> > +             error(1, ret, "bind");
> > +
> > +     ret = setsockopt(socket_fd, SOL_SOCKET, SO_ZEROCOPY, &opt, sizeof(opt));
> > +     if (ret)
> > +             error(1, ret, "set sock opt");
> > +
> > +     fprintf(stderr, "Connect to %s %d (via %s)\n", server_ip,
> > +             ntohs(server_sin.sin6_port), ifname);
> > +
> > +     ret = connect(socket_fd, &server_sin, sizeof(server_sin));
> > +     if (ret)
> > +             error(1, ret, "connect");
> > +
> > +     while (1) {
> > +             free(line);
> > +             line = NULL;
> > +             /* Subtract 1 from line_size to remove trailing newlines that
> > +              * get_line are surely to parse...
> > +              */
> > +             line_size = getline(&line, &len, stdin) - 1;
>
> Why not send the '\n' as well? If we skip the '\n', it's not keeping
> netcat-like behavior :-(
>

Ah, this is to make the validation on the RX side work. The validation
expects a repeating pattern:

1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, ....

With no newlines.

But it does become weird that TX doesn't match netcat. Let me think on
this a bit. Maybe I can resolve this in a way where the validation
works but also the tx side behaves like netcat. Maybe the RX
validation can skip newlines or something. Maybe I can massage how I
invoke the test.

This can become a rabbit hole because I do want to invoke multiple
sendmsg() in one iteration of the test as well, and not overcomplicate
the series.

> > +
> > +             if (line_size < 0)
> > +                     break;
>
> [..]
>
> > +             mid = (line_size / 2) + 1;
> > +
> > +             iov[0].iov_base = (void *)100;
> > +             iov[0].iov_len = mid;
> > +             iov[1].iov_base = (void *)2000;
> > +             iov[1].iov_len = line_size - mid;
>
> This seems a bit hard-coded. We should at least test that mid is < 2000?
>

Yep, I should do at least do that. FWIW (although missed it in this
iteration), at the top of the file I put docs that list exactly what I
run for others to repro the results (and nipa should eventually run
similar, I have that on my todo list), but the test should be more
flexible to at least catch instances where mid is too large.

> But ideally we should have two modes for tx with a flag (and run them
> both from the selftest):
> - pass one big iov, this will test the sendmsg path which creates
>   multiple skbs internally
> - break 'line' into N sections (as you do here), but maybe have more
>   control over the number of sections?
>
> Maybe let's have a new --max-iov-size flag? Then we can call ncdevmem
> with --max-iov-size <some prime number close to 4k> to exercise all
> sorts of weird offsets?
>
> (seems ok to also follow up on that separately)

Yeah, improvements to tests are always possible, lets have some
reasonable tests in the first iteration and expand.
diff mbox series

Patch

diff --git a/tools/testing/selftests/drivers/net/hw/ncdevmem.c b/tools/testing/selftests/drivers/net/hw/ncdevmem.c
index 19a6969643f4..8455f19ecd1a 100644
--- a/tools/testing/selftests/drivers/net/hw/ncdevmem.c
+++ b/tools/testing/selftests/drivers/net/hw/ncdevmem.c
@@ -40,15 +40,18 @@ 
 #include <fcntl.h>
 #include <malloc.h>
 #include <error.h>
+#include <poll.h>
 
 #include <arpa/inet.h>
 #include <sys/socket.h>
 #include <sys/mman.h>
 #include <sys/ioctl.h>
 #include <sys/syscall.h>
+#include <sys/time.h>
 
 #include <linux/memfd.h>
 #include <linux/dma-buf.h>
+#include <linux/errqueue.h>
 #include <linux/udmabuf.h>
 #include <libmnl/libmnl.h>
 #include <linux/types.h>
@@ -80,6 +83,8 @@  static int num_queues = -1;
 static char *ifname;
 static unsigned int ifindex;
 static unsigned int dmabuf_id;
+static uint32_t tx_dmabuf_id;
+static int waittime_ms = 500;
 
 struct memory_buffer {
 	int fd;
@@ -93,6 +98,8 @@  struct memory_buffer {
 struct memory_provider {
 	struct memory_buffer *(*alloc)(size_t size);
 	void (*free)(struct memory_buffer *ctx);
+	void (*memcpy_to_device)(struct memory_buffer *dst, size_t off,
+				 void *src, int n);
 	void (*memcpy_from_device)(void *dst, struct memory_buffer *src,
 				   size_t off, int n);
 };
@@ -153,6 +160,20 @@  static void udmabuf_free(struct memory_buffer *ctx)
 	free(ctx);
 }
 
+static void udmabuf_memcpy_to_device(struct memory_buffer *dst, size_t off,
+				     void *src, int n)
+{
+	struct dma_buf_sync sync = {};
+
+	sync.flags = DMA_BUF_SYNC_START | DMA_BUF_SYNC_WRITE;
+	ioctl(dst->fd, DMA_BUF_IOCTL_SYNC, &sync);
+
+	memcpy(dst->buf_mem + off, src, n);
+
+	sync.flags = DMA_BUF_SYNC_END | DMA_BUF_SYNC_WRITE;
+	ioctl(dst->fd, DMA_BUF_IOCTL_SYNC, &sync);
+}
+
 static void udmabuf_memcpy_from_device(void *dst, struct memory_buffer *src,
 				       size_t off, int n)
 {
@@ -170,6 +191,7 @@  static void udmabuf_memcpy_from_device(void *dst, struct memory_buffer *src,
 static struct memory_provider udmabuf_memory_provider = {
 	.alloc = udmabuf_alloc,
 	.free = udmabuf_free,
+	.memcpy_to_device = udmabuf_memcpy_to_device,
 	.memcpy_from_device = udmabuf_memcpy_from_device,
 };
 
@@ -188,7 +210,7 @@  void validate_buffer(void *line, size_t size)
 {
 	static unsigned char seed = 1;
 	unsigned char *ptr = line;
-	int errors = 0;
+	static int errors;
 	size_t i;
 
 	for (i = 0; i < size; i++) {
@@ -202,7 +224,7 @@  void validate_buffer(void *line, size_t size)
 		}
 		seed++;
 		if (seed == do_validation)
-			seed = 0;
+			seed = 1;
 	}
 
 	fprintf(stdout, "Validated buffer\n");
@@ -394,6 +416,49 @@  static int bind_rx_queue(unsigned int ifindex, unsigned int dmabuf_fd,
 	return -1;
 }
 
+static int bind_tx_queue(unsigned int ifindex, unsigned int dmabuf_fd,
+			 struct ynl_sock **ys)
+{
+	struct netdev_bind_tx_req *req = NULL;
+	struct netdev_bind_tx_rsp *rsp = NULL;
+	struct ynl_error yerr;
+
+	*ys = ynl_sock_create(&ynl_netdev_family, &yerr);
+	if (!*ys) {
+		fprintf(stderr, "YNL: %s\n", yerr.msg);
+		return -1;
+	}
+
+	req = netdev_bind_tx_req_alloc();
+	netdev_bind_tx_req_set_ifindex(req, ifindex);
+	netdev_bind_tx_req_set_fd(req, dmabuf_fd);
+
+	rsp = netdev_bind_tx(*ys, req);
+	if (!rsp) {
+		perror("netdev_bind_tx");
+		goto err_close;
+	}
+
+	if (!rsp->_present.id) {
+		perror("id not present");
+		goto err_close;
+	}
+
+	fprintf(stderr, "got tx dmabuf id=%d\n", rsp->id);
+	tx_dmabuf_id = rsp->id;
+
+	netdev_bind_tx_req_free(req);
+	netdev_bind_tx_rsp_free(rsp);
+
+	return 0;
+
+err_close:
+	fprintf(stderr, "YNL failed: %s\n", (*ys)->err.msg);
+	netdev_bind_tx_req_free(req);
+	ynl_sock_destroy(*ys);
+	return -1;
+}
+
 static void enable_reuseaddr(int fd)
 {
 	int opt = 1;
@@ -432,7 +497,7 @@  static int parse_address(const char *str, int port, struct sockaddr_in6 *sin6)
 	return 0;
 }
 
-int do_server(struct memory_buffer *mem)
+static int do_server(struct memory_buffer *mem)
 {
 	char ctrl_data[sizeof(int) * 20000];
 	struct netdev_queue_id *queues;
@@ -686,6 +751,207 @@  void run_devmem_tests(void)
 	provider->free(mem);
 }
 
+static uint64_t gettimeofday_ms(void)
+{
+	struct timeval tv;
+
+	gettimeofday(&tv, NULL);
+	return (tv.tv_sec * 1000ULL) + (tv.tv_usec / 1000ULL);
+}
+
+static int do_poll(int fd)
+{
+	struct pollfd pfd;
+	int ret;
+
+	pfd.revents = 0;
+	pfd.fd = fd;
+
+	ret = poll(&pfd, 1, waittime_ms);
+	if (ret == -1)
+		error(1, errno, "poll");
+
+	return ret && (pfd.revents & POLLERR);
+}
+
+static void wait_compl(int fd)
+{
+	int64_t tstop = gettimeofday_ms() + waittime_ms;
+	char control[CMSG_SPACE(100)] = {};
+	struct sock_extended_err *serr;
+	struct msghdr msg = {};
+	struct cmsghdr *cm;
+	int retries = 10;
+	__u32 hi, lo;
+	int ret;
+
+	msg.msg_control = control;
+	msg.msg_controllen = sizeof(control);
+
+	while (gettimeofday_ms() < tstop) {
+		if (!do_poll(fd))
+			continue;
+
+		ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
+		if (ret < 0) {
+			if (errno == EAGAIN)
+				continue;
+			error(1, ret, "recvmsg(MSG_ERRQUEUE)");
+			return;
+		}
+		if (msg.msg_flags & MSG_CTRUNC)
+			error(1, 0, "MSG_CTRUNC\n");
+
+		for (cm = CMSG_FIRSTHDR(&msg); cm; cm = CMSG_NXTHDR(&msg, cm)) {
+			if (cm->cmsg_level != SOL_IP &&
+			    cm->cmsg_level != SOL_IPV6)
+				continue;
+			if (cm->cmsg_level == SOL_IP &&
+			    cm->cmsg_type != IP_RECVERR)
+				continue;
+			if (cm->cmsg_level == SOL_IPV6 &&
+			    cm->cmsg_type != IPV6_RECVERR)
+				continue;
+
+			serr = (void *)CMSG_DATA(cm);
+			if (serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY)
+				error(1, 0, "wrong origin %u", serr->ee_origin);
+			if (serr->ee_errno != 0)
+				error(1, 0, "wrong errno %d", serr->ee_errno);
+
+			hi = serr->ee_data;
+			lo = serr->ee_info;
+
+			fprintf(stderr, "tx complete [%d,%d]\n", lo, hi);
+			return;
+		}
+	}
+
+	error(1, 0, "did not receive tx completion");
+}
+
+static int do_client(struct memory_buffer *mem)
+{
+	char ctrl_data[CMSG_SPACE(sizeof(struct dmabuf_tx_cmsg))];
+	struct sockaddr_in6 server_sin;
+	struct sockaddr_in6 client_sin;
+	struct dmabuf_tx_cmsg ddmabuf;
+	struct ynl_sock *ys = NULL;
+	struct msghdr msg = {};
+	ssize_t line_size = 0;
+	struct cmsghdr *cmsg;
+	struct iovec iov[2];
+	uint64_t off = 100;
+	char *line = NULL;
+	size_t len = 0;
+	int socket_fd;
+	int ret, mid;
+	int opt = 1;
+
+	ret = parse_address(server_ip, atoi(port), &server_sin);
+	if (ret < 0)
+		error(1, 0, "parse server address");
+
+	socket_fd = socket(AF_INET6, SOCK_STREAM, 0);
+	if (socket_fd < 0)
+		error(1, socket_fd, "create socket");
+
+	enable_reuseaddr(socket_fd);
+
+	ret = setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, ifname,
+			 strlen(ifname) + 1);
+	if (ret)
+		error(1, ret, "bindtodevice");
+
+	if (bind_tx_queue(ifindex, mem->fd, &ys))
+		error(1, 0, "Failed to bind\n");
+
+	ret = parse_address(client_ip, atoi(port), &client_sin);
+	if (ret < 0)
+		error(1, 0, "parse client address");
+
+	ret = bind(socket_fd, &client_sin, sizeof(client_sin));
+	if (ret)
+		error(1, ret, "bind");
+
+	ret = setsockopt(socket_fd, SOL_SOCKET, SO_ZEROCOPY, &opt, sizeof(opt));
+	if (ret)
+		error(1, ret, "set sock opt");
+
+	fprintf(stderr, "Connect to %s %d (via %s)\n", server_ip,
+		ntohs(server_sin.sin6_port), ifname);
+
+	ret = connect(socket_fd, &server_sin, sizeof(server_sin));
+	if (ret)
+		error(1, ret, "connect");
+
+	while (1) {
+		free(line);
+		line = NULL;
+		/* Subtract 1 from line_size to remove trailing newlines that
+		 * get_line are surely to parse...
+		 */
+		line_size = getline(&line, &len, stdin) - 1;
+
+		if (line_size < 0)
+			break;
+
+		mid = (line_size / 2) + 1;
+
+		iov[0].iov_base = (void *)100;
+		iov[0].iov_len = mid;
+		iov[1].iov_base = (void *)2000;
+		iov[1].iov_len = line_size - mid;
+
+		provider->memcpy_to_device(mem, (size_t)iov[0].iov_base, line,
+					   iov[0].iov_len);
+		provider->memcpy_to_device(mem, (size_t)iov[1].iov_base,
+					   line + iov[0].iov_len,
+					   iov[1].iov_len);
+
+		fprintf(stderr,
+			"read line_size=%ld off=%d iov[0].iov_base=%d, iov[0].iov_len=%d, iov[1].iov_base=%d, iov[1].iov_len=%d\n",
+			line_size, off, iov[0].iov_base, iov[0].iov_len,
+			iov[1].iov_base, iov[1].iov_len);
+
+		msg.msg_iov = iov;
+		msg.msg_iovlen = 2;
+
+		msg.msg_control = ctrl_data;
+		msg.msg_controllen = sizeof(ctrl_data);
+
+		cmsg = CMSG_FIRSTHDR(&msg);
+		cmsg->cmsg_level = SOL_SOCKET;
+		cmsg->cmsg_type = SCM_DEVMEM_DMABUF;
+		cmsg->cmsg_len = CMSG_LEN(sizeof(struct dmabuf_tx_cmsg));
+
+		ddmabuf.dmabuf_id = tx_dmabuf_id;
+
+		*((struct dmabuf_tx_cmsg *)CMSG_DATA(cmsg)) = ddmabuf;
+
+		ret = sendmsg(socket_fd, &msg, MSG_ZEROCOPY);
+		if (ret < 0)
+			error(1, errno, "Failed sendmsg");
+
+		fprintf(stderr, "sendmsg_ret=%d\n", ret);
+
+		if (ret != line_size)
+			error(1, errno, "Did not send all bytes");
+
+		wait_compl(socket_fd);
+	}
+
+	fprintf(stderr, "%s: tx ok\n", TEST_PREFIX);
+
+	free(line);
+	close(socket_fd);
+
+	if (ys)
+		ynl_sock_destroy(ys);
+
+	return 0;
+}
+
 int main(int argc, char *argv[])
 {
 	struct memory_buffer *mem;
@@ -729,6 +995,8 @@  int main(int argc, char *argv[])
 
 	ifindex = if_nametoindex(ifname);
 
+	fprintf(stderr, "using ifindex=%u\n", ifindex);
+
 	if (!server_ip && !client_ip) {
 		if (start_queue < 0 && num_queues < 0) {
 			num_queues = rxq_num(ifindex);
@@ -779,7 +1047,7 @@  int main(int argc, char *argv[])
 		error(1, 0, "Missing -p argument\n");
 
 	mem = provider->alloc(getpagesize() * NUM_PAGES);
-	ret = is_server ? do_server(mem) : 1;
+	ret = is_server ? do_server(mem) : do_client(mem);
 	provider->free(mem);
 
 	return ret;