From patchwork Fri Oct 15 09:12:13 2010 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: "Xin, Xiaohui" X-Patchwork-Id: 255691 Received: from vger.kernel.org (vger.kernel.org [209.132.180.67]) by demeter1.kernel.org (8.14.4/8.14.3) with ESMTP id o9F8uOoc004100 for ; Fri, 15 Oct 2010 08:56:26 GMT Received: (majordomo@vger.kernel.org) by vger.kernel.org via listexpand id S1754920Ab0JOIyY (ORCPT ); Fri, 15 Oct 2010 04:54:24 -0400 Received: from mga01.intel.com ([192.55.52.88]:60392 "EHLO mga01.intel.com" rhost-flags-OK-OK-OK-OK) by vger.kernel.org with ESMTP id S1754706Ab0JOIxF (ORCPT ); Fri, 15 Oct 2010 04:53:05 -0400 Received: from fmsmga002.fm.intel.com ([10.253.24.26]) by fmsmga101.fm.intel.com with ESMTP; 15 Oct 2010 01:53:04 -0700 X-ExtLoop1: 1 X-IronPort-AV: E=Sophos;i="4.57,335,1283756400"; d="scan'208";a="616995453" Received: from unknown (HELO localhost.localdomain.sh.intel.com) ([10.239.36.73]) by fmsmga002.fm.intel.com with ESMTP; 15 Oct 2010 01:53:03 -0700 From: xiaohui.xin@intel.com To: netdev@vger.kernel.org, kvm@vger.kernel.org, linux-kernel@vger.kernel.org, mst@redhat.com, mingo@elte.hu, davem@davemloft.net, herbert@gondor.hengli.com.au, jdike@linux.intel.com Cc: Xin Xiaohui Subject: [PATCH v13 12/16] Add mp(mediate passthru) device. Date: Fri, 15 Oct 2010 17:12:13 +0800 Message-Id: X-Mailer: git-send-email 1.7.3 In-Reply-To: <1287133937-5538-1-git-send-email-xiaohui.xin@intel.com> References: <1287133937-5538-1-git-send-email-xiaohui.xin@intel.com> In-Reply-To: References: Sender: kvm-owner@vger.kernel.org Precedence: bulk List-ID: X-Mailing-List: kvm@vger.kernel.org X-Greylist: IP, sender and recipient auto-whitelisted, not delayed by milter-greylist-4.2.3 (demeter1.kernel.org [140.211.167.41]); Fri, 15 Oct 2010 08:56:26 +0000 (UTC) diff --git a/drivers/vhost/mpassthru.c b/drivers/vhost/mpassthru.c new file mode 100644 index 0000000..5389f3e --- /dev/null +++ b/drivers/vhost/mpassthru.c @@ -0,0 +1,1380 @@ +/* + * MPASSTHRU - Mediate passthrough device. + * Copyright (C) 2009 ZhaoYu, XinXiaohui, Dike, Jeffery G + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + */ + +#define DRV_NAME "mpassthru" +#define DRV_DESCRIPTION "Mediate passthru device driver" +#define DRV_COPYRIGHT "(C) 2009 ZhaoYu, XinXiaohui, Dike, Jeffery G" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +struct mp_struct { + struct mp_file *mfile; + struct net_device *dev; + struct page_pool *pool; + struct socket socket; + struct socket_wq wq; + struct mm_struct *mm; +}; + +struct mp_file { + atomic_t count; + struct mp_struct *mp; + struct net *net; +}; + +struct mp_sock { + struct sock sk; + struct mp_struct *mp; +}; + +/* The main function to allocate external buffers */ +static struct skb_ext_page *page_ctor(struct mp_port *port, + struct sk_buff *skb, + int npages) +{ + int i; + unsigned long flags; + struct page_pool *pool; + struct page_info *info = NULL; + + if (npages != 1) + BUG(); + pool = container_of(port, struct page_pool, port); + + spin_lock_irqsave(&pool->read_lock, flags); + if (!list_empty(&pool->readq)) { + info = list_first_entry(&pool->readq, struct page_info, list); + list_del(&info->list); + } + spin_unlock_irqrestore(&pool->read_lock, flags); + if (!info) + return NULL; + + for (i = 0; i < info->pnum; i++) + get_page(info->pages[i]); + info->skb = skb; + return &info->ext_page; +} + +static struct page_info *mp_hash_lookup(struct page_pool *pool, + struct page *page); +static struct page_info *mp_hash_delete(struct page_pool *pool, + struct page_info *info); + +static struct skb_ext_page *mp_lookup(struct net_device *dev, + struct page *page) +{ + struct mp_struct *mp = + container_of(dev->mp_port->sock->sk, struct mp_sock, sk)->mp; + struct page_pool *pool = mp->pool; + struct page_info *info; + + info = mp_hash_lookup(pool, page); + if (!info) + return NULL; + return &info->ext_page; +} + +struct page_pool *page_pool_create(struct net_device *dev, + struct socket *sock) +{ + struct page_pool *pool; + int rc; + + pool = kzalloc(sizeof(*pool), GFP_KERNEL); + if (!pool) + return NULL; + rc = netdev_mp_port_prep(dev, &pool->port); + if (rc) + goto fail; + + INIT_LIST_HEAD(&pool->readq); + spin_lock_init(&pool->read_lock); + pool->hash_table = + kzalloc(sizeof(struct page_info *) * HASH_BUCKETS, GFP_KERNEL); + if (!pool->hash_table) + goto fail; + dev_hold(dev); + pool->dev = dev; + pool->port.ctor = page_ctor; + pool->port.sock = sock; + pool->port.hash = mp_lookup; + pool->locked_pages = 0; + pool->cur_pages = 0; + pool->orig_locked_vm = 0; + /* should be protected by rtnl_lock() */ + dev->mp_port = &pool->port; + return pool; +fail: + kfree(pool); + return NULL; +} +EXPORT_SYMBOL_GPL(page_pool_create); + +void dev_change_state(struct net_device *dev) +{ + dev_change_flags(dev, dev->flags & (~IFF_UP)); + dev_change_flags(dev, dev->flags | IFF_UP); +} +EXPORT_SYMBOL_GPL(dev_change_state); + +static int mp_page_pool_attach(struct mp_struct *mp, struct page_pool *pool) +{ + int rc = 0; + /* should be protected by mp_mutex */ + if (mp->pool) { + rc = -EBUSY; + goto fail; + } + if (mp->dev != pool->dev) { + rc = -EFAULT; + goto fail; + } + mp->pool = pool; + return 0; +fail: + kfree(pool->hash_table); + kfree(pool); + dev_put(mp->dev); + return rc; +} + +struct page_info *info_dequeue(struct page_pool *pool) +{ + unsigned long flags; + struct page_info *info = NULL; + spin_lock_irqsave(&pool->read_lock, flags); + if (!list_empty(&pool->readq)) { + info = list_first_entry(&pool->readq, + struct page_info, list); + list_del(&info->list); + } + spin_unlock_irqrestore(&pool->read_lock, flags); + return info; +} + +static void mp_ki_dtor(struct kiocb *iocb) +{ + struct page_info *info = (struct page_info *)(iocb->private); + int i; + + if (info->flags == INFO_READ) { + for (i = 0; i < info->pnum; i++) { + if (info->pages[i]) { + set_page_dirty_lock(info->pages[i]); + put_page(info->pages[i]); + } + } + mp_hash_delete(info->pool, info); + if (info->skb) { + info->skb->destructor = NULL; + kfree_skb(info->skb); + } + } + /* Decrement the number of locked pages */ + info->pool->cur_pages -= info->pnum; + kmem_cache_free(ext_page_info_cache, info); + + return; +} + +static struct kiocb *create_iocb(struct page_info *info, int size) +{ + struct kiocb *iocb = NULL; + + iocb = info->iocb; + if (!iocb) + return iocb; + iocb->ki_flags = 0; + iocb->ki_users = 1; + iocb->ki_key = 0; + iocb->ki_ctx = NULL; + iocb->ki_cancel = NULL; + iocb->ki_retry = NULL; + iocb->ki_eventfd = NULL; + iocb->ki_pos = info->desc_pos; + iocb->ki_nbytes = size; + iocb->ki_dtor(iocb); + iocb->private = (void *)info; + iocb->ki_dtor = mp_ki_dtor; + + return iocb; +} + +void page_pool_destroy(struct mm_struct *mm, struct page_pool *pool) +{ + struct page_info *info; + int i; + + while ((info = info_dequeue(pool))) { + for (i = 0; i < info->pnum; i++) + if (info->pages[i]) + put_page(info->pages[i]); + create_iocb(info, 0); + kmem_cache_free(ext_page_info_cache, info); + } + down_write(&mm->mmap_sem); + mm->locked_vm -= pool->locked_pages; + up_write(&mm->mmap_sem); + + /* should be locked by rtnl_lock */ + pool->dev->mp_port = NULL; + dev_put(pool->dev); + kfree(pool->hash_table); + kfree(pool); +} +EXPORT_SYMBOL_GPL(page_pool_destroy); + +static void mp_page_pool_detach(struct mp_struct *mp) +{ + /* locked by mp_mutex */ + if (mp->pool) { + page_pool_destroy(mp->mm, mp->pool); + mp->pool = NULL; + } +} + +static void __mp_detach(struct mp_struct *mp) +{ + mp->mfile = NULL; + + dev_change_flags(mp->dev, mp->dev->flags & ~IFF_UP); + mp_page_pool_detach(mp); + dev_change_flags(mp->dev, mp->dev->flags | IFF_UP); + + /* Drop the extra count on the net device */ + dev_put(mp->dev); +} + +static DEFINE_MUTEX(mp_mutex); + +static void mp_detach(struct mp_struct *mp) +{ + mutex_lock(&mp_mutex); + __mp_detach(mp); + mutex_unlock(&mp_mutex); +} + +static struct mp_struct *mp_get(struct mp_file *mfile) +{ + struct mp_struct *mp = NULL; + if (atomic_inc_not_zero(&mfile->count)) + mp = mfile->mp; + + return mp; +} + +static void mp_put(struct mp_file *mfile) +{ + if (atomic_dec_and_test(&mfile->count)) { + if (!rtnl_is_locked()) { + rtnl_lock(); + mp_detach(mfile->mp); + rtnl_unlock(); + } else + mp_detach(mfile->mp); + } +} + +static void iocb_tag(struct kiocb *iocb) +{ + iocb->ki_flags = 1; +} + +/* The callback to destruct the external buffers or skb */ +static void page_dtor(struct skb_ext_page *ext_page) +{ + struct page_info *info; + struct page_pool *pool; + struct sock *sk; + struct sk_buff *skb; + + if (!ext_page) + return; + info = container_of(ext_page, struct page_info, ext_page); + if (!info) + return; + pool = info->pool; + skb = info->skb; + + if (info->flags == INFO_READ) { + create_iocb(info, 0); + return; + } + + /* For transmit, we should wait for the DMA finish by hardware. + * Queue the notifier to wake up the backend driver + */ + + iocb_tag(info->iocb); + sk = pool->port.sock->sk; + sk->sk_write_space(sk); + + return; +} + +/* For small exteranl buffers transmit, we don't need to call + * get_user_pages(). + */ +static struct page_info *alloc_small_page_info(struct page_pool *pool, + struct kiocb *iocb, int total) +{ + struct page_info *info = + kmem_cache_alloc(ext_page_info_cache, GFP_KERNEL); + + if (!info) + return NULL; + info->ext_page.dtor = page_dtor; + info->pool = pool; + info->flags = INFO_WRITE; + info->iocb = iocb; + info->pnum = 0; + return info; +} + +typedef u32 key_mp_t; +static inline key_mp_t mp_hash(struct page *page, int buckets) +{ + key_mp_t k; +#if BITS_PER_LONG == 64 + k = ((((unsigned long)page << 32UL) >> 32UL) / + sizeof(struct page)) % buckets ; +#elif BITS_PER_LONG == 32 + k = ((unsigned long)page / sizeof(struct page)) % buckets; +#endif + + return k; +} + +static void mp_hash_insert(struct page_pool *pool, + struct page *page, struct page_info *page_info) +{ + struct page_info *tmp; + key_mp_t key = mp_hash(page, HASH_BUCKETS); + if (!pool->hash_table[key]) { + pool->hash_table[key] = page_info; + return; + } + + tmp = pool->hash_table[key]; + while (tmp->next) + tmp = tmp->next; + + tmp->next = page_info; + page_info->prev = tmp; + return; +} + +static struct page_info *mp_hash_delete(struct page_pool *pool, + struct page_info *info) +{ + key_mp_t key = mp_hash(info->pages[0], HASH_BUCKETS); + struct page_info *tmp = NULL; + + tmp = pool->hash_table[key]; + while (tmp) { + if (tmp == info) { + if (!tmp->prev) { + pool->hash_table[key] = tmp->next; + if (tmp->next) + tmp->next->prev = NULL; + } else { + tmp->prev->next = tmp->next; + if (tmp->next) + tmp->next->prev = tmp->prev; + } + return tmp; + } + tmp = tmp->next; + } + return tmp; +} + +static struct page_info *mp_hash_lookup(struct page_pool *pool, + struct page *page) +{ + key_mp_t key = mp_hash(page, HASH_BUCKETS); + struct page_info *tmp = NULL; + + int i; + tmp = pool->hash_table[key]; + while (tmp) { + for (i = 0; i < tmp->pnum; i++) { + if (tmp->pages[i] == page) + return tmp; + } + tmp = tmp->next; + } + return tmp; +} + +/* The main function to transform the guest user space address + * to host kernel address via get_user_pages(). Thus the hardware + * can do DMA directly to the external buffer address. + */ +static struct page_info *alloc_page_info(struct page_pool *pool, + struct kiocb *iocb, struct iovec *iov, + int count, struct frag *frags, + int npages, int total) +{ + int rc; + int i, j, n = 0; + int len; + unsigned long base; + struct page_info *info = NULL; + + if (pool->cur_pages + count > pool->locked_pages) { + printk(KERN_INFO "Exceed memory lock rlimt."); + return NULL; + } + + info = kmem_cache_alloc(ext_page_info_cache, GFP_KERNEL); + + if (!info) + return NULL; + info->skb = NULL; + info->next = info->prev = NULL; + + for (i = j = 0; i < count; i++) { + base = (unsigned long)iov[i].iov_base; + len = iov[i].iov_len; + + if (!len) + continue; + n = ((base & ~PAGE_MASK) + len + ~PAGE_MASK) >> PAGE_SHIFT; + + rc = get_user_pages_fast(base, n, npages ? 1 : 0, + &info->pages[j]); + if (rc != n) + goto failed; + + while (n--) { + frags[j].offset = base & ~PAGE_MASK; + frags[j].size = min_t(int, len, + PAGE_SIZE - frags[j].offset); + len -= frags[j].size; + base += frags[j].size; + j++; + } + } + +#ifdef CONFIG_HIGHMEM + if (npages && !(dev->features & NETIF_F_HIGHDMA)) { + for (i = 0; i < j; i++) { + if (PageHighMem(info->pages[i])) + goto failed; + } + } +#endif + + info->ext_page.dtor = page_dtor; + info->ext_page.page = info->pages[0]; + info->pool = pool; + info->pnum = j; + info->iocb = iocb; + if (!npages) + info->flags = INFO_WRITE; + else + info->flags = INFO_READ; + + if (info->flags == INFO_READ) { + if (frags[0].offset == 0 && iocb->ki_iovec[0].iov_len) { + frags[0].offset = iocb->ki_iovec[0].iov_len; + pool->port.vnet_hlen = iocb->ki_iovec[0].iov_len; + } + for (i = 0; i < j; i++) + mp_hash_insert(pool, info->pages[i], info); + } + /* increment the number of locked pages */ + pool->cur_pages += j; + return info; + +failed: + for (i = 0; i < j; i++) + put_page(info->pages[i]); + + kmem_cache_free(ext_page_info_cache, info); + + return NULL; +} + +static void mp_sock_destruct(struct sock *sk) +{ + struct mp_struct *mp = container_of(sk, struct mp_sock, sk)->mp; + kfree(mp); +} + +static void mp_sock_state_change(struct sock *sk) +{ + wait_queue_head_t *wqueue = sk_sleep(sk); + if (wqueue && waitqueue_active(wqueue)) + wake_up_interruptible_sync_poll(wqueue, POLLIN); +} + +static void mp_sock_write_space(struct sock *sk) +{ + wait_queue_head_t *wqueue = sk_sleep(sk); + if (wqueue && waitqueue_active(wqueue)) + wake_up_interruptible_sync_poll(wqueue, POLLOUT); +} + +void async_data_ready(struct sock *sk, struct page_pool *pool) +{ + struct sk_buff *skb = NULL; + struct page_info *info = NULL; + int len; + + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { + struct page *page; + int off; + int size = 0, i = 0; + struct skb_shared_info *shinfo = skb_shinfo(skb); + struct skb_ext_page *ext_page = + (struct skb_ext_page *)(shinfo->destructor_arg); + struct virtio_net_hdr_mrg_rxbuf hdr = { + .hdr.flags = 0, + .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE + }; + + if (skb->ip_summed == CHECKSUM_COMPLETE) + printk(KERN_INFO "Complete checksum occurs\n"); + + if (shinfo->frags[0].page == ext_page->page) { + info = container_of(ext_page, + struct page_info, + ext_page); + if (shinfo->nr_frags) + hdr.num_buffers = shinfo->nr_frags; + else + hdr.num_buffers = shinfo->nr_frags + 1; + } else { + info = container_of(ext_page, + struct page_info, + ext_page); + hdr.num_buffers = shinfo->nr_frags + 1; + } + skb_push(skb, ETH_HLEN); + + if (skb_is_gso(skb)) { + hdr.hdr.hdr_len = skb_headlen(skb); + hdr.hdr.gso_size = shinfo->gso_size; + if (shinfo->gso_type & SKB_GSO_TCPV4) + hdr.hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV4; + else if (shinfo->gso_type & SKB_GSO_TCPV6) + hdr.hdr.gso_type = VIRTIO_NET_HDR_GSO_TCPV6; + else if (shinfo->gso_type & SKB_GSO_UDP) + hdr.hdr.gso_type = VIRTIO_NET_HDR_GSO_UDP; + else + BUG(); + if (shinfo->gso_type & SKB_GSO_TCP_ECN) + hdr.hdr.gso_type |= VIRTIO_NET_HDR_GSO_ECN; + + } else + hdr.hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE; + + if (skb->ip_summed == CHECKSUM_PARTIAL) { + hdr.hdr.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM; + hdr.hdr.csum_start = + skb->csum_start - skb_headroom(skb); + hdr.hdr.csum_offset = skb->csum_offset; + } + + off = info->hdr[0].iov_len; + len = memcpy_toiovec(info->iov, (unsigned char *)&hdr, off); + if (len) { + pr_debug("Unable to write vnet_hdr at addr '%p': '%d'\n", + info->iov, len); + goto clean; + } + + memcpy_toiovec(info->iov, skb->data, skb_headlen(skb)); + + info->iocb->ki_left = hdr.num_buffers; + if (shinfo->frags[0].page == ext_page->page) { + size = shinfo->frags[0].size + + shinfo->frags[0].page_offset - off; + i = 1; + } else { + size = skb_headlen(skb); + i = 0; + } + create_iocb(info, off + size); + for (i = i; i < shinfo->nr_frags; i++) { + page = shinfo->frags[i].page; + info = mp_hash_lookup(pool, shinfo->frags[i].page); + create_iocb(info, shinfo->frags[i].size); + } + info->skb = skb; + shinfo->nr_frags = 0; + shinfo->destructor_arg = NULL; + continue; +clean: + kfree_skb(skb); + for (i = 0; i < info->pnum; i++) + put_page(info->pages[i]); + kmem_cache_free(ext_page_info_cache, info); + } + return; +} +EXPORT_SYMBOL_GPL(async_data_ready); + +static void mp_sock_data_ready(struct sock *sk, int coming) +{ + struct mp_struct *mp = container_of(sk, struct mp_sock, sk)->mp; + struct page_pool *pool = NULL; + + pool = mp->pool; + if (!pool) + return; + return async_data_ready(sk, pool); +} + +static inline struct sk_buff *mp_alloc_skb(struct sock *sk, size_t prepad, + size_t len, size_t linear, + int noblock, int *err) +{ + struct sk_buff *skb; + + /* Under a page? Don't bother with paged skb. */ + if (prepad + len < PAGE_SIZE || !linear) + linear = len; + + skb = sock_alloc_send_pskb(sk, prepad + linear, len - linear, noblock, + err); + if (!skb) + return NULL; + + skb_reserve(skb, prepad); + skb_put(skb, linear); + skb->data_len = len - linear; + skb->len += len - linear; + + return skb; +} + +static int mp_skb_from_vnet_hdr(struct sk_buff *skb, + struct virtio_net_hdr *vnet_hdr) +{ + unsigned short gso_type = 0; + if (vnet_hdr->gso_type != VIRTIO_NET_HDR_GSO_NONE) { + switch (vnet_hdr->gso_type & ~VIRTIO_NET_HDR_GSO_ECN) { + case VIRTIO_NET_HDR_GSO_TCPV4: + gso_type = SKB_GSO_TCPV4; + break; + case VIRTIO_NET_HDR_GSO_TCPV6: + gso_type = SKB_GSO_TCPV6; + break; + case VIRTIO_NET_HDR_GSO_UDP: + gso_type = SKB_GSO_UDP; + break; + default: + return -EINVAL; + } + + if (vnet_hdr->gso_type & VIRTIO_NET_HDR_GSO_ECN) + gso_type |= SKB_GSO_TCP_ECN; + + if (vnet_hdr->gso_size == 0) + return -EINVAL; + } + + if (vnet_hdr->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) { + if (!skb_partial_csum_set(skb, vnet_hdr->csum_start, + vnet_hdr->csum_offset)) + return -EINVAL; + } + + if (vnet_hdr->gso_type != VIRTIO_NET_HDR_GSO_NONE) { + skb_shinfo(skb)->gso_size = vnet_hdr->gso_size; + skb_shinfo(skb)->gso_type = gso_type; + + /* Header must be checked, and gso_segs computed. */ + skb_shinfo(skb)->gso_type |= SKB_GSO_DODGY; + skb_shinfo(skb)->gso_segs = 0; + } + return 0; +} + +int async_sendmsg(struct sock *sk, struct kiocb *iocb, struct page_pool *pool, + struct iovec *iov, int count) +{ + struct virtio_net_hdr vnet_hdr = {0}; + int hdr_len = 0; + struct page_info *info = NULL; + struct frag frags[MAX_SKB_FRAGS]; + struct sk_buff *skb; + int total = 0, header, n, i, len, rc; + unsigned long base; + + total = iov_length(iov, count); + + if (total < ETH_HLEN) + return -EINVAL; + + if (total <= COPY_THRESHOLD) + goto copy; + + n = 0; + for (i = 0; i < count; i++) { + base = (unsigned long)iov[i].iov_base; + len = iov[i].iov_len; + if (!len) + continue; + n += ((base & ~PAGE_MASK) + len + ~PAGE_MASK) >> PAGE_SHIFT; + if (n > MAX_SKB_FRAGS) + return -EINVAL; + } + +copy: + hdr_len = sizeof(vnet_hdr); + if ((total - iocb->ki_iovec[0].iov_len) < 0) + return -EINVAL; + + rc = memcpy_fromiovecend((void *)&vnet_hdr, iocb->ki_iovec, 0, hdr_len); + if (rc < 0) + return -EINVAL; + + if ((vnet_hdr.flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) && + vnet_hdr.csum_start + vnet_hdr.csum_offset + 2 > + vnet_hdr.hdr_len) + vnet_hdr.hdr_len = vnet_hdr.csum_start + + vnet_hdr.csum_offset + 2; + + if (vnet_hdr.hdr_len > total) + return -EINVAL; + + header = total > COPY_THRESHOLD ? COPY_HDR_LEN : total; + + skb = mp_alloc_skb(sk, NET_IP_ALIGN, header, + iocb->ki_iovec[0].iov_len, 1, &rc); + if (!skb) + goto drop; + + skb_set_network_header(skb, ETH_HLEN); + memcpy_fromiovec(skb->data, iov, header); + + skb_reset_mac_header(skb); + skb->protocol = eth_hdr(skb)->h_proto; + + rc = mp_skb_from_vnet_hdr(skb, &vnet_hdr); + if (rc) + goto drop; + + if (header == total) { + rc = total; + info = alloc_small_page_info(pool, iocb, total); + } else { + info = alloc_page_info(pool, iocb, iov, count, frags, 0, total); + if (info) + for (i = 0; i < info->pnum; i++) { + skb_add_rx_frag(skb, i, info->pages[i], + frags[i].offset, frags[i].size); + info->pages[i] = NULL; + } + } + if (!pool->cur_pages) + sk->sk_state_change(sk); + + if (info != NULL) { + info->desc_pos = iocb->ki_pos; + info->skb = skb; + skb_shinfo(skb)->destructor_arg = &info->ext_page; + skb->dev = pool->dev; + create_iocb(info, total); + dev_queue_xmit(skb); + return 0; + } +drop: + kfree_skb(skb); + if (info) { + for (i = 0; i < info->pnum; i++) + put_page(info->pages[i]); + kmem_cache_free(ext_page_info_cache, info); + } + pool->dev->stats.tx_dropped++; + return -ENOMEM; +} +EXPORT_SYMBOL_GPL(async_sendmsg); + +static int mp_sendmsg(struct kiocb *iocb, struct socket *sock, + struct msghdr *m, size_t total_len) +{ + struct mp_struct *mp = container_of(sock->sk, struct mp_sock, sk)->mp; + struct page_pool *pool; + struct iovec *iov = m->msg_iov; + int count = m->msg_iovlen; + + pool = mp->pool; + if (!pool) + return -ENODEV; + return async_sendmsg(sock->sk, iocb, pool, iov, count); +} + +int async_recvmsg(struct kiocb *iocb, struct page_pool *pool, + struct iovec *iov, int count, int flags) +{ + int npages, payload; + struct page_info *info; + struct frag frags[MAX_SKB_FRAGS]; + unsigned long base; + int i, len; + unsigned long flag; + + if (!(flags & MSG_DONTWAIT)) + return -EINVAL; + + if (!pool) + return -EINVAL; + + /* Error detections in case invalid external buffer */ + if (count > 2 && iov[1].iov_len < pool->port.hdr_len && + pool->dev->features & NETIF_F_SG) { + return -EINVAL; + } + + npages = pool->port.npages; + payload = pool->port.data_len; + + /* If KVM guest virtio-net FE driver use SG feature */ + if (count > 2) { + for (i = 2; i < count; i++) { + base = (unsigned long)iov[i].iov_base & ~PAGE_MASK; + len = iov[i].iov_len; + if (npages == 1) + len = min_t(int, len, PAGE_SIZE - base); + else if (base) + break; + payload -= len; + if (payload <= 0) + goto proceed; + if (npages == 1 || (len & ~PAGE_MASK)) + break; + } + } + + if ((((unsigned long)iov[1].iov_base & ~PAGE_MASK) + - NET_SKB_PAD - NET_IP_ALIGN) >= 0) + goto proceed; + + return -EINVAL; +proceed: + /* skip the virtnet head */ + if (count > 1) { + iov++; + count--; + } + + /* Translate address to kernel */ + info = alloc_page_info(pool, iocb, iov, count, frags, npages, 0); + if (!info) + return -ENOMEM; + info->hdr[0].iov_base = iocb->ki_iovec[0].iov_base; + info->hdr[0].iov_len = iocb->ki_iovec[0].iov_len; + iocb->ki_iovec[0].iov_len = 0; + iocb->ki_left = 0; + info->desc_pos = iocb->ki_pos; + + if (count > 1) { + iov--; + count++; + } + + memcpy(info->iov, iov, sizeof(struct iovec) * count); + + spin_lock_irqsave(&pool->read_lock, flag); + list_add_tail(&info->list, &pool->readq); + spin_unlock_irqrestore(&pool->read_lock, flag); + + return 0; +} +EXPORT_SYMBOL_GPL(async_recvmsg); + +static int mp_recvmsg(struct kiocb *iocb, struct socket *sock, + struct msghdr *m, size_t total_len, + int flags) +{ + struct mp_struct *mp = container_of(sock->sk, struct mp_sock, sk)->mp; + struct page_pool *pool; + struct iovec *iov = m->msg_iov; + int count = m->msg_iovlen; + + pool = mp->pool; + if (!pool) + return -EINVAL; + + return async_recvmsg(iocb, pool, iov, count, flags); +} + +/* Ops structure to mimic raw sockets with mp device */ +static const struct proto_ops mp_socket_ops = { + .sendmsg = mp_sendmsg, + .recvmsg = mp_recvmsg, +}; + +static struct proto mp_proto = { + .name = "mp", + .owner = THIS_MODULE, + .obj_size = sizeof(struct mp_sock), +}; + +static int mp_chr_open(struct inode *inode, struct file * file) +{ + struct mp_file *mfile; + cycle_kernel_lock(); + + pr_debug("mp: mp_chr_open\n"); + mfile = kzalloc(sizeof(*mfile), GFP_KERNEL); + if (!mfile) + return -ENOMEM; + atomic_set(&mfile->count, 0); + mfile->mp = NULL; + mfile->net = get_net(current->nsproxy->net_ns); + file->private_data = mfile; + return 0; +} + +static int mp_attach(struct mp_struct *mp, struct file *file) +{ + struct mp_file *mfile = file->private_data; + int err; + + netif_tx_lock_bh(mp->dev); + + err = -EINVAL; + + if (mfile->mp) + goto out; + + err = -EBUSY; + if (mp->mfile) + goto out; + + err = 0; + mfile->mp = mp; + mp->mfile = mfile; + mp->socket.file = file; + dev_hold(mp->dev); + sock_hold(mp->socket.sk); + atomic_inc(&mfile->count); + +out: + netif_tx_unlock_bh(mp->dev); + return err; +} + +static int do_unbind(struct mp_file *mfile) +{ + struct mp_struct *mp = mp_get(mfile); + + if (!mp) + return -EINVAL; + + mp_detach(mp); + sock_put(mp->socket.sk); + mp_put(mfile); + return 0; +} + +static long mp_chr_ioctl(struct file *file, unsigned int cmd, + unsigned long arg) +{ + struct mp_file *mfile = file->private_data; + struct mp_struct *mp; + struct net_device *dev; + struct page_pool *pool; + void __user* argp = (void __user *)arg; + unsigned long __user *limitp = argp; + struct ifreq ifr; + struct sock *sk; + unsigned long limit, locked, lock_limit; + int ret; + + ret = -EINVAL; + + switch (cmd) { + case MPASSTHRU_BINDDEV: + ret = -EFAULT; + if (copy_from_user(&ifr, argp, sizeof ifr)) + break; + + ifr.ifr_name[IFNAMSIZ-1] = '\0'; + + ret = -ENODEV; + + rtnl_lock(); + dev = dev_get_by_name(mfile->net, ifr.ifr_name); + if (!dev) { + rtnl_unlock(); + break; + } + + mutex_lock(&mp_mutex); + + ret = -EBUSY; + + /* the device can be only bind once */ + if (dev_is_mpassthru(dev)) + goto err_dev_put; + + ret = -EFAULT; + + if (!(dev->features & NETIF_F_SG)) { + pr_debug("The device has no SG features.\n"); + goto err_dev_put; + } + mp = mfile->mp; + if (mp) + goto err_dev_put; + + mp = kzalloc(sizeof(*mp), GFP_KERNEL); + if (!mp) { + ret = -ENOMEM; + goto err_dev_put; + } + mp->dev = dev; + mp->mm = get_task_mm(current); + ret = -ENOMEM; + + sk = sk_alloc(mfile->net, AF_UNSPEC, GFP_KERNEL, &mp_proto); + if (!sk) + goto err_free_mp; + + mp->socket.wq = &mp->wq; + init_waitqueue_head(&mp->wq.wait); + mp->socket.ops = &mp_socket_ops; + sock_init_data(&mp->socket, sk); + sk->sk_sndbuf = INT_MAX; + container_of(sk, struct mp_sock, sk)->mp = mp; + + sk->sk_destruct = mp_sock_destruct; + sk->sk_data_ready = mp_sock_data_ready; + sk->sk_write_space = mp_sock_write_space; + sk->sk_state_change = mp_sock_state_change; + ret = mp_attach(mp, file); + if (ret < 0) + goto err_free_sk; + pool = page_pool_create(dev, &mp->socket); + if (!pool) + goto err_free_sk; + + ret = mp_page_pool_attach(mp, pool); + if (ret < 0) + goto err_free_sk; + dev_change_state(dev); +out: + mutex_unlock(&mp_mutex); + rtnl_unlock(); + break; +err_free_sk: + sk_free(sk); +err_free_mp: + kfree(mp); +err_dev_put: + dev_put(dev); + goto out; + + case MPASSTHRU_UNBINDDEV: + rtnl_lock(); + ret = do_unbind(mfile); + rtnl_unlock(); + break; + + case MPASSTHRU_SET_MEM_LOCKED: + ret = copy_from_user(&limit, limitp, sizeof limit); + if (ret < 0) + return ret; + + mp = mp_get(mfile); + if (!mp) + return -ENODEV; + + mutex_lock(&mp_mutex); + if (mp->mm != current->mm) { + mutex_unlock(&mp_mutex); + return -EPERM; + } + + limit = PAGE_ALIGN(limit) >> PAGE_SHIFT; + down_write(&mp->mm->mmap_sem); + if (!mp->pool->locked_pages) + mp->pool->orig_locked_vm = mp->mm->locked_vm; + locked = limit + mp->pool->orig_locked_vm; + lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT; + + if ((locked > lock_limit) && !capable(CAP_IPC_LOCK)) { + up_write(&mp->mm->mmap_sem); + mutex_unlock(&mp_mutex); + mp_put(mfile); + return -ENOMEM; + } + mp->mm->locked_vm = locked; + up_write(&mp->mm->mmap_sem); + + mp->pool->locked_pages = limit; + mutex_unlock(&mp_mutex); + + mp_put(mfile); + return 0; + + case MPASSTHRU_GET_MEM_LOCKED_NEED: + limit = DEFAULT_NEED; + return copy_to_user(limitp, &limit, sizeof limit); + + + default: + break; + } + return ret; +} + +static unsigned int mp_chr_poll(struct file *file, poll_table * wait) +{ + struct mp_file *mfile = file->private_data; + struct mp_struct *mp = mp_get(mfile); + struct sock *sk; + unsigned int mask = 0; + + if (!mp) + return POLLERR; + + sk = mp->socket.sk; + + poll_wait(file, &mp->wq.wait, wait); + + if (!skb_queue_empty(&sk->sk_receive_queue)) + mask |= POLLIN | POLLRDNORM; + + if (sock_writeable(sk) || + (!test_and_set_bit(SOCK_ASYNC_NOSPACE, &sk->sk_socket->flags) && + sock_writeable(sk))) + mask |= POLLOUT | POLLWRNORM; + + if (mp->dev->reg_state != NETREG_REGISTERED) + mask = POLLERR; + + mp_put(mfile); + return mask; +} + +static ssize_t mp_chr_aio_write(struct kiocb *iocb, const struct iovec *iov, + unsigned long count, loff_t pos) +{ + struct file *file = iocb->ki_filp; + struct mp_struct *mp = mp_get(file->private_data); + struct sock *sk = mp->socket.sk; + struct sk_buff *skb; + int len, err; + ssize_t result = 0; + + if (!mp) + return -EBADFD; + + /* currently, async is not supported. + * but we may support real async aio from user application, + * maybe qemu virtio-net backend. + */ + if (!is_sync_kiocb(iocb)) + return -EFAULT; + + len = iov_length(iov, count); + + if (unlikely(len < ETH_HLEN)) + return -EINVAL; + + skb = sock_alloc_send_skb(sk, len + NET_IP_ALIGN, + file->f_flags & O_NONBLOCK, &err); + + if (!skb) + return -ENOMEM; + + skb_reserve(skb, NET_IP_ALIGN); + skb_put(skb, len); + + if (skb_copy_datagram_from_iovec(skb, 0, iov, 0, len)) { + kfree_skb(skb); + return -EAGAIN; + } + + skb->protocol = eth_type_trans(skb, mp->dev); + skb->dev = mp->dev; + + dev_queue_xmit(skb); + + mp_put(file->private_data); + return result; +} + +static int mp_chr_close(struct inode *inode, struct file *file) +{ + struct mp_file *mfile = file->private_data; + + /* + * Ignore return value since an error only means there was nothing to + * do + */ + rtnl_lock(); + do_unbind(mfile); + rtnl_unlock(); + put_net(mfile->net); + kfree(mfile); + + return 0; +} + +#ifdef CONFIG_COMPAT +static long mp_chr_compat_ioctl(struct file *f, unsigned int ioctl, + unsigned long arg) +{ + return mp_chr_ioctl(f, ioctl, (unsigned long)compat_ptr(arg)); +} +#endif + +static const struct file_operations mp_fops = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .write = do_sync_write, + .aio_write = mp_chr_aio_write, + .poll = mp_chr_poll, + .unlocked_ioctl = mp_chr_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = mp_chr_compat_ioctl, +#endif + .open = mp_chr_open, + .release = mp_chr_close, +}; + +static struct miscdevice mp_miscdev = { + .minor = MISC_DYNAMIC_MINOR, + .name = "mp", + .nodename = "net/mp", + .fops = &mp_fops, +}; + +static int mp_device_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = ptr; + struct mp_port *port; + struct mp_struct *mp = NULL; + struct socket *sock = NULL; + struct sock *sk; + + port = dev->mp_port; + if (port == NULL) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_UNREGISTER: + sock = dev->mp_port->sock; + mp = container_of(sock->sk, struct mp_sock, sk)->mp; + do_unbind(mp->mfile); + break; + case NETDEV_CHANGE: + sk = dev->mp_port->sock->sk; + sk->sk_state_change(sk); + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block mp_notifier_block __read_mostly = { + .notifier_call = mp_device_event, +}; + +static int mp_init(void) +{ + int err = 0; + + ext_page_info_cache = kmem_cache_create("skb_page_info", + sizeof(struct page_info), + 0, SLAB_HWCACHE_ALIGN, NULL); + if (!ext_page_info_cache) + return -ENOMEM; + + err = misc_register(&mp_miscdev); + if (err) { + printk(KERN_ERR "mp: Can't register misc device\n"); + kmem_cache_destroy(ext_page_info_cache); + } else { + printk(KERN_INFO "Registering mp misc device - minor = %d\n", + mp_miscdev.minor); + register_netdevice_notifier(&mp_notifier_block); + } + return err; +} + +void mp_exit(void) +{ + unregister_netdevice_notifier(&mp_notifier_block); + misc_deregister(&mp_miscdev); + kmem_cache_destroy(ext_page_info_cache); +} + +/* Get an underlying socket object from mp file. Returns error unless file is + * attached to a device. The returned object works like a packet socket, it + * can be used for sock_sendmsg/sock_recvmsg. The caller is responsible for + * holding a reference to the file for as long as the socket is in use. */ +struct socket *mp_get_socket(struct file *file) +{ + struct mp_file *mfile = file->private_data; + struct mp_struct *mp; + + if (file->f_op != &mp_fops) + return ERR_PTR(-EINVAL); + mp = mp_get(mfile); + if (!mp) + return ERR_PTR(-EBADFD); + mp_put(mfile); + return &mp->socket; +} +EXPORT_SYMBOL_GPL(mp_get_socket); + +module_init(mp_init); +module_exit(mp_exit); +MODULE_AUTHOR(DRV_COPYRIGHT); +MODULE_DESCRIPTION(DRV_DESCRIPTION); +MODULE_LICENSE("GPL v2");