@@ -1673,6 +1673,37 @@ static void mptcp_set_nospace(struct sock *sk)
set_bit(MPTCP_NOSPACE, &mptcp_sk(sk)->flags);
}
+static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msghdr *msg,
+ size_t len, int *copied_syn)
+{
+ unsigned int saved_flags = msg->msg_flags;
+ struct mptcp_sock *msk = mptcp_sk(sk);
+ int ret;
+
+ lock_sock(ssk);
+ msg->msg_flags |= MSG_DONTWAIT;
+ msk->connect_flags = O_NONBLOCK;
+ msk->is_sendmsg = 1;
+ ret = tcp_sendmsg_fastopen(ssk, msg, copied_syn, len, NULL);
+ msk->is_sendmsg = 0;
+ msg->msg_flags = saved_flags;
+ release_sock(ssk);
+
+ /* do the blocking bits of inet_stream_connect outside the ssk socket lock */
+ if (ret == -EINPROGRESS && !(msg->msg_flags & MSG_DONTWAIT)) {
+ ret = __inet_stream_connect(sk->sk_socket, msg->msg_name,
+ msg->msg_namelen, msg->msg_flags, 1);
+
+ /* Keep the same behaviour of plain TCP: zero the copied bytes in
+ * case of any error, except timeout or signal
+ */
+ if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR)
+ *copied_syn = 0;
+ }
+
+ return ret;
+}
+
static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
{
struct mptcp_sock *msk = mptcp_sk(sk);
@@ -1693,26 +1724,14 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
ssock = __mptcp_nmpc_socket(msk);
if (unlikely(ssock && inet_sk(ssock->sk)->defer_connect)) {
- struct sock *ssk = ssock->sk;
int copied_syn = 0;
- lock_sock(ssk);
-
- msk->connect_flags = (msg->msg_flags & MSG_DONTWAIT) ? O_NONBLOCK : 0;
- msk->is_sendmsg = 1;
- ret = tcp_sendmsg_fastopen(ssk, msg, &copied_syn, len, NULL);
- msk->is_sendmsg = 0;
+ ret = mptcp_sendmsg_fastopen(sk, ssock->sk, msg, len, &copied_syn);
copied += copied_syn;
- if (ret == -EINPROGRESS && copied_syn > 0) {
- /* reflect the new state on the MPTCP socket */
- inet_sk_state_store(sk, inet_sk_state_load(ssk));
- release_sock(ssk);
+ if (ret == -EINPROGRESS && copied_syn > 0)
goto out;
- } else if (ret) {
- release_sock(ssk);
+ else if (ret)
goto do_error;
- }
- release_sock(ssk);
}
timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);