@@ -71,25 +71,41 @@ static int dh_set_secret(struct crypto_kpp *tfm, const void *buf,
{
struct dh_ctx *ctx = dh_get_ctx(tfm);
struct dh params;
+ char key[CRYPTO_DH_MAX_PRIVKEY_SIZE];
+ int err;
/* Free the old MPI key if any */
dh_clear_ctx(ctx);
- if (crypto_dh_decode_key(buf, len, ¶ms) < 0)
+ err = crypto_dh_decode_key(buf, len, ¶ms);
+ if (err)
goto err_clear_ctx;
- if (dh_set_params(ctx, ¶ms) < 0)
+ if (!params.key_size) {
+ err = crypto_dh_gen_privkey(params.group_id, key,
+ ¶ms.key_size);
+ if (err)
+ goto err_clear_ctx;
+ params.key = key;
+ }
+
+ err = dh_set_params(ctx, ¶ms);
+ if (err)
goto err_clear_ctx;
ctx->xa = mpi_read_raw_data(params.key, params.key_size);
- if (!ctx->xa)
+ if (!ctx->xa) {
+ err = -EINVAL;
goto err_clear_ctx;
+ }
+
+ memzero_explicit(key, sizeof(key));
return 0;
err_clear_ctx:
dh_clear_ctx(ctx);
- return -EINVAL;
+ return err;
}
/*