@@ -43,7 +43,7 @@
#include <net/udp.h>
#include <net/tcp.h>
#include <net/tcp_states.h>
-#include <net/tls.h>
+#include <net/tls_prot.h>
#include <net/handshake.h>
#include <linux/uaccess.h>
#include <linux/highmem.h>
@@ -226,27 +226,30 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining)
}
static int
-svc_tcp_sock_process_cmsg(struct svc_sock *svsk, struct msghdr *msg,
+svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
struct cmsghdr *cmsg, int ret)
{
- if (cmsg->cmsg_level == SOL_TLS &&
- cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
- u8 content_type = *((u8 *)CMSG_DATA(cmsg));
-
- switch (content_type) {
- case TLS_RECORD_TYPE_DATA:
- /* TLS sets EOR at the end of each application data
- * record, even though there might be more frames
- * waiting to be decrypted.
- */
- msg->msg_flags &= ~MSG_EOR;
- break;
- case TLS_RECORD_TYPE_ALERT:
- ret = -ENOTCONN;
- break;
- default:
- ret = -EAGAIN;
- }
+ u8 content_type = tls_get_record_type(sock->sk, cmsg);
+ u8 level, description;
+
+ switch (content_type) {
+ case 0:
+ break;
+ case TLS_RECORD_TYPE_DATA:
+ /* TLS sets EOR at the end of each application data
+ * record, even though there might be more frames
+ * waiting to be decrypted.
+ */
+ msg->msg_flags &= ~MSG_EOR;
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ tls_alert_recv(sock->sk, msg, &level, &description);
+ ret = (level == TLS_ALERT_LEVEL_FATAL) ?
+ -ENOTCONN : -EAGAIN;
+ break;
+ default:
+ /* discard this record type */
+ ret = -EAGAIN;
}
return ret;
}
@@ -258,13 +261,14 @@ svc_tcp_sock_recv_cmsg(struct svc_sock *svsk, struct msghdr *msg)
struct cmsghdr cmsg;
u8 buf[CMSG_SPACE(sizeof(u8))];
} u;
+ struct socket *sock = svsk->sk_sock;
int ret;
msg->msg_control = &u;
msg->msg_controllen = sizeof(u);
- ret = sock_recvmsg(svsk->sk_sock, msg, MSG_DONTWAIT);
+ ret = sock_recvmsg(sock, msg, MSG_DONTWAIT);
if (unlikely(msg->msg_controllen != sizeof(u)))
- ret = svc_tcp_sock_process_cmsg(svsk, msg, &u.cmsg, ret);
+ ret = svc_tcp_sock_process_cmsg(sock, msg, &u.cmsg, ret);
return ret;
}
@@ -47,7 +47,7 @@
#include <net/checksum.h>
#include <net/udp.h>
#include <net/tcp.h>
-#include <net/tls.h>
+#include <net/tls_prot.h>
#include <net/handshake.h>
#include <linux/bvec.h>
@@ -360,24 +360,27 @@ static int
xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg,
struct cmsghdr *cmsg, int ret)
{
- if (cmsg->cmsg_level == SOL_TLS &&
- cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
- u8 content_type = *((u8 *)CMSG_DATA(cmsg));
-
- switch (content_type) {
- case TLS_RECORD_TYPE_DATA:
- /* TLS sets EOR at the end of each application data
- * record, even though there might be more frames
- * waiting to be decrypted.
- */
- msg->msg_flags &= ~MSG_EOR;
- break;
- case TLS_RECORD_TYPE_ALERT:
- ret = -ENOTCONN;
- break;
- default:
- ret = -EAGAIN;
- }
+ u8 content_type = tls_get_record_type(sock->sk, cmsg);
+ u8 level, description;
+
+ switch (content_type) {
+ case 0:
+ break;
+ case TLS_RECORD_TYPE_DATA:
+ /* TLS sets EOR at the end of each application data
+ * record, even though there might be more frames
+ * waiting to be decrypted.
+ */
+ msg->msg_flags &= ~MSG_EOR;
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ tls_alert_recv(sock->sk, msg, &level, &description);
+ ret = (level == TLS_ALERT_LEVEL_FATAL) ?
+ -EACCES : -EAGAIN;
+ break;
+ default:
+ /* discard this record type */
+ ret = -EAGAIN;
}
return ret;
}
@@ -777,6 +780,8 @@ static void xs_stream_data_receive(struct sock_xprt *transport)
}
if (ret == -ESHUTDOWN)
kernel_sock_shutdown(transport->sock, SHUT_RDWR);
+ else if (ret == -EACCES)
+ xprt_wake_pending_tasks(&transport->xprt, -EACCES);
else
xs_poll_check_readable(transport);
out: