@@ -177,22 +177,51 @@ static void handle___pkvm_vcpu_put(struct kvm_cpu_context *host_ctxt)
}
}
+static struct kvm_vcpu *__get_current_vcpu(struct kvm_vcpu *vcpu,
+ struct kvm_shadow_vcpu_state **state)
+{
+ struct kvm_shadow_vcpu_state *sstate = NULL;
+
+ vcpu = kern_hyp_va(vcpu);
+
+ if (unlikely(is_protected_kvm_enabled())) {
+ sstate = pkvm_loaded_shadow_vcpu_state();
+ if (!sstate || vcpu != sstate->host_vcpu) {
+ sstate = NULL;
+ vcpu = NULL;
+ }
+ }
+
+ *state = sstate;
+ return vcpu;
+}
+
+#define get_current_vcpu(ctxt, regnr, statepp) \
+ ({ \
+ DECLARE_REG(struct kvm_vcpu *, __vcpu, ctxt, regnr); \
+ __get_current_vcpu(__vcpu, statepp); \
+ })
+
static void handle___kvm_vcpu_run(struct kvm_cpu_context *host_ctxt)
{
- DECLARE_REG(struct kvm_vcpu *, host_vcpu, host_ctxt, 1);
+ struct kvm_shadow_vcpu_state *shadow_state;
+ struct kvm_vcpu *vcpu;
int ret;
- if (unlikely(is_protected_kvm_enabled())) {
- struct kvm_shadow_vcpu_state *shadow_state = pkvm_loaded_shadow_vcpu_state();
- struct kvm_vcpu *shadow_vcpu = &shadow_state->shadow_vcpu;
+ vcpu = get_current_vcpu(host_ctxt, 1, &shadow_state);
+ if (!vcpu) {
+ cpu_reg(host_ctxt, 1) = -EINVAL;
+ return;
+ }
+ if (unlikely(shadow_state)) {
flush_shadow_state(shadow_state);
- ret = __kvm_vcpu_run(shadow_vcpu);
+ ret = __kvm_vcpu_run(&shadow_state->shadow_vcpu);
sync_shadow_state(shadow_state);
} else {
- ret = __kvm_vcpu_run(kern_hyp_va(host_vcpu));
+ ret = __kvm_vcpu_run(vcpu);
}
cpu_reg(host_ctxt, 1) = ret;