diff mbox

[v5,5/5] shared memory server for inter-VM shared memory

Message ID 1271872850-23350-1-git-send-email-cam@cs.ualberta.ca (mailing list archive)
State New, archived
Headers show

Commit Message

Cam Macdonell April 21, 2010, 6 p.m. UTC
None
diff mbox

Patch

diff --git a/contrib/ivshmem-server/Makefile b/contrib/ivshmem-server/Makefile
new file mode 100644
index 0000000..da40ffa
--- /dev/null
+++ b/contrib/ivshmem-server/Makefile
@@ -0,0 +1,16 @@ 
+CC = gcc
+CFLAGS = -O3 -Wall -Werror
+LIBS = -lrt
+
+# a very simple makefile to build the inter-VM shared memory server
+
+all: ivshmem_server
+
+.c.o:
+	$(CC) $(CFLAGS) -c $^ -o $@
+
+ivshmem_server: ivshmem_server.o send_scm.o
+	$(CC) $(CFLAGS) -o $@ $^ $(LIBS)
+
+clean:
+	rm -f *.o ivshmem_server
diff --git a/contrib/ivshmem-server/README b/contrib/ivshmem-server/README
new file mode 100644
index 0000000..b1fc2a2
--- /dev/null
+++ b/contrib/ivshmem-server/README
@@ -0,0 +1,30 @@ 
+Using the ivshmem shared memory server
+--------------------------------------
+
+This server is only supported on Linux.
+
+To use the shared memory server, first compile it.  Running 'make' should
+accomplish this.  An executable named 'ivshmem_server' will be built.
+
+to display the options run:
+
+./ivshmem_server -h
+
+Options
+-------
+
+    -h  print help message
+
+    -p <path on host>
+        unix socket to listen on.  The qemu-kvm chardev needs to connect on
+        this socket. (default: '/tmp/ivshmem_socket')
+
+    -s <string>
+        POSIX shared object to create that is the shared memory (default: 'ivshmem')
+
+    -m <#>
+        size of the POSIX object in MBs (default: 1)
+
+    -n <#>
+        number of eventfds for each guest.  This number must match the
+        'vectors' argument passed the ivshmem device. (default: 1)
diff --git a/contrib/ivshmem-server/ivshmem_server.c b/contrib/ivshmem-server/ivshmem_server.c
new file mode 100644
index 0000000..2dbf76f
--- /dev/null
+++ b/contrib/ivshmem-server/ivshmem_server.c
@@ -0,0 +1,339 @@ 
+/*
+ * A stand-alone shared memory server for inter-VM shared memory for KVM
+*/
+
+#include <errno.h>
+#include <string.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <unistd.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <sys/eventfd.h>
+#include <sys/mman.h>
+#include <sys/select.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include "send_scm.h"
+
+#define DEFAULT_SOCK_PATH "/tmp/ivshmem_socket"
+#define DEFAULT_SHM_OBJ "ivshmem"
+
+#define DEBUG 1
+
+typedef struct server_state {
+    vmguest_t *live_vms;
+    int nr_allocated_vms;
+    int shm_size;
+    long live_count;
+    long total_count;
+    int shm_fd;
+    char * path;
+    char * shmobj;
+    int maxfd, conn_socket;
+    long msi_vectors;
+} server_state_t;
+
+void usage(char const *prg);
+int find_set(fd_set * readset, int max);
+void print_vec(server_state_t * s, const char * c);
+
+void add_new_guest(server_state_t * s);
+void parse_args(int argc, char **argv, server_state_t * s);
+int create_listening_socket(char * path);
+
+int main(int argc, char ** argv)
+{
+    fd_set readset;
+    server_state_t * s;
+
+    s = (server_state_t *)calloc(1, sizeof(server_state_t));
+
+    s->live_count = 0;
+    s->total_count = 0;
+    parse_args(argc, argv, s);
+
+    /* open shared memory file  */
+    if ((s->shm_fd = shm_open(s->shmobj, O_CREAT|O_RDWR, S_IRWXU)) < 0)
+    {
+        fprintf(stderr, "kvm_ivshmem: could not open shared file\n");
+        exit(-1);
+    }
+
+    ftruncate(s->shm_fd, s->shm_size);
+
+    s->conn_socket = create_listening_socket(s->path);
+
+    s->maxfd = s->conn_socket;
+
+    for(;;) {
+        int ret, handle, i;
+        char buf[1024];
+
+        print_vec(s, "vm_sockets");
+
+        FD_ZERO(&readset);
+        /* conn socket is in Live_vms at posn 0 */
+        FD_SET(s->conn_socket, &readset);
+        for (i = 0; i < s->total_count; i++) {
+            if (s->live_vms[i].alive != 0) {
+                FD_SET(s->live_vms[i].sockfd, &readset);
+            }
+        }
+
+        printf("\nWaiting (maxfd = %d)\n", s->maxfd);
+
+        ret = select(s->maxfd + 1, &readset, NULL, NULL, NULL);
+
+        if (ret == -1) {
+            perror("select()");
+        }
+
+        handle = find_set(&readset, s->maxfd + 1);
+        if (handle == -1) continue;
+
+        if (handle == s->conn_socket) {
+
+            printf("[NC] new connection\n");
+            FD_CLR(s->conn_socket, &readset);
+
+            /* The Total_count is equal to the new guests VM ID */
+            add_new_guest(s);
+
+            /* update our the maximum file descriptor number */
+            s->maxfd = s->live_vms[s->total_count - 1].sockfd > s->maxfd ?
+                            s->live_vms[s->total_count - 1].sockfd : s->maxfd;
+
+            s->live_count++;
+            printf("Live_count is %ld\n", s->live_count);
+
+        } else {
+            /* then we have received a disconnection */
+            int recv_ret;
+            long i, j;
+            long deadposn = -1;
+
+            recv_ret = recv(handle, buf, 1, 0);
+
+            printf("[DC] recv returned %d\n", recv_ret);
+
+            /* find the dead VM in our list and move it do the dead list. */
+            for (i = 0; i < s->total_count; i++) {
+                if (s->live_vms[i].sockfd == handle) {
+                    deadposn = i;
+                    s->live_vms[i].alive = 0;
+                    close(s->live_vms[i].sockfd);
+
+                    for (j = 0; j < s->msi_vectors; j++) {
+                        close(s->live_vms[i].efd[j]);
+                    }
+
+                    free(s->live_vms[i].efd);
+                    s->live_vms[i].sockfd = -1;
+                    break;
+                }
+            }
+
+            for (j = 0; j < s->total_count; j++) {
+                /* update remaining clients that one client has left/died */
+                if (s->live_vms[j].alive) {
+                    printf("[UD] sending kill of fd[%ld] to %ld\n",
+                                                                deadposn, j);
+                    sendKill(s->live_vms[j].sockfd, deadposn, sizeof(deadposn));
+                }
+            }
+
+            s->live_count--;
+
+            /* close the socket for the departed VM */
+            close(handle);
+        }
+
+    }
+
+    return 0;
+}
+
+void add_new_guest(server_state_t * s) {
+
+    struct sockaddr_un remote;
+    socklen_t t = sizeof(remote);
+    long i, j;
+    int vm_sock;
+    long new_posn;
+    long neg1 = -1;
+
+    vm_sock = accept(s->conn_socket, (struct sockaddr *)&remote, &t);
+
+    if ( vm_sock == -1 ) {
+        perror("accept");
+        exit(1);
+    }
+
+    new_posn = s->total_count;
+
+    if (new_posn == s->nr_allocated_vms) {
+        printf("increasing vm slots\n");
+        s->nr_allocated_vms = s->nr_allocated_vms * 2;
+        if (s->nr_allocated_vms < 16)
+            s->nr_allocated_vms = 16;
+        s->live_vms = realloc(s->live_vms,
+                    s->nr_allocated_vms * sizeof(vmguest_t));
+
+        if (s->live_vms == NULL) {
+            fprintf(stderr, "realloc failed - quitting\n");
+            exit(-1);
+        }
+    }
+
+    s->live_vms[new_posn].posn = new_posn;
+    printf("[NC] Live_vms[%ld]\n", new_posn);
+    s->live_vms[new_posn].efd = (int *) malloc(sizeof(int));
+    for (i = 0; i < s->msi_vectors; i++) {
+        s->live_vms[new_posn].efd[i] = eventfd(0, 0);
+        printf("\tefd[%ld] = %d\n", i, s->live_vms[new_posn].efd[i]);
+    }
+    s->live_vms[new_posn].sockfd = vm_sock;
+    s->live_vms[new_posn].alive = 1;
+
+
+    sendPosition(vm_sock, new_posn);
+    sendUpdate(vm_sock, neg1, sizeof(long), s->shm_fd);
+    printf("[NC] trying to send fds to new connection\n");
+    sendRights(vm_sock, new_posn, sizeof(new_posn), s->live_vms, s->msi_vectors);
+
+    printf("[NC] Connected (count = %ld).\n", new_posn);
+    for (i = 0; i < new_posn; i++) {
+        if (s->live_vms[i].alive) {
+            // ping all clients that a new client has joined
+            printf("[UD] sending fd[%ld] to %ld\n", new_posn, i);
+            for (j = 0; j < s->msi_vectors; j++) {
+                printf("\tefd[%ld] = [%d]", j, s->live_vms[new_posn].efd[j]);
+                sendUpdate(s->live_vms[i].sockfd, new_posn,
+                        sizeof(new_posn), s->live_vms[new_posn].efd[j]);
+            }
+            printf("\n");
+        }
+    }
+
+    s->total_count++;
+}
+
+int create_listening_socket(char * path) {
+
+    struct sockaddr_un local;
+    int len, conn_socket;
+
+    if ((conn_socket = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) {
+        perror("socket");
+        exit(1);
+    }
+
+    local.sun_family = AF_UNIX;
+    strcpy(local.sun_path, path);
+    unlink(local.sun_path);
+    len = strlen(local.sun_path) + sizeof(local.sun_family);
+    if (bind(conn_socket, (struct sockaddr *)&local, len) == -1) {
+        perror("bind");
+        exit(1);
+    }
+
+    if (listen(conn_socket, 5) == -1) {
+        perror("listen");
+        exit(1);
+    }
+
+    return conn_socket;
+
+}
+
+void parse_args(int argc, char **argv, server_state_t * s) {
+
+    int c;
+
+    s->shm_size = 1024 * 1024; // default shm_size
+    s->path = NULL;
+    s->shmobj = NULL;
+    s->msi_vectors = 1;
+
+	while ((c = getopt(argc, argv, "hp:s:m:n:")) != -1) {
+
+        switch (c) {
+            // path to listening socket
+            case 'p':
+                s->path = optarg;
+                break;
+            // name of shared memory object
+            case 's':
+                s->shmobj = optarg;
+                break;
+            // size of shared memory object
+            case 'm':
+                s->shm_size = atol(optarg)*1024*1024;
+                break;
+            case 'n':
+                s->msi_vectors = atol(optarg);
+                break;
+            case 'h':
+            default:
+	            usage(argv[0]);
+		        exit(1);
+		}
+	}
+
+    if (s->path == NULL) {
+        s->path = strdup(DEFAULT_SOCK_PATH);
+    }
+
+    printf("listening socket: %s\n", s->path);
+
+    if (s->shmobj == NULL) {
+        s->shmobj = strdup(DEFAULT_SHM_OBJ);
+    }
+
+    printf("shared object: %s\n", s->shmobj);
+    printf("shared object size: %d MB\n", s->shm_size);
+
+}
+
+void print_vec(server_state_t * s, const char * c) {
+
+    int i, j;
+
+#if DEBUG
+    printf("%s (%ld) = ", c, s->total_count);
+    for (i = 0; i < s->total_count; i++) {
+        if (s->live_vms[i].alive) {
+            for (j = 0; j < s->msi_vectors; j++) {
+                printf("[%d|%d] ", s->live_vms[i].sockfd, s->live_vms[i].efd[j]);
+            }
+        }
+    }
+    printf("\n");
+#endif
+
+}
+
+int find_set(fd_set * readset, int max) {
+
+    int i;
+
+    for (i = 1; i < max; i++) {
+        if (FD_ISSET(i, readset)) {
+            return i;
+        }
+    }
+
+    printf("nothing set\n");
+    return -1;
+
+}
+
+void usage(char const *prg) {
+	fprintf(stderr, "use: %s [-h]  [-p <unix socket>] [-s <shm obj>]
+                        [-m <size in MB>] [-n <# of MSI vectors>]\n", prg);
+}
+
+
diff --git a/contrib/ivshmem-server/send_scm.c b/contrib/ivshmem-server/send_scm.c
new file mode 100644
index 0000000..b1bb4a3
--- /dev/null
+++ b/contrib/ivshmem-server/send_scm.c
@@ -0,0 +1,208 @@ 
+#include <stdint.h>
+#include <stdlib.h>
+#include <errno.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <sys/socket.h>
+#include <sys/syscall.h>
+#include <sys/un.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <poll.h>
+#include "send_scm.h"
+
+#ifndef POLLRDHUP
+#define POLLRDHUP 0x2000
+#endif
+
+int readUpdate(int fd, long * posn, int * newfd)
+{
+    struct msghdr msg;
+    struct iovec iov[1];
+    struct cmsghdr *cmptr;
+    size_t len;
+    size_t msg_size = sizeof(int);
+    char control[CMSG_SPACE(msg_size)];
+
+    msg.msg_name = 0;
+    msg.msg_namelen = 0;
+    msg.msg_control = control;
+    msg.msg_controllen = sizeof(control);
+    msg.msg_flags = 0;
+    msg.msg_iov = iov;
+    msg.msg_iovlen = 1;
+
+    iov[0].iov_base = &posn;
+    iov[0].iov_len = sizeof(posn);
+
+    do {
+        len = recvmsg(fd, &msg, 0);
+    } while (len == (size_t) (-1) && (errno == EINTR || errno == EAGAIN));
+
+    printf("iov[0].buf is %ld\n", *((long *)iov[0].iov_base));
+    printf("len is %ld\n", len);
+    // TODO: Logging
+    if (len == (size_t) (-1)) {
+        perror("recvmsg()");
+        return -1;
+    }
+
+    if (msg.msg_controllen < sizeof(struct cmsghdr))
+        return *posn;
+
+    for (cmptr = CMSG_FIRSTHDR(&msg); cmptr != NULL;
+        cmptr = CMSG_NXTHDR(&msg, cmptr)) {
+        if (cmptr->cmsg_level != SOL_SOCKET ||
+            cmptr->cmsg_type != SCM_RIGHTS){
+                printf("continuing %ld\n", sizeof(size_t));
+                printf("read msg_size = %ld\n", msg_size);
+                if (cmptr->cmsg_len != sizeof(control))
+                    printf("not equal (%ld != %ld)\n",cmptr->cmsg_len,sizeof(control));
+                continue;
+        }
+
+        memcpy(newfd, CMSG_DATA(cmptr), sizeof(int));
+        printf("posn is %ld (fd = %d)\n", *posn, *newfd);
+        return 0;
+    }
+
+    fprintf(stderr, "bad data in packet\n");
+    return -1;
+}
+
+int readRights(int fd, long count, size_t count_len, int **fds, int msi_vectors)
+{
+    int j, newfd;
+
+    for (; ;){
+        long posn = 0;
+
+        readUpdate(fd, &posn, &newfd);
+        printf("reading posn %ld ", posn);
+        fds[posn] = (int *)malloc (msi_vectors * sizeof(int));
+        fds[posn][0] = newfd;
+        for (j = 1; j < msi_vectors; j++) {
+            readUpdate(fd, &posn, &newfd);
+            fds[posn][j] = newfd;
+            printf("%d.", fds[posn][j]);
+        }
+        printf("\n");
+
+        /* stop reading once i've read my own eventfds */
+        if (posn == count)
+            break;
+    }
+
+    return 0;
+}
+
+int sendKill(int fd, long const posn, size_t posn_len) {
+
+    struct cmsghdr *cmsg;
+    size_t msg_size = sizeof(int);
+    char control[CMSG_SPACE(msg_size)];
+    struct iovec iov[1];
+    size_t len;
+    struct msghdr msg = { 0, 0, iov, 1, control, sizeof control, 0 };
+
+    struct pollfd mypollfd;
+    int rv;
+
+    iov[0].iov_base = (void *) &posn;
+    iov[0].iov_len = posn_len;
+
+    // from cmsg(3)
+    cmsg = CMSG_FIRSTHDR(&msg);
+    cmsg->cmsg_level = SOL_SOCKET;
+    cmsg->cmsg_len = 0;
+    msg.msg_controllen = cmsg->cmsg_len;
+
+    printf("Killing posn %ld\n", posn);
+
+    // check if the fd is dead or not
+    mypollfd.fd = fd;
+    mypollfd.events = POLLRDHUP;
+    mypollfd.revents = 0;
+
+    rv = poll(&mypollfd, 1, 0);
+
+    printf("rv is %d\n", rv);
+
+    if (rv == 0) {
+        len = sendmsg(fd, &msg, 0);
+        if (len == (size_t) (-1)) {
+            perror("sendmsg()");
+            return -1;
+        }
+        return (len == posn_len);
+    } else {
+        printf("already dead\n");
+        return 0;
+    }
+}
+
+int sendUpdate(int fd, long posn, size_t posn_len, int sendfd)
+{
+
+    struct cmsghdr *cmsg;
+    size_t msg_size = sizeof(int);
+    char control[CMSG_SPACE(msg_size)];
+    struct iovec iov[1];
+    size_t len;
+    struct msghdr msg = { 0, 0, iov, 1, control, sizeof control, 0 };
+
+    iov[0].iov_base = (void *) (&posn);
+    iov[0].iov_len = posn_len;
+
+    // from cmsg(3)
+    cmsg = CMSG_FIRSTHDR(&msg);
+    cmsg->cmsg_level = SOL_SOCKET;
+    cmsg->cmsg_type = SCM_RIGHTS;
+    cmsg->cmsg_len = CMSG_LEN(msg_size);
+    msg.msg_controllen = cmsg->cmsg_len;
+
+    memcpy((CMSG_DATA(cmsg)), &sendfd, msg_size);
+
+    len = sendmsg(fd, &msg, 0);
+    if (len == (size_t) (-1)) {
+        perror("sendmsg()");
+        return -1;
+    }
+
+    return (len == posn_len);
+
+}
+
+int sendPosition(int fd, long const posn)
+{
+    int rv;
+
+    rv = send(fd, &posn, sizeof(long), 0);
+    if (rv != sizeof(long)) {
+        fprintf(stderr, "error sending posn\n");
+        return -1;
+    }
+
+    return 0;
+}
+
+int sendRights(int fd, long const count, size_t count_len, vmguest_t * Live_vms,
+                                                            long msi_vectors)
+{
+    /* updates about new guests are sent one at a time */
+
+    long i, j;
+
+    for (i = 0; i <= count; i++) {
+        if (Live_vms[i].alive) {
+            for (j = 0; j < msi_vectors; j++) {
+                sendUpdate(Live_vms[count].sockfd, i, sizeof(long),
+                                                        Live_vms[i].efd[j]);
+            }
+        }
+    }
+
+    return 0;
+
+}
diff --git a/contrib/ivshmem-server/send_scm.h b/contrib/ivshmem-server/send_scm.h
new file mode 100644
index 0000000..48c9a8d
--- /dev/null
+++ b/contrib/ivshmem-server/send_scm.h
@@ -0,0 +1,19 @@ 
+#ifndef SEND_SCM
+#define SEND_SCM
+
+struct vm_guest_conn {
+    int posn;
+    int sockfd;
+    int * efd;
+    int alive;
+};
+
+typedef struct vm_guest_conn vmguest_t;
+
+int readRights(int fd, long count, size_t count_len, int **fds, int msi_vectors);
+int sendRights(int fd, long const count, size_t count_len, vmguest_t *Live_vms, long msi_vectors);
+int readUpdate(int fd, long * posn, int * newfd);
+int sendUpdate(int fd, long const posn, size_t posn_len, int sendfd);
+int sendPosition(int fd, long const posn);
+int sendKill(int fd, long const posn, size_t posn_len);
+#endif