@@ -205,6 +205,8 @@ void vsock_remove_sock(struct vsock_sock *vsk);
void vsock_for_each_connected_socket(void (*fn)(struct sock *sk));
int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
bool vsock_find_cid(unsigned int cid);
+int vsock_wait_space(struct sock *sk, size_t space, int flags,
+ struct vsock_transport_send_notify_data *send_data);
/**** TAP ****/
@@ -1692,6 +1692,65 @@ static int vsock_connectible_getsockopt(struct socket *sock,
return 0;
}
+int vsock_wait_space(struct sock *sk, size_t space, int flags,
+ struct vsock_transport_send_notify_data *send_data)
+{
+ const struct vsock_transport *transport;
+ struct vsock_sock *vsk;
+ long timeout;
+ int err;
+
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+ vsk = vsock_sk(sk);
+ transport = vsk->transport;
+ timeout = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+ err = 0;
+
+ add_wait_queue(sk_sleep(sk), &wait);
+
+ while (vsock_stream_has_space(vsk) < space &&
+ sk->sk_err == 0 &&
+ !(sk->sk_shutdown & SEND_SHUTDOWN) &&
+ !(vsk->peer_shutdown & RCV_SHUTDOWN)) {
+
+ /* Don't wait for non-blocking sockets. */
+ if (timeout == 0) {
+ err = -EAGAIN;
+ goto out_err;
+ }
+
+ if (send_data) {
+ err = transport->notify_send_pre_block(vsk, send_data);
+ if (err < 0)
+ goto out_err;
+ }
+
+ release_sock(sk);
+ timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout);
+ lock_sock(sk);
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeout);
+ goto out_err;
+ } else if (timeout == 0) {
+ err = -EAGAIN;
+ goto out_err;
+ }
+ }
+
+ if (sk->sk_err) {
+ err = -sk->sk_err;
+ } else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
+ (vsk->peer_shutdown & RCV_SHUTDOWN)) {
+ err = -EPIPE;
+ }
+
+out_err:
+ remove_wait_queue(sk_sleep(sk), &wait);
+ return err;
+}
+EXPORT_SYMBOL_GPL(vsock_wait_space);
+
static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
size_t len)
{
@@ -1699,10 +1758,8 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
struct vsock_sock *vsk;
const struct vsock_transport *transport;
ssize_t total_written;
- long timeout;
int err;
struct vsock_transport_send_notify_data send_data;
- DEFINE_WAIT_FUNC(wait, woken_wake_function);
sk = sock->sk;
vsk = vsock_sk(sk);
@@ -1740,9 +1797,6 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
- /* Wait for room in the produce queue to enqueue our user's data. */
- timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
-
err = transport->notify_send_init(vsk, &send_data);
if (err < 0)
goto out;
@@ -1750,39 +1804,8 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
while (total_written < len) {
ssize_t written;
- add_wait_queue(sk_sleep(sk), &wait);
- while (vsock_stream_has_space(vsk) == 0 &&
- sk->sk_err == 0 &&
- !(sk->sk_shutdown & SEND_SHUTDOWN) &&
- !(vsk->peer_shutdown & RCV_SHUTDOWN)) {
-
- /* Don't wait for non-blocking sockets. */
- if (timeout == 0) {
- err = -EAGAIN;
- remove_wait_queue(sk_sleep(sk), &wait);
- goto out_err;
- }
-
- err = transport->notify_send_pre_block(vsk, &send_data);
- if (err < 0) {
- remove_wait_queue(sk_sleep(sk), &wait);
- goto out_err;
- }
-
- release_sock(sk);
- timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout);
- lock_sock(sk);
- if (signal_pending(current)) {
- err = sock_intr_errno(timeout);
- remove_wait_queue(sk_sleep(sk), &wait);
- goto out_err;
- } else if (timeout == 0) {
- err = -EAGAIN;
- remove_wait_queue(sk_sleep(sk), &wait);
- goto out_err;
- }
- }
- remove_wait_queue(sk_sleep(sk), &wait);
+ if (vsock_wait_space(sk, 1, msg->msg_flags, &send_data))
+ goto out_err;
/* These checks occur both as part of and after the loop
* conditional since we need to check before and after