@@ -145,7 +145,8 @@ void tls_err_abort(struct sock *sk, int err);
int init_prot_info(struct tls_prot_info *prot,
const struct tls_crypto_info *crypto_info,
const struct tls_cipher_desc *cipher_desc);
-int tls_set_sw_offload(struct sock *sk, int tx);
+int tls_set_sw_offload(struct sock *sk, int tx,
+ struct tls_crypto_info *new_crypto_info);
void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
void tls_sw_strparser_done(struct tls_context *tls_ctx);
@@ -1227,7 +1227,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
context->resync_nh_reset = 1;
ctx->priv_ctx_rx = context;
- rc = tls_set_sw_offload(sk, 0);
+ rc = tls_set_sw_offload(sk, 0, NULL);
if (rc)
goto release_ctx;
@@ -423,9 +423,10 @@ static __poll_t tls_sk_poll(struct file *file, struct socket *sock,
ctx = tls_sw_ctx_rx(tls_ctx);
psock = sk_psock_get(sk);
- if (skb_queue_empty_lockless(&ctx->rx_list) &&
- !tls_strp_msg_ready(ctx) &&
- sk_psock_queue_empty(psock))
+ if ((skb_queue_empty_lockless(&ctx->rx_list) &&
+ !tls_strp_msg_ready(ctx) &&
+ sk_psock_queue_empty(psock)) ||
+ READ_ONCE(ctx->key_update_pending))
mask &= ~(EPOLLIN | EPOLLRDNORM);
if (psock)
@@ -612,11 +613,13 @@ static int validate_crypto_info(const struct tls_crypto_info *crypto_info,
static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
unsigned int optlen, int tx)
{
- struct tls_crypto_info *crypto_info;
- struct tls_crypto_info *alt_crypto_info;
+ struct tls_crypto_info *crypto_info, *alt_crypto_info;
+ struct tls_crypto_info *old_crypto_info = NULL;
struct tls_context *ctx = tls_get_ctx(sk);
const struct tls_cipher_desc *cipher_desc;
union tls_crypto_context *crypto_ctx;
+ union tls_crypto_context tmp = {};
+ bool update = false;
int rc = 0;
int conf;
@@ -633,9 +636,18 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
crypto_info = &crypto_ctx->info;
- /* Currently we don't support set crypto info more than one time */
- if (TLS_CRYPTO_INFO_READY(crypto_info))
- return -EBUSY;
+ if (TLS_CRYPTO_INFO_READY(crypto_info)) {
+ /* Currently we only support setting crypto info more
+ * than one time for TLS 1.3
+ */
+ if (crypto_info->version != TLS_1_3_VERSION)
+ return -EBUSY;
+
+ update = true;
+ old_crypto_info = crypto_info;
+ crypto_info = &tmp.info;
+ crypto_ctx = &tmp;
+ }
rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
if (rc) {
@@ -643,7 +655,14 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
goto err_crypto_info;
}
- rc = validate_crypto_info(crypto_info, alt_crypto_info);
+ if (update) {
+ /* Ensure that TLS version and ciphers are not modified */
+ if (crypto_info->version != old_crypto_info->version ||
+ crypto_info->cipher_type != old_crypto_info->cipher_type)
+ rc = -EINVAL;
+ } else {
+ rc = validate_crypto_info(crypto_info, alt_crypto_info);
+ }
if (rc)
goto err_crypto_info;
@@ -673,7 +692,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
} else {
- rc = tls_set_sw_offload(sk, 1);
+ rc = tls_set_sw_offload(sk, 1,
+ update ? crypto_info : NULL);
if (rc)
goto err_crypto_info;
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
@@ -687,14 +707,16 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
} else {
- rc = tls_set_sw_offload(sk, 0);
+ rc = tls_set_sw_offload(sk, 0,
+ update ? crypto_info : NULL);
if (rc)
goto err_crypto_info;
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
conf = TLS_SW;
}
- tls_sw_strparser_arm(sk, ctx);
+ if (!update)
+ tls_sw_strparser_arm(sk, ctx);
}
if (tx)
@@ -2716,12 +2716,22 @@ int init_prot_info(struct tls_prot_info *prot,
return 0;
}
-int tls_set_sw_offload(struct sock *sk, int tx)
+static void tls_finish_key_update(struct sock *sk, struct tls_context *tls_ctx)
{
+ struct tls_sw_context_rx *ctx = tls_ctx->priv_ctx_rx;
+
+ WRITE_ONCE(ctx->key_update_pending, false);
+ /* wake-up pre-existing poll() */
+ ctx->saved_data_ready(sk);
+}
+
+int tls_set_sw_offload(struct sock *sk, int tx,
+ struct tls_crypto_info *new_crypto_info)
+{
+ struct tls_crypto_info *crypto_info, *src_crypto_info;
struct tls_sw_context_tx *sw_ctx_tx = NULL;
struct tls_sw_context_rx *sw_ctx_rx = NULL;
const struct tls_cipher_desc *cipher_desc;
- struct tls_crypto_info *crypto_info;
char *iv, *rec_seq, *key, *salt;
struct cipher_context *cctx;
struct tls_prot_info *prot;
@@ -2733,45 +2743,47 @@ int tls_set_sw_offload(struct sock *sk, int tx)
ctx = tls_get_ctx(sk);
prot = &ctx->prot_info;
- if (tx) {
- ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);
- if (!ctx->priv_ctx_tx)
- return -ENOMEM;
+ /* new_crypto_info != NULL means rekey */
+ if (!new_crypto_info) {
+ if (tx) {
+ ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);
+ if (!ctx->priv_ctx_tx)
+ return -ENOMEM;
+ } else {
+ ctx->priv_ctx_rx = init_ctx_rx(ctx);
+ if (!ctx->priv_ctx_rx)
+ return -ENOMEM;
+ }
+ }
+ if (tx) {
sw_ctx_tx = ctx->priv_ctx_tx;
crypto_info = &ctx->crypto_send.info;
cctx = &ctx->tx;
aead = &sw_ctx_tx->aead_send;
} else {
- ctx->priv_ctx_rx = init_ctx_rx(ctx);
- if (!ctx->priv_ctx_rx)
- return -ENOMEM;
-
sw_ctx_rx = ctx->priv_ctx_rx;
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
aead = &sw_ctx_rx->aead_recv;
- sw_ctx_rx->key_update_pending = false;
}
- cipher_desc = get_cipher_desc(crypto_info->cipher_type);
+ src_crypto_info = new_crypto_info ?: crypto_info;
+
+ cipher_desc = get_cipher_desc(src_crypto_info->cipher_type);
if (!cipher_desc) {
rc = -EINVAL;
goto free_priv;
}
- rc = init_prot_info(prot, crypto_info, cipher_desc);
+ rc = init_prot_info(prot, src_crypto_info, cipher_desc);
if (rc)
goto free_priv;
- iv = crypto_info_iv(crypto_info, cipher_desc);
- key = crypto_info_key(crypto_info, cipher_desc);
- salt = crypto_info_salt(crypto_info, cipher_desc);
- rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc);
-
- memcpy(cctx->iv, salt, cipher_desc->salt);
- memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv);
- memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq);
+ iv = crypto_info_iv(src_crypto_info, cipher_desc);
+ key = crypto_info_key(src_crypto_info, cipher_desc);
+ salt = crypto_info_salt(src_crypto_info, cipher_desc);
+ rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc);
if (!*aead) {
*aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0);
@@ -2784,20 +2796,30 @@ int tls_set_sw_offload(struct sock *sk, int tx)
ctx->push_pending_record = tls_sw_push_pending_record;
+ /* setkey is the last operation that could fail during a
+ * rekey. if it succeeds, we can start modifying the
+ * context.
+ */
rc = crypto_aead_setkey(*aead, key, cipher_desc->key);
- if (rc)
- goto free_aead;
+ if (rc) {
+ if (new_crypto_info)
+ goto out;
+ else
+ goto free_aead;
+ }
- rc = crypto_aead_setauthsize(*aead, prot->tag_size);
- if (rc)
- goto free_aead;
+ if (!new_crypto_info) {
+ rc = crypto_aead_setauthsize(*aead, prot->tag_size);
+ if (rc)
+ goto free_aead;
+ }
- if (sw_ctx_rx) {
+ if (!tx && !new_crypto_info) {
tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
tls_update_rx_zc_capable(ctx);
sw_ctx_rx->async_capable =
- crypto_info->version != TLS_1_3_VERSION &&
+ src_crypto_info->version != TLS_1_3_VERSION &&
!!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
rc = tls_strp_init(&sw_ctx_rx->strp, sk);
@@ -2805,18 +2827,33 @@ int tls_set_sw_offload(struct sock *sk, int tx)
goto free_aead;
}
+ memcpy(cctx->iv, salt, cipher_desc->salt);
+ memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv);
+ memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq);
+
+ if (new_crypto_info) {
+ unsafe_memcpy(crypto_info, new_crypto_info,
+ cipher_desc->crypto_info,
+ /* size was checked in do_tls_setsockopt_conf */);
+ memzero_explicit(new_crypto_info, cipher_desc->crypto_info);
+ if (!tx)
+ tls_finish_key_update(sk, ctx);
+ }
+
goto out;
free_aead:
crypto_free_aead(*aead);
*aead = NULL;
free_priv:
- if (tx) {
- kfree(ctx->priv_ctx_tx);
- ctx->priv_ctx_tx = NULL;
- } else {
- kfree(ctx->priv_ctx_rx);
- ctx->priv_ctx_rx = NULL;
+ if (!new_crypto_info) {
+ if (tx) {
+ kfree(ctx->priv_ctx_tx);
+ ctx->priv_ctx_tx = NULL;
+ } else {
+ kfree(ctx->priv_ctx_rx);
+ ctx->priv_ctx_rx = NULL;
+ }
}
out:
return rc;