@@ -646,7 +646,7 @@ static struct socket *drbd_try_connect(struct drbd_connection *connection)
* stay C_WF_CONNECTION, don't go Disconnecting! */
disconnect_on_error = 0;
what = "connect";
- err = sock->ops->connect(sock, (struct sockaddr *) &peer_in6, peer_addr_len, 0);
+ err = kernel_connect(sock, (struct sockaddr *)&peer_in6, peer_addr_len, 0);
out:
if (err < 0) {
@@ -993,7 +993,7 @@ static int kernel_bindconnect(struct socket *s, struct sockaddr *laddr,
ret = s->ops->bind(s, laddr, laddrlen);
if (ret)
return ret;
- ret = s->ops->connect(s, raddr, raddrlen, flags);
+ ret = kernel_connect(s, raddr, raddrlen, flags);
return ret < 0 ? ret : 0;
}
@@ -1328,7 +1328,7 @@ static int kernel_bindconnect(struct socket *s, struct sockaddr *laddr,
if (rv < 0)
return rv;
- rv = s->ops->connect(s, raddr, size, flags);
+ rv = kernel_connect(s, raddr, size, flags);
return rv < 0 ? rv : 0;
}
@@ -1818,7 +1818,7 @@ static int dlm_tcp_bind(struct socket *sock)
static int dlm_tcp_connect(struct connection *con, struct socket *sock,
struct sockaddr *addr, int addr_len)
{
- return sock->ops->connect(sock, addr, addr_len, O_NONBLOCK);
+ return kernel_connect(sock, addr, addr_len, O_NONBLOCK);
}
static int dlm_tcp_listen_validate(void)
@@ -1876,12 +1876,12 @@ static int dlm_sctp_connect(struct connection *con, struct socket *sock,
int ret;
/*
- * Make sock->ops->connect() function return in specified time,
+ * Make kernel_connect() function return in specified time,
* since O_NONBLOCK argument in connect() function does not work here,
* then, we should restore the default value of this attribute.
*/
sock_set_sndtimeo(sock->sk, 5);
- ret = sock->ops->connect(sock, addr, addr_len, 0);
+ ret = kernel_connect(sock, addr, addr_len, 0);
sock_set_sndtimeo(sock->sk, 0);
return ret;
}
@@ -1636,10 +1636,10 @@ static void o2net_start_connect(struct work_struct *work)
remoteaddr.sin_addr.s_addr = node->nd_ipv4_address;
remoteaddr.sin_port = node->nd_ipv4_port;
- ret = sc->sc_sock->ops->connect(sc->sc_sock,
- (struct sockaddr *)&remoteaddr,
- sizeof(remoteaddr),
- O_NONBLOCK);
+ ret = kernel_connect(sc->sc_sock,
+ (struct sockaddr *)&remoteaddr,
+ sizeof(remoteaddr),
+ O_NONBLOCK);
if (ret == -EINPROGRESS)
ret = 0;
@@ -3042,8 +3042,8 @@ generic_ip_connect(struct TCP_Server_Info *server)
socket->sk->sk_sndbuf,
socket->sk->sk_rcvbuf, socket->sk->sk_rcvtimeo);
- rc = socket->ops->connect(socket, saddr, slen,
- server->noblockcnt ? O_NONBLOCK : 0);
+ rc = kernel_connect(socket, saddr, slen,
+ server->noblockcnt ? O_NONBLOCK : 0);
/*
* When mounting SMB root file systems, we do not want to block in
* connect. Otherwise bail out and then let cifs_reconnect() perform
@@ -1019,9 +1019,9 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
}
}
- err = READ_ONCE(csocket->ops)->connect(csocket,
- (struct sockaddr *)&sin_server,
- sizeof(struct sockaddr_in), 0);
+ err = kernel_connect(csocket,
+ (struct sockaddr *)&sin_server,
+ sizeof(struct sockaddr_in), 0);
if (err < 0) {
pr_err("%s (%d): problem connecting socket to %s\n",
__func__, task_pid_nr(current), addr);
@@ -1060,8 +1060,8 @@ p9_fd_create_unix(struct p9_client *client, const char *addr, char *args)
return err;
}
- err = READ_ONCE(csocket->ops)->connect(csocket, (struct sockaddr *)&sun_server,
- sizeof(struct sockaddr_un) - 1, 0);
+ err = kernel_connect(csocket, (struct sockaddr *)&sun_server,
+ sizeof(struct sockaddr_un) - 1, 0);
if (err < 0) {
pr_err("%s (%d): problem connecting socket: %s: %d\n",
__func__, task_pid_nr(current), addr, err);
@@ -459,8 +459,8 @@ int ceph_tcp_connect(struct ceph_connection *con)
set_sock_callbacks(sock, con);
con_sock_state_connecting(con);
- ret = sock->ops->connect(sock, (struct sockaddr *)&ss, sizeof(ss),
- O_NONBLOCK);
+ ret = kernel_connect(sock, (struct sockaddr *)&ss, sizeof(ss),
+ O_NONBLOCK);
if (ret == -EINPROGRESS) {
dout("connect %s EINPROGRESS sk_state = %u\n",
ceph_pr_addr(&con->peer_addr),
@@ -1505,8 +1505,8 @@ static int make_send_sock(struct netns_ipvs *ipvs, int id,
}
get_mcast_sockaddr(&mcast_addr, &salen, &ipvs->mcfg, id);
- result = sock->ops->connect(sock, (struct sockaddr *) &mcast_addr,
- salen, 0);
+ result = kernel_connect(sock, (struct sockaddr *)&mcast_addr,
+ salen, 0);
if (result < 0) {
pr_err("Error connecting to the multicast addr\n");
goto error;
@@ -173,7 +173,7 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
* own the socket
*/
rds_tcp_set_callbacks(sock, cp);
- ret = sock->ops->connect(sock, addr, addrlen, O_NONBLOCK);
+ ret = kernel_connect(sock, addr, addrlen, O_NONBLOCK);
rdsdebug("connect to address %pI6c returned %d\n", &conn->c_faddr, ret);
if (ret == -EINPROGRESS)
@@ -3572,6 +3572,9 @@ int kernel_connect(struct socket *sock, struct sockaddr *addr, int addrlen,
{
struct sockaddr_storage address;
+ if (addrlen < 0 || addrlen > sizeof(address))
+ return -EINVAL;
+
memcpy(&address, addr, addrlen);
return READ_ONCE(sock->ops)->connect(sock, (struct sockaddr *)&address,
commit 0bdf399342c5 ("net: Avoid address overwrite in kernel_connect") ensured that kernel_connect() will not overwrite the address parameter in cases where BPF connect hooks perform an address rewrite. This change replaces all direct calls to sock->ops->connect() with kernel_connect() to make these call safe. This patch also introduces a sanity check to kernel_connect() for addrlen. Link: https://lore.kernel.org/netdev/20230912013332.2048422-1-jrife@google.com/ Fixes: d74bad4e74ee ("bpf: Hooks for sys_connect") Signed-off-by: Jordan Rife <jrife@google.com> --- v2->v3: Add "Fixes" tag. Check for positivity in addrlen sanity check. v1->v2: Split up original patch into patch series. Insulate calls with kernel_connect() instead of pushing address copy deeper into sock->ops->connect(). drivers/block/drbd/drbd_receiver.c | 2 +- drivers/infiniband/hw/erdma/erdma_cm.c | 2 +- drivers/infiniband/sw/siw/siw_cm.c | 2 +- fs/dlm/lowcomms.c | 6 +++--- fs/ocfs2/cluster/tcp.c | 8 ++++---- fs/smb/client/connect.c | 4 ++-- net/9p/trans_fd.c | 10 +++++----- net/ceph/messenger.c | 4 ++-- net/netfilter/ipvs/ip_vs_sync.c | 4 ++-- net/rds/tcp_connect.c | 2 +- net/socket.c | 3 +++ 11 files changed, 25 insertions(+), 22 deletions(-)