]> nv-tegra.nvidia Code Review - linux-2.6.git/blobdiff - arch/x86/kvm/mmu.c
KVM: MMU: remove unused macros
[linux-2.6.git] / arch / x86 / kvm / mmu.c
index e4b862eb888517833bfafb471b1b9e855e5927c7..88d36890f420dea52ac4a75d5fb7bece89d474b8 100644 (file)
@@ -7,7 +7,7 @@
  * MMU support
  *
  * Copyright (C) 2006 Qumranet, Inc.
- * Copyright 2010 Red Hat, Inc. and/or its affilates.
+ * Copyright 2010 Red Hat, Inc. and/or its affiliates.
  *
  * Authors:
  *   Yaniv Kamay  <yaniv@qumranet.com>
  *
  */
 
+#include "irq.h"
 #include "mmu.h"
 #include "x86.h"
 #include "kvm_cache_regs.h"
+#include "x86.h"
 
 #include <linux/kvm_host.h>
 #include <linux/types.h>
  */
 bool tdp_enabled = false;
 
-#undef MMU_DEBUG
+enum {
+       AUDIT_PRE_PAGE_FAULT,
+       AUDIT_POST_PAGE_FAULT,
+       AUDIT_PRE_PTE_WRITE,
+       AUDIT_POST_PTE_WRITE,
+       AUDIT_PRE_SYNC,
+       AUDIT_POST_SYNC
+};
 
-#undef AUDIT
+char *audit_point_name[] = {
+       "pre page fault",
+       "post page fault",
+       "pre pte write",
+       "post pte write",
+       "pre sync",
+       "post sync"
+};
 
-#ifdef AUDIT
-static void kvm_mmu_audit(struct kvm_vcpu *vcpu, const char *msg);
-#else
-static void kvm_mmu_audit(struct kvm_vcpu *vcpu, const char *msg) {}
-#endif
+#undef MMU_DEBUG
 
 #ifdef MMU_DEBUG
 
@@ -71,7 +83,7 @@ static void kvm_mmu_audit(struct kvm_vcpu *vcpu, const char *msg) {}
 
 #endif
 
-#if defined(MMU_DEBUG) || defined(AUDIT)
+#ifdef MMU_DEBUG
 static int dbg = 0;
 module_param(dbg, bool, 0644);
 #endif
@@ -89,6 +101,8 @@ module_param(oos_shadow, bool, 0644);
        }
 #endif
 
+#define PTE_PREFETCH_NUM               8
+
 #define PT_FIRST_AVAIL_BITS_SHIFT 9
 #define PT64_SECOND_AVAIL_BITS_SHIFT 52
 
@@ -97,9 +111,6 @@ module_param(oos_shadow, bool, 0644);
 #define PT64_LEVEL_SHIFT(level) \
                (PAGE_SHIFT + (level - 1) * PT64_LEVEL_BITS)
 
-#define PT64_LEVEL_MASK(level) \
-               (((1ULL << PT64_LEVEL_BITS) - 1) << PT64_LEVEL_SHIFT(level))
-
 #define PT64_INDEX(address, level)\
        (((address) >> PT64_LEVEL_SHIFT(level)) & ((1 << PT64_LEVEL_BITS) - 1))
 
@@ -109,8 +120,6 @@ module_param(oos_shadow, bool, 0644);
 #define PT32_LEVEL_SHIFT(level) \
                (PAGE_SHIFT + (level - 1) * PT32_LEVEL_BITS)
 
-#define PT32_LEVEL_MASK(level) \
-               (((1ULL << PT32_LEVEL_BITS) - 1) << PT32_LEVEL_SHIFT(level))
 #define PT32_LVL_OFFSET_MASK(level) \
        (PT32_BASE_ADDR_MASK & ((1ULL << (PAGE_SHIFT + (((level) - 1) \
                                                * PT32_LEVEL_BITS))) - 1))
@@ -178,10 +187,10 @@ typedef void (*mmu_parent_walk_fn) (struct kvm_mmu_page *sp, u64 *spte);
 static struct kmem_cache *pte_chain_cache;
 static struct kmem_cache *rmap_desc_cache;
 static struct kmem_cache *mmu_page_header_cache;
+static struct percpu_counter kvm_total_used_mmu_pages;
 
 static u64 __read_mostly shadow_trap_nonpresent_pte;
 static u64 __read_mostly shadow_notrap_nonpresent_pte;
-static u64 __read_mostly shadow_base_present_pte;
 static u64 __read_mostly shadow_nx_mask;
 static u64 __read_mostly shadow_x_mask;        /* mutual exclusive with nx_mask */
 static u64 __read_mostly shadow_user_mask;
@@ -200,12 +209,6 @@ void kvm_mmu_set_nonpresent_ptes(u64 trap_pte, u64 notrap_pte)
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_nonpresent_ptes);
 
-void kvm_mmu_set_base_ptes(u64 base_pte)
-{
-       shadow_base_present_pte = base_pte;
-}
-EXPORT_SYMBOL_GPL(kvm_mmu_set_base_ptes);
-
 void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
                u64 dirty_mask, u64 nx_mask, u64 x_mask)
 {
@@ -281,11 +284,7 @@ static gfn_t pse36_gfn_delta(u32 gpte)
 
 static void __set_spte(u64 *sptep, u64 spte)
 {
-#ifdef CONFIG_X86_64
-       set_64bit((unsigned long *)sptep, spte);
-#else
-       set_64bit((unsigned long long *)sptep, spte);
-#endif
+       set_64bit(sptep, spte);
 }
 
 static u64 __xchg_spte(u64 *sptep, u64 new_spte)
@@ -303,18 +302,50 @@ static u64 __xchg_spte(u64 *sptep, u64 new_spte)
 #endif
 }
 
+static bool spte_has_volatile_bits(u64 spte)
+{
+       if (!shadow_accessed_mask)
+               return false;
+
+       if (!is_shadow_present_pte(spte))
+               return false;
+
+       if ((spte & shadow_accessed_mask) &&
+             (!is_writable_pte(spte) || (spte & shadow_dirty_mask)))
+               return false;
+
+       return true;
+}
+
+static bool spte_is_bit_cleared(u64 old_spte, u64 new_spte, u64 bit_mask)
+{
+       return (old_spte & bit_mask) && !(new_spte & bit_mask);
+}
+
 static void update_spte(u64 *sptep, u64 new_spte)
 {
-       u64 old_spte;
+       u64 mask, old_spte = *sptep;
+
+       WARN_ON(!is_rmap_spte(new_spte));
+
+       new_spte |= old_spte & shadow_dirty_mask;
+
+       mask = shadow_accessed_mask;
+       if (is_writable_pte(old_spte))
+               mask |= shadow_dirty_mask;
 
-       if (!shadow_accessed_mask || (new_spte & shadow_accessed_mask) ||
-             !is_rmap_spte(*sptep))
+       if (!spte_has_volatile_bits(old_spte) || (new_spte & mask) == mask)
                __set_spte(sptep, new_spte);
-       else {
+       else
                old_spte = __xchg_spte(sptep, new_spte);
-               if (old_spte & shadow_accessed_mask)
-                       mark_page_accessed(pfn_to_page(spte_to_pfn(old_spte)));
-       }
+
+       if (!shadow_accessed_mask)
+               return;
+
+       if (spte_is_bit_cleared(old_spte, new_spte, shadow_accessed_mask))
+               kvm_set_pfn_accessed(spte_to_pfn(old_spte));
+       if (spte_is_bit_cleared(old_spte, new_spte, shadow_dirty_mask))
+               kvm_set_pfn_dirty(spte_to_pfn(old_spte));
 }
 
 static int mmu_topup_memory_cache(struct kvm_mmu_memory_cache *cache,
@@ -343,15 +374,15 @@ static void mmu_free_memory_cache(struct kvm_mmu_memory_cache *mc,
 static int mmu_topup_memory_cache_page(struct kvm_mmu_memory_cache *cache,
                                       int min)
 {
-       struct page *page;
+       void *page;
 
        if (cache->nobjs >= min)
                return 0;
        while (cache->nobjs < ARRAY_SIZE(cache->objects)) {
-               page = alloc_page(GFP_KERNEL);
+               page = (void *)__get_free_page(GFP_KERNEL);
                if (!page)
                        return -ENOMEM;
-               cache->objects[cache->nobjs++] = page_address(page);
+               cache->objects[cache->nobjs++] = page;
        }
        return 0;
 }
@@ -371,7 +402,7 @@ static int mmu_topup_memory_caches(struct kvm_vcpu *vcpu)
        if (r)
                goto out;
        r = mmu_topup_memory_cache(&vcpu->arch.mmu_rmap_desc_cache,
-                                  rmap_desc_cache, 4);
+                                  rmap_desc_cache, 4 + PTE_PREFETCH_NUM);
        if (r)
                goto out;
        r = mmu_topup_memory_cache_page(&vcpu->arch.mmu_page_cache, 8);
@@ -441,46 +472,46 @@ static void kvm_mmu_page_set_gfn(struct kvm_mmu_page *sp, int index, gfn_t gfn)
 }
 
 /*
- * Return the pointer to the largepage write count for a given
- * gfn, handling slots that are not large page aligned.
+ * Return the pointer to the large page information for a given gfn,
+ * handling slots that are not large page aligned.
  */
-static int *slot_largepage_idx(gfn_t gfn,
-                              struct kvm_memory_slot *slot,
-                              int level)
+static struct kvm_lpage_info *lpage_info_slot(gfn_t gfn,
+                                             struct kvm_memory_slot *slot,
+                                             int level)
 {
        unsigned long idx;
 
        idx = (gfn >> KVM_HPAGE_GFN_SHIFT(level)) -
              (slot->base_gfn >> KVM_HPAGE_GFN_SHIFT(level));
-       return &slot->lpage_info[level - 2][idx].write_count;
+       return &slot->lpage_info[level - 2][idx];
 }
 
 static void account_shadowed(struct kvm *kvm, gfn_t gfn)
 {
        struct kvm_memory_slot *slot;
-       int *write_count;
+       struct kvm_lpage_info *linfo;
        int i;
 
        slot = gfn_to_memslot(kvm, gfn);
        for (i = PT_DIRECTORY_LEVEL;
             i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
-               write_count   = slot_largepage_idx(gfn, slot, i);
-               *write_count += 1;
+               linfo = lpage_info_slot(gfn, slot, i);
+               linfo->write_count += 1;
        }
 }
 
 static void unaccount_shadowed(struct kvm *kvm, gfn_t gfn)
 {
        struct kvm_memory_slot *slot;
-       int *write_count;
+       struct kvm_lpage_info *linfo;
        int i;
 
        slot = gfn_to_memslot(kvm, gfn);
        for (i = PT_DIRECTORY_LEVEL;
             i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
-               write_count   = slot_largepage_idx(gfn, slot, i);
-               *write_count -= 1;
-               WARN_ON(*write_count < 0);
+               linfo = lpage_info_slot(gfn, slot, i);
+               linfo->write_count -= 1;
+               WARN_ON(linfo->write_count < 0);
        }
 }
 
@@ -489,12 +520,12 @@ static int has_wrprotected_page(struct kvm *kvm,
                                int level)
 {
        struct kvm_memory_slot *slot;
-       int *largepage_idx;
+       struct kvm_lpage_info *linfo;
 
        slot = gfn_to_memslot(kvm, gfn);
        if (slot) {
-               largepage_idx = slot_largepage_idx(gfn, slot, level);
-               return *largepage_idx;
+               linfo = lpage_info_slot(gfn, slot, level);
+               return linfo->write_count;
        }
 
        return 1;
@@ -518,14 +549,18 @@ static int host_mapping_level(struct kvm *kvm, gfn_t gfn)
        return ret;
 }
 
-static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn)
+static bool mapping_level_dirty_bitmap(struct kvm_vcpu *vcpu, gfn_t large_gfn)
 {
        struct kvm_memory_slot *slot;
-       int host_level, level, max_level;
-
        slot = gfn_to_memslot(vcpu->kvm, large_gfn);
        if (slot && slot->dirty_bitmap)
-               return PT_PAGE_TABLE_LEVEL;
+               return true;
+       return false;
+}
+
+static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn)
+{
+       int host_level, level, max_level;
 
        host_level = host_mapping_level(vcpu->kvm, large_gfn);
 
@@ -549,16 +584,15 @@ static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn)
 static unsigned long *gfn_to_rmap(struct kvm *kvm, gfn_t gfn, int level)
 {
        struct kvm_memory_slot *slot;
-       unsigned long idx;
+       struct kvm_lpage_info *linfo;
 
        slot = gfn_to_memslot(kvm, gfn);
        if (likely(level == PT_PAGE_TABLE_LEVEL))
                return &slot->rmap[gfn - slot->base_gfn];
 
-       idx = (gfn >> KVM_HPAGE_GFN_SHIFT(level)) -
-               (slot->base_gfn >> KVM_HPAGE_GFN_SHIFT(level));
+       linfo = lpage_info_slot(gfn, slot, level);
 
-       return &slot->lpage_info[level - 2][idx].rmap_pde;
+       return &linfo->rmap_pde;
 }
 
 /*
@@ -595,6 +629,7 @@ static int rmap_add(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
                desc->sptes[0] = (u64 *)*rmapp;
                desc->sptes[1] = spte;
                *rmapp = (unsigned long)desc | 1;
+               ++count;
        } else {
                rmap_printk("rmap_add: %p %llx many->many\n", spte, *spte);
                desc = (struct kvm_rmap_desc *)(*rmapp & ~1ul);
@@ -607,7 +642,7 @@ static int rmap_add(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
                        desc = desc->more;
                }
                for (i = 0; desc->sptes[i]; ++i)
-                       ;
+                       ++count;
                desc->sptes[i] = spte;
        }
        return count;
@@ -649,18 +684,17 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
        gfn = kvm_mmu_page_get_gfn(sp, spte - sp->spt);
        rmapp = gfn_to_rmap(kvm, gfn, sp->role.level);
        if (!*rmapp) {
-               printk(KERN_ERR "rmap_remove: %p %llx 0->BUG\n", spte, *spte);
+               printk(KERN_ERR "rmap_remove: %p 0->BUG\n", spte);
                BUG();
        } else if (!(*rmapp & 1)) {
-               rmap_printk("rmap_remove:  %p %llx 1->0\n", spte, *spte);
+               rmap_printk("rmap_remove:  %p 1->0\n", spte);
                if ((u64 *)*rmapp != spte) {
-                       printk(KERN_ERR "rmap_remove:  %p %llx 1->BUG\n",
-                              spte, *spte);
+                       printk(KERN_ERR "rmap_remove:  %p 1->BUG\n", spte);
                        BUG();
                }
                *rmapp = 0;
        } else {
-               rmap_printk("rmap_remove:  %p %llx many->many\n", spte, *spte);
+               rmap_printk("rmap_remove:  %p many->many\n", spte);
                desc = (struct kvm_rmap_desc *)(*rmapp & ~1ul);
                prev_desc = NULL;
                while (desc) {
@@ -674,30 +708,36 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
                        prev_desc = desc;
                        desc = desc->more;
                }
-               pr_err("rmap_remove: %p %llx many->many\n", spte, *spte);
+               pr_err("rmap_remove: %p many->many\n", spte);
                BUG();
        }
 }
 
-static void set_spte_track_bits(u64 *sptep, u64 new_spte)
+static int set_spte_track_bits(u64 *sptep, u64 new_spte)
 {
        pfn_t pfn;
-       u64 old_spte;
+       u64 old_spte = *sptep;
+
+       if (!spte_has_volatile_bits(old_spte))
+               __set_spte(sptep, new_spte);
+       else
+               old_spte = __xchg_spte(sptep, new_spte);
 
-       old_spte = __xchg_spte(sptep, new_spte);
        if (!is_rmap_spte(old_spte))
-               return;
+               return 0;
+
        pfn = spte_to_pfn(old_spte);
        if (!shadow_accessed_mask || old_spte & shadow_accessed_mask)
                kvm_set_pfn_accessed(pfn);
-       if (is_writable_pte(old_spte))
+       if (!shadow_dirty_mask || (old_spte & shadow_dirty_mask))
                kvm_set_pfn_dirty(pfn);
+       return 1;
 }
 
 static void drop_spte(struct kvm *kvm, u64 *sptep, u64 new_spte)
 {
-       set_spte_track_bits(sptep, new_spte);
-       rmap_remove(kvm, sptep);
+       if (set_spte_track_bits(sptep, new_spte))
+               rmap_remove(kvm, sptep);
 }
 
 static u64 *rmap_next(struct kvm *kvm, unsigned long *rmapp, u64 *spte)
@@ -745,13 +785,6 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
                }
                spte = rmap_next(kvm, rmapp, spte);
        }
-       if (write_protected) {
-               pfn_t pfn;
-
-               spte = rmap_next(kvm, rmapp, NULL);
-               pfn = spte_to_pfn(*spte);
-               kvm_set_pfn_dirty(pfn);
-       }
 
        /* check for huge page mappings */
        for (i = PT_DIRECTORY_LEVEL;
@@ -847,19 +880,16 @@ static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
                end = start + (memslot->npages << PAGE_SHIFT);
                if (hva >= start && hva < end) {
                        gfn_t gfn_offset = (hva - start) >> PAGE_SHIFT;
+                       gfn_t gfn = memslot->base_gfn + gfn_offset;
 
                        ret = handler(kvm, &memslot->rmap[gfn_offset], data);
 
                        for (j = 0; j < KVM_NR_PAGE_SIZES - 1; ++j) {
-                               unsigned long idx;
-                               int sh;
-
-                               sh = KVM_HPAGE_GFN_SHIFT(PT_DIRECTORY_LEVEL+j);
-                               idx = ((memslot->base_gfn+gfn_offset) >> sh) -
-                                       (memslot->base_gfn >> sh);
-                               ret |= handler(kvm,
-                                       &memslot->lpage_info[j][idx].rmap_pde,
-                                       data);
+                               struct kvm_lpage_info *linfo;
+
+                               linfo = lpage_info_slot(gfn, memslot,
+                                                       PT_DIRECTORY_LEVEL + j);
+                               ret |= handler(kvm, &linfo->rmap_pde, data);
                        }
                        trace_kvm_age_page(hva, memslot, ret);
                        retval |= ret;
@@ -910,6 +940,35 @@ static int kvm_age_rmapp(struct kvm *kvm, unsigned long *rmapp,
        return young;
 }
 
+static int kvm_test_age_rmapp(struct kvm *kvm, unsigned long *rmapp,
+                             unsigned long data)
+{
+       u64 *spte;
+       int young = 0;
+
+       /*
+        * If there's no access bit in the secondary pte set by the
+        * hardware it's up to gup-fast/gup to set the access bit in
+        * the primary pte or in the page structure.
+        */
+       if (!shadow_accessed_mask)
+               goto out;
+
+       spte = rmap_next(kvm, rmapp, NULL);
+       while (spte) {
+               u64 _spte = *spte;
+               BUG_ON(!(_spte & PT_PRESENT_MASK));
+               young = _spte & PT_ACCESSED_MASK;
+               if (young) {
+                       young = 1;
+                       break;
+               }
+               spte = rmap_next(kvm, rmapp, spte);
+       }
+out:
+       return young;
+}
+
 #define RMAP_RECYCLE_THRESHOLD 1000
 
 static void rmap_recycle(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
@@ -930,6 +989,11 @@ int kvm_age_hva(struct kvm *kvm, unsigned long hva)
        return kvm_handle_hva(kvm, hva, 0, kvm_age_rmapp);
 }
 
+int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
+{
+       return kvm_handle_hva(kvm, hva, 0, kvm_test_age_rmapp);
+}
+
 #ifdef MMU_DEBUG
 static int is_empty_shadow_page(u64 *spt)
 {
@@ -946,16 +1010,28 @@ static int is_empty_shadow_page(u64 *spt)
 }
 #endif
 
+/*
+ * This value is the sum of all of the kvm instances's
+ * kvm->arch.n_used_mmu_pages values.  We need a global,
+ * aggregate version in order to make the slab shrinker
+ * faster
+ */
+static inline void kvm_mod_used_mmu_pages(struct kvm *kvm, int nr)
+{
+       kvm->arch.n_used_mmu_pages += nr;
+       percpu_counter_add(&kvm_total_used_mmu_pages, nr);
+}
+
 static void kvm_mmu_free_page(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        ASSERT(is_empty_shadow_page(sp->spt));
        hlist_del(&sp->hash_link);
        list_del(&sp->link);
-       __free_page(virt_to_page(sp->spt));
+       free_page((unsigned long)sp->spt);
        if (!sp->role.direct)
-               __free_page(virt_to_page(sp->gfns));
+               free_page((unsigned long)sp->gfns);
        kmem_cache_free(mmu_page_header_cache, sp);
-       ++kvm->arch.n_free_mmu_pages;
+       kvm_mod_used_mmu_pages(kvm, -1);
 }
 
 static unsigned kvm_page_table_hashfn(gfn_t gfn)
@@ -978,7 +1054,7 @@ static struct kvm_mmu_page *kvm_mmu_alloc_page(struct kvm_vcpu *vcpu,
        bitmap_zero(sp->slot_bitmap, KVM_MEMORY_SLOTS + KVM_PRIVATE_MEM_SLOTS);
        sp->multimapped = 0;
        sp->parent_pte = parent_pte;
-       --vcpu->kvm->arch.n_free_mmu_pages;
+       kvm_mod_used_mmu_pages(vcpu->kvm, +1);
        return sp;
 }
 
@@ -1109,7 +1185,7 @@ static void nonpaging_prefetch_page(struct kvm_vcpu *vcpu,
 }
 
 static int nonpaging_sync_page(struct kvm_vcpu *vcpu,
-                              struct kvm_mmu_page *sp, bool clear_unsync)
+                              struct kvm_mmu_page *sp)
 {
        return 1;
 }
@@ -1239,7 +1315,7 @@ static int __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
        if (clear_unsync)
                kvm_unlink_unsync_page(vcpu->kvm, sp);
 
-       if (vcpu->arch.mmu.sync_page(vcpu, sp, clear_unsync)) {
+       if (vcpu->arch.mmu.sync_page(vcpu, sp)) {
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
                return 1;
        }
@@ -1280,12 +1356,12 @@ static void kvm_sync_pages(struct kvm_vcpu *vcpu,  gfn_t gfn)
                        continue;
 
                WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
+               kvm_unlink_unsync_page(vcpu->kvm, s);
                if ((s->role.cr4_pae != !!is_pae(vcpu)) ||
-                       (vcpu->arch.mmu.sync_page(vcpu, s, true))) {
+                       (vcpu->arch.mmu.sync_page(vcpu, s))) {
                        kvm_mmu_prepare_zap_page(vcpu->kvm, s, &invalid_list);
                        continue;
                }
-               kvm_unlink_unsync_page(vcpu->kvm, s);
                flush = true;
        }
 
@@ -1402,7 +1478,8 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        if (role.direct)
                role.cr4_pae = 0;
        role.access = access;
-       if (!tdp_enabled && vcpu->arch.mmu.root_level <= PT32_ROOT_LEVEL) {
+       if (!vcpu->arch.mmu.direct_map
+           && vcpu->arch.mmu.root_level <= PT32_ROOT_LEVEL) {
                quadrant = gaddr >> (PAGE_SHIFT + (PT64_PT_BITS * level));
                quadrant &= (1 << ((PT32_PT_BITS - PT64_PT_BITS) * level)) - 1;
                role.quadrant = quadrant;
@@ -1457,6 +1534,12 @@ static void shadow_walk_init(struct kvm_shadow_walk_iterator *iterator,
        iterator->addr = addr;
        iterator->shadow_addr = vcpu->arch.mmu.root_hpa;
        iterator->level = vcpu->arch.mmu.shadow_root_level;
+
+       if (iterator->level == PT64_ROOT_LEVEL &&
+           vcpu->arch.mmu.root_level < PT64_ROOT_LEVEL &&
+           !vcpu->arch.mmu.direct_map)
+               --iterator->level;
+
        if (iterator->level == PT32E_ROOT_LEVEL) {
                iterator->shadow_addr
                        = vcpu->arch.mmu.pae_root[(addr >> 30) & 3];
@@ -1664,41 +1747,31 @@ static void kvm_mmu_commit_zap_page(struct kvm *kvm,
 
 /*
  * Changing the number of mmu pages allocated to the vm
- * Note: if kvm_nr_mmu_pages is too small, you will get dead lock
+ * Note: if goal_nr_mmu_pages is too small, you will get dead lock
  */
-void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages)
+void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int goal_nr_mmu_pages)
 {
-       int used_pages;
        LIST_HEAD(invalid_list);
-
-       used_pages = kvm->arch.n_alloc_mmu_pages - kvm->arch.n_free_mmu_pages;
-       used_pages = max(0, used_pages);
-
        /*
         * If we set the number of mmu pages to be smaller be than the
         * number of actived pages , we must to free some mmu pages before we
         * change the value
         */
 
-       if (used_pages > kvm_nr_mmu_pages) {
-               while (used_pages > kvm_nr_mmu_pages &&
+       if (kvm->arch.n_used_mmu_pages > goal_nr_mmu_pages) {
+               while (kvm->arch.n_used_mmu_pages > goal_nr_mmu_pages &&
                        !list_empty(&kvm->arch.active_mmu_pages)) {
                        struct kvm_mmu_page *page;
 
                        page = container_of(kvm->arch.active_mmu_pages.prev,
                                            struct kvm_mmu_page, link);
-                       used_pages -= kvm_mmu_prepare_zap_page(kvm, page,
-                                                              &invalid_list);
+                       kvm_mmu_prepare_zap_page(kvm, page, &invalid_list);
+                       kvm_mmu_commit_zap_page(kvm, &invalid_list);
                }
-               kvm_mmu_commit_zap_page(kvm, &invalid_list);
-               kvm_nr_mmu_pages = used_pages;
-               kvm->arch.n_free_mmu_pages = 0;
+               goal_nr_mmu_pages = kvm->arch.n_used_mmu_pages;
        }
-       else
-               kvm->arch.n_free_mmu_pages += kvm_nr_mmu_pages
-                                        - kvm->arch.n_alloc_mmu_pages;
 
-       kvm->arch.n_alloc_mmu_pages = kvm_nr_mmu_pages;
+       kvm->arch.n_max_mmu_pages = goal_nr_mmu_pages;
 }
 
 static int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
@@ -1708,11 +1781,11 @@ static int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
        LIST_HEAD(invalid_list);
        int r;
 
-       pgprintk("%s: looking for gfn %lx\n", __func__, gfn);
+       pgprintk("%s: looking for gfn %llx\n", __func__, gfn);
        r = 0;
 
        for_each_gfn_indirect_valid_sp(kvm, sp, gfn, node) {
-               pgprintk("%s: gfn %lx role %x\n", __func__, gfn,
+               pgprintk("%s: gfn %llx role %x\n", __func__, gfn,
                         sp->role.word);
                r = 1;
                kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list);
@@ -1728,7 +1801,7 @@ static void mmu_unshadow(struct kvm *kvm, gfn_t gfn)
        LIST_HEAD(invalid_list);
 
        for_each_gfn_indirect_valid_sp(kvm, sp, gfn, node) {
-               pgprintk("%s: zap %lx %x\n",
+               pgprintk("%s: zap %llx %x\n",
                         __func__, gfn, sp->role.word);
                kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list);
        }
@@ -1914,9 +1987,9 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                    unsigned pte_access, int user_fault,
                    int write_fault, int dirty, int level,
                    gfn_t gfn, pfn_t pfn, bool speculative,
-                   bool can_unsync, bool reset_host_protection)
+                   bool can_unsync, bool host_writable)
 {
-       u64 spte;
+       u64 spte, entry = *sptep;
        int ret = 0;
 
        /*
@@ -1924,7 +1997,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
         * whether the guest actually used the pte (in order to detect
         * demand paging).
         */
-       spte = shadow_base_present_pte | shadow_dirty_mask;
+       spte = PT_PRESENT_MASK;
        if (!speculative)
                spte |= shadow_accessed_mask;
        if (!dirty)
@@ -1941,14 +2014,16 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                spte |= kvm_x86_ops->get_mt_mask(vcpu, gfn,
                        kvm_is_mmio_pfn(pfn));
 
-       if (reset_host_protection)
+       if (host_writable)
                spte |= SPTE_HOST_WRITEABLE;
+       else
+               pte_access &= ~ACC_WRITE_MASK;
 
        spte |= (u64)pfn << PAGE_SHIFT;
 
        if ((pte_access & ACC_WRITE_MASK)
-           || (!tdp_enabled && write_fault && !is_write_protection(vcpu)
-               && !user_fault)) {
+           || (!vcpu->arch.mmu.direct_map && write_fault
+               && !is_write_protection(vcpu) && !user_fault)) {
 
                if (level > PT_PAGE_TABLE_LEVEL &&
                    has_wrprotected_page(vcpu->kvm, gfn, level)) {
@@ -1959,7 +2034,8 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 
                spte |= PT_WRITABLE_MASK;
 
-               if (!tdp_enabled && !(pte_access & ACC_WRITE_MASK))
+               if (!vcpu->arch.mmu.direct_map
+                   && !(pte_access & ACC_WRITE_MASK))
                        spte &= ~PT_USER_MASK;
 
                /*
@@ -1972,7 +2048,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                        goto set_pte;
 
                if (mmu_need_write_protect(vcpu, gfn, can_unsync)) {
-                       pgprintk("%s: found shadow page for %lx, marking ro\n",
+                       pgprintk("%s: found shadow page for %llx, marking ro\n",
                                 __func__, gfn);
                        ret = 1;
                        pte_access &= ~ACC_WRITE_MASK;
@@ -1985,9 +2061,15 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                mark_page_dirty(vcpu->kvm, gfn);
 
 set_pte:
-       if (is_writable_pte(*sptep) && !is_writable_pte(spte))
-               kvm_set_pfn_dirty(pfn);
        update_spte(sptep, spte);
+       /*
+        * If we overwrite a writable spte with a read-only one we
+        * should flush remote TLBs. Otherwise rmap_write_protect
+        * will find a read-only spte, even though the writable spte
+        * might be cached on a CPU's TLB.
+        */
+       if (is_writable_pte(entry) && !is_writable_pte(*sptep))
+               kvm_flush_remote_tlbs(vcpu->kvm);
 done:
        return ret;
 }
@@ -1997,13 +2079,13 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                         int user_fault, int write_fault, int dirty,
                         int *ptwrite, int level, gfn_t gfn,
                         pfn_t pfn, bool speculative,
-                        bool reset_host_protection)
+                        bool host_writable)
 {
        int was_rmapped = 0;
        int rmap_count;
 
        pgprintk("%s: spte %llx access %x write_fault %d"
-                " user_fault %d gfn %lx\n",
+                " user_fault %d gfn %llx\n",
                 __func__, *sptep, pt_access,
                 write_fault, user_fault, gfn);
 
@@ -2022,7 +2104,7 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                        __set_spte(sptep, shadow_trap_nonpresent_pte);
                        kvm_flush_remote_tlbs(vcpu->kvm);
                } else if (pfn != spte_to_pfn(*sptep)) {
-                       pgprintk("hfn old %lx new %lx\n",
+                       pgprintk("hfn old %llx new %llx\n",
                                 spte_to_pfn(*sptep), pfn);
                        drop_spte(vcpu->kvm, sptep, shadow_trap_nonpresent_pte);
                        kvm_flush_remote_tlbs(vcpu->kvm);
@@ -2032,14 +2114,14 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 
        if (set_spte(vcpu, sptep, pte_access, user_fault, write_fault,
                      dirty, level, gfn, pfn, speculative, true,
-                     reset_host_protection)) {
+                     host_writable)) {
                if (write_fault)
                        *ptwrite = 1;
                kvm_mmu_flush_tlb(vcpu);
        }
 
        pgprintk("%s: setting spte %llx\n", __func__, *sptep);
-       pgprintk("instantiating %s PTE (%s) at %ld (%llx) addr %p\n",
+       pgprintk("instantiating %s PTE (%s) at %llx (%llx) addr %p\n",
                 is_large_pte(*sptep)? "2MB" : "4kB",
                 *sptep & PT_PRESENT_MASK ?"RW":"R", gfn,
                 *sptep, sptep);
@@ -2063,8 +2145,108 @@ static void nonpaging_new_cr3(struct kvm_vcpu *vcpu)
 {
 }
 
+static struct kvm_memory_slot *
+pte_prefetch_gfn_to_memslot(struct kvm_vcpu *vcpu, gfn_t gfn, bool no_dirty_log)
+{
+       struct kvm_memory_slot *slot;
+
+       slot = gfn_to_memslot(vcpu->kvm, gfn);
+       if (!slot || slot->flags & KVM_MEMSLOT_INVALID ||
+             (no_dirty_log && slot->dirty_bitmap))
+               slot = NULL;
+
+       return slot;
+}
+
+static pfn_t pte_prefetch_gfn_to_pfn(struct kvm_vcpu *vcpu, gfn_t gfn,
+                                    bool no_dirty_log)
+{
+       struct kvm_memory_slot *slot;
+       unsigned long hva;
+
+       slot = pte_prefetch_gfn_to_memslot(vcpu, gfn, no_dirty_log);
+       if (!slot) {
+               get_page(bad_page);
+               return page_to_pfn(bad_page);
+       }
+
+       hva = gfn_to_hva_memslot(slot, gfn);
+
+       return hva_to_pfn_atomic(vcpu->kvm, hva);
+}
+
+static int direct_pte_prefetch_many(struct kvm_vcpu *vcpu,
+                                   struct kvm_mmu_page *sp,
+                                   u64 *start, u64 *end)
+{
+       struct page *pages[PTE_PREFETCH_NUM];
+       unsigned access = sp->role.access;
+       int i, ret;
+       gfn_t gfn;
+
+       gfn = kvm_mmu_page_get_gfn(sp, start - sp->spt);
+       if (!pte_prefetch_gfn_to_memslot(vcpu, gfn, access & ACC_WRITE_MASK))
+               return -1;
+
+       ret = gfn_to_page_many_atomic(vcpu->kvm, gfn, pages, end - start);
+       if (ret <= 0)
+               return -1;
+
+       for (i = 0; i < ret; i++, gfn++, start++)
+               mmu_set_spte(vcpu, start, ACC_ALL,
+                            access, 0, 0, 1, NULL,
+                            sp->role.level, gfn,
+                            page_to_pfn(pages[i]), true, true);
+
+       return 0;
+}
+
+static void __direct_pte_prefetch(struct kvm_vcpu *vcpu,
+                                 struct kvm_mmu_page *sp, u64 *sptep)
+{
+       u64 *spte, *start = NULL;
+       int i;
+
+       WARN_ON(!sp->role.direct);
+
+       i = (sptep - sp->spt) & ~(PTE_PREFETCH_NUM - 1);
+       spte = sp->spt + i;
+
+       for (i = 0; i < PTE_PREFETCH_NUM; i++, spte++) {
+               if (*spte != shadow_trap_nonpresent_pte || spte == sptep) {
+                       if (!start)
+                               continue;
+                       if (direct_pte_prefetch_many(vcpu, sp, start, spte) < 0)
+                               break;
+                       start = NULL;
+               } else if (!start)
+                       start = spte;
+       }
+}
+
+static void direct_pte_prefetch(struct kvm_vcpu *vcpu, u64 *sptep)
+{
+       struct kvm_mmu_page *sp;
+
+       /*
+        * Since it's no accessed bit on EPT, it's no way to
+        * distinguish between actually accessed translations
+        * and prefetched, so disable pte prefetch if EPT is
+        * enabled.
+        */
+       if (!shadow_accessed_mask)
+               return;
+
+       sp = page_header(__pa(sptep));
+       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+               return;
+
+       __direct_pte_prefetch(vcpu, sp, sptep);
+}
+
 static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
-                       int level, gfn_t gfn, pfn_t pfn)
+                       int map_writable, int level, gfn_t gfn, pfn_t pfn,
+                       bool prefault)
 {
        struct kvm_shadow_walk_iterator iterator;
        struct kvm_mmu_page *sp;
@@ -2073,9 +2255,12 @@ static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
 
        for_each_shadow_entry(vcpu, (u64)gfn << PAGE_SHIFT, iterator) {
                if (iterator.level == level) {
-                       mmu_set_spte(vcpu, iterator.sptep, ACC_ALL, ACC_ALL,
+                       unsigned pte_access = ACC_ALL;
+
+                       mmu_set_spte(vcpu, iterator.sptep, ACC_ALL, pte_access,
                                     0, write, 1, &pt_write,
-                                    level, gfn, pfn, false, true);
+                                    level, gfn, pfn, prefault, map_writable);
+                       direct_pte_prefetch(vcpu, iterator.sptep);
                        ++vcpu->stat.pf_fixed;
                        break;
                }
@@ -2097,28 +2282,31 @@ static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
                        __set_spte(iterator.sptep,
                                   __pa(sp->spt)
                                   | PT_PRESENT_MASK | PT_WRITABLE_MASK
-                                  | shadow_user_mask | shadow_x_mask);
+                                  | shadow_user_mask | shadow_x_mask
+                                  | shadow_accessed_mask);
                }
        }
        return pt_write;
 }
 
-static void kvm_send_hwpoison_signal(struct kvm *kvm, gfn_t gfn)
+static void kvm_send_hwpoison_signal(unsigned long address, struct task_struct *tsk)
 {
-       char buf[1];
-       void __user *hva;
-       int r;
+       siginfo_t info;
 
-       /* Touch the page, so send SIGBUS */
-       hva = (void __user *)gfn_to_hva(kvm, gfn);
-       r = copy_from_user(buf, hva, 1);
+       info.si_signo   = SIGBUS;
+       info.si_errno   = 0;
+       info.si_code    = BUS_MCEERR_AR;
+       info.si_addr    = (void __user *)address;
+       info.si_addr_lsb = PAGE_SHIFT;
+
+       send_sig_info(SIGBUS, &info, tsk);
 }
 
 static int kvm_handle_bad_page(struct kvm *kvm, gfn_t gfn, pfn_t pfn)
 {
        kvm_release_pfn_clean(pfn);
        if (is_hwpoison_pfn(pfn)) {
-               kvm_send_hwpoison_signal(kvm, gfn);
+               kvm_send_hwpoison_signal(gfn_to_hva(kvm, gfn), current);
                return 0;
        } else if (is_fault_pfn(pfn))
                return -EFAULT;
@@ -2126,27 +2314,81 @@ static int kvm_handle_bad_page(struct kvm *kvm, gfn_t gfn, pfn_t pfn)
        return 1;
 }
 
-static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, int write, gfn_t gfn)
+static void transparent_hugepage_adjust(struct kvm_vcpu *vcpu,
+                                       gfn_t *gfnp, pfn_t *pfnp, int *levelp)
+{
+       pfn_t pfn = *pfnp;
+       gfn_t gfn = *gfnp;
+       int level = *levelp;
+
+       /*
+        * Check if it's a transparent hugepage. If this would be an
+        * hugetlbfs page, level wouldn't be set to
+        * PT_PAGE_TABLE_LEVEL and there would be no adjustment done
+        * here.
+        */
+       if (!is_error_pfn(pfn) && !kvm_is_mmio_pfn(pfn) &&
+           level == PT_PAGE_TABLE_LEVEL &&
+           PageTransCompound(pfn_to_page(pfn)) &&
+           !has_wrprotected_page(vcpu->kvm, gfn, PT_DIRECTORY_LEVEL)) {
+               unsigned long mask;
+               /*
+                * mmu_notifier_retry was successful and we hold the
+                * mmu_lock here, so the pmd can't become splitting
+                * from under us, and in turn
+                * __split_huge_page_refcount() can't run from under
+                * us and we can safely transfer the refcount from
+                * PG_tail to PG_head as we switch the pfn to tail to
+                * head.
+                */
+               *levelp = level = PT_DIRECTORY_LEVEL;
+               mask = KVM_PAGES_PER_HPAGE(level) - 1;
+               VM_BUG_ON((gfn & mask) != (pfn & mask));
+               if (pfn & mask) {
+                       gfn &= ~mask;
+                       *gfnp = gfn;
+                       kvm_release_pfn_clean(pfn);
+                       pfn &= ~mask;
+                       if (!get_page_unless_zero(pfn_to_page(pfn)))
+                               BUG();
+                       *pfnp = pfn;
+               }
+       }
+}
+
+static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
+                        gva_t gva, pfn_t *pfn, bool write, bool *writable);
+
+static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, int write, gfn_t gfn,
+                        bool prefault)
 {
        int r;
        int level;
+       int force_pt_level;
        pfn_t pfn;
        unsigned long mmu_seq;
+       bool map_writable;
 
-       level = mapping_level(vcpu, gfn);
-
-       /*
-        * This path builds a PAE pagetable - so we can map 2mb pages at
-        * maximum. Therefore check if the level is larger than that.
-        */
-       if (level > PT_DIRECTORY_LEVEL)
-               level = PT_DIRECTORY_LEVEL;
+       force_pt_level = mapping_level_dirty_bitmap(vcpu, gfn);
+       if (likely(!force_pt_level)) {
+               level = mapping_level(vcpu, gfn);
+               /*
+                * This path builds a PAE pagetable - so we can map
+                * 2mb pages at maximum. Therefore check if the level
+                * is larger than that.
+                */
+               if (level > PT_DIRECTORY_LEVEL)
+                       level = PT_DIRECTORY_LEVEL;
 
-       gfn &= ~(KVM_PAGES_PER_HPAGE(level) - 1);
+               gfn &= ~(KVM_PAGES_PER_HPAGE(level) - 1);
+       } else
+               level = PT_PAGE_TABLE_LEVEL;
 
        mmu_seq = vcpu->kvm->mmu_notifier_seq;
        smp_rmb();
-       pfn = gfn_to_pfn(vcpu->kvm, gfn);
+
+       if (try_async_pf(vcpu, prefault, gfn, v, &pfn, write, &map_writable))
+               return 0;
 
        /* mmio */
        if (is_error_pfn(pfn))
@@ -2156,7 +2398,10 @@ static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, int write, gfn_t gfn)
        if (mmu_notifier_retry(vcpu, mmu_seq))
                goto out_unlock;
        kvm_mmu_free_some_pages(vcpu);
-       r = __direct_map(vcpu, v, write, level, gfn, pfn);
+       if (likely(!force_pt_level))
+               transparent_hugepage_adjust(vcpu, &gfn, &pfn, &level);
+       r = __direct_map(vcpu, v, write, map_writable, level, gfn, pfn,
+                        prefault);
        spin_unlock(&vcpu->kvm->mmu_lock);
 
 
@@ -2178,7 +2423,9 @@ static void mmu_free_roots(struct kvm_vcpu *vcpu)
        if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
                return;
        spin_lock(&vcpu->kvm->mmu_lock);
-       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_LEVEL) {
+       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_LEVEL &&
+           (vcpu->arch.mmu.root_level == PT64_ROOT_LEVEL ||
+            vcpu->arch.mmu.direct_map)) {
                hpa_t root = vcpu->arch.mmu.root_hpa;
 
                sp = page_header(root);
@@ -2221,83 +2468,163 @@ static int mmu_check_root(struct kvm_vcpu *vcpu, gfn_t root_gfn)
        return ret;
 }
 
-static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
+static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
 {
-       int i;
-       gfn_t root_gfn;
        struct kvm_mmu_page *sp;
-       int direct = 0;
-       u64 pdptr;
-
-       root_gfn = vcpu->arch.cr3 >> PAGE_SHIFT;
+       unsigned i;
 
        if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_LEVEL) {
+               spin_lock(&vcpu->kvm->mmu_lock);
+               kvm_mmu_free_some_pages(vcpu);
+               sp = kvm_mmu_get_page(vcpu, 0, 0, PT64_ROOT_LEVEL,
+                                     1, ACC_ALL, NULL);
+               ++sp->root_count;
+               spin_unlock(&vcpu->kvm->mmu_lock);
+               vcpu->arch.mmu.root_hpa = __pa(sp->spt);
+       } else if (vcpu->arch.mmu.shadow_root_level == PT32E_ROOT_LEVEL) {
+               for (i = 0; i < 4; ++i) {
+                       hpa_t root = vcpu->arch.mmu.pae_root[i];
+
+                       ASSERT(!VALID_PAGE(root));
+                       spin_lock(&vcpu->kvm->mmu_lock);
+                       kvm_mmu_free_some_pages(vcpu);
+                       sp = kvm_mmu_get_page(vcpu, i << (30 - PAGE_SHIFT),
+                                             i << 30,
+                                             PT32_ROOT_LEVEL, 1, ACC_ALL,
+                                             NULL);
+                       root = __pa(sp->spt);
+                       ++sp->root_count;
+                       spin_unlock(&vcpu->kvm->mmu_lock);
+                       vcpu->arch.mmu.pae_root[i] = root | PT_PRESENT_MASK;
+               }
+               vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.pae_root);
+       } else
+               BUG();
+
+       return 0;
+}
+
+static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
+{
+       struct kvm_mmu_page *sp;
+       u64 pdptr, pm_mask;
+       gfn_t root_gfn;
+       int i;
+
+       root_gfn = vcpu->arch.mmu.get_cr3(vcpu) >> PAGE_SHIFT;
+
+       if (mmu_check_root(vcpu, root_gfn))
+               return 1;
+
+       /*
+        * Do we shadow a long mode page table? If so we need to
+        * write-protect the guests page table root.
+        */
+       if (vcpu->arch.mmu.root_level == PT64_ROOT_LEVEL) {
                hpa_t root = vcpu->arch.mmu.root_hpa;
 
                ASSERT(!VALID_PAGE(root));
-               if (mmu_check_root(vcpu, root_gfn))
-                       return 1;
-               if (tdp_enabled) {
-                       direct = 1;
-                       root_gfn = 0;
-               }
+
                spin_lock(&vcpu->kvm->mmu_lock);
                kvm_mmu_free_some_pages(vcpu);
-               sp = kvm_mmu_get_page(vcpu, root_gfn, 0,
-                                     PT64_ROOT_LEVEL, direct,
-                                     ACC_ALL, NULL);
+               sp = kvm_mmu_get_page(vcpu, root_gfn, 0, PT64_ROOT_LEVEL,
+                                     0, ACC_ALL, NULL);
                root = __pa(sp->spt);
                ++sp->root_count;
                spin_unlock(&vcpu->kvm->mmu_lock);
                vcpu->arch.mmu.root_hpa = root;
                return 0;
        }
-       direct = !is_paging(vcpu);
+
+       /*
+        * We shadow a 32 bit page table. This may be a legacy 2-level
+        * or a PAE 3-level page table. In either case we need to be aware that
+        * the shadow page table may be a PAE or a long mode page table.
+        */
+       pm_mask = PT_PRESENT_MASK;
+       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_LEVEL)
+               pm_mask |= PT_ACCESSED_MASK | PT_WRITABLE_MASK | PT_USER_MASK;
+
        for (i = 0; i < 4; ++i) {
                hpa_t root = vcpu->arch.mmu.pae_root[i];
 
                ASSERT(!VALID_PAGE(root));
                if (vcpu->arch.mmu.root_level == PT32E_ROOT_LEVEL) {
-                       pdptr = kvm_pdptr_read(vcpu, i);
+                       pdptr = kvm_pdptr_read_mmu(vcpu, &vcpu->arch.mmu, i);
                        if (!is_present_gpte(pdptr)) {
                                vcpu->arch.mmu.pae_root[i] = 0;
                                continue;
                        }
                        root_gfn = pdptr >> PAGE_SHIFT;
-               } else if (vcpu->arch.mmu.root_level == 0)
-                       root_gfn = 0;
-               if (mmu_check_root(vcpu, root_gfn))
-                       return 1;
-               if (tdp_enabled) {
-                       direct = 1;
-                       root_gfn = i << 30;
+                       if (mmu_check_root(vcpu, root_gfn))
+                               return 1;
                }
                spin_lock(&vcpu->kvm->mmu_lock);
                kvm_mmu_free_some_pages(vcpu);
                sp = kvm_mmu_get_page(vcpu, root_gfn, i << 30,
-                                     PT32_ROOT_LEVEL, direct,
+                                     PT32_ROOT_LEVEL, 0,
                                      ACC_ALL, NULL);
                root = __pa(sp->spt);
                ++sp->root_count;
                spin_unlock(&vcpu->kvm->mmu_lock);
 
-               vcpu->arch.mmu.pae_root[i] = root | PT_PRESENT_MASK;
+               vcpu->arch.mmu.pae_root[i] = root | pm_mask;
        }
        vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.pae_root);
+
+       /*
+        * If we shadow a 32 bit page table with a long mode page
+        * table we enter this path.
+        */
+       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_LEVEL) {
+               if (vcpu->arch.mmu.lm_root == NULL) {
+                       /*
+                        * The additional page necessary for this is only
+                        * allocated on demand.
+                        */
+
+                       u64 *lm_root;
+
+                       lm_root = (void*)get_zeroed_page(GFP_KERNEL);
+                       if (lm_root == NULL)
+                               return 1;
+
+                       lm_root[0] = __pa(vcpu->arch.mmu.pae_root) | pm_mask;
+
+                       vcpu->arch.mmu.lm_root = lm_root;
+               }
+
+               vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.lm_root);
+       }
+
        return 0;
 }
 
+static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
+{
+       if (vcpu->arch.mmu.direct_map)
+               return mmu_alloc_direct_roots(vcpu);
+       else
+               return mmu_alloc_shadow_roots(vcpu);
+}
+
 static void mmu_sync_roots(struct kvm_vcpu *vcpu)
 {
        int i;
        struct kvm_mmu_page *sp;
 
+       if (vcpu->arch.mmu.direct_map)
+               return;
+
        if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
                return;
-       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_LEVEL) {
+
+       trace_kvm_mmu_audit(vcpu, AUDIT_PRE_SYNC);
+       if (vcpu->arch.mmu.root_level == PT64_ROOT_LEVEL) {
                hpa_t root = vcpu->arch.mmu.root_hpa;
                sp = page_header(root);
                mmu_sync_children(vcpu, sp);
+               trace_kvm_mmu_audit(vcpu, AUDIT_POST_SYNC);
                return;
        }
        for (i = 0; i < 4; ++i) {
@@ -2309,6 +2636,7 @@ static void mmu_sync_roots(struct kvm_vcpu *vcpu)
                        mmu_sync_children(vcpu, sp);
                }
        }
+       trace_kvm_mmu_audit(vcpu, AUDIT_POST_SYNC);
 }
 
 void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
@@ -2319,15 +2647,24 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
 }
 
 static gpa_t nonpaging_gva_to_gpa(struct kvm_vcpu *vcpu, gva_t vaddr,
-                                 u32 access, u32 *error)
+                                 u32 access, struct x86_exception *exception)
 {
-       if (error)
-               *error = 0;
+       if (exception)
+               exception->error_code = 0;
        return vaddr;
 }
 
+static gpa_t nonpaging_gva_to_gpa_nested(struct kvm_vcpu *vcpu, gva_t vaddr,
+                                        u32 access,
+                                        struct x86_exception *exception)
+{
+       if (exception)
+               exception->error_code = 0;
+       return vcpu->arch.nested_mmu.translate_gpa(vcpu, vaddr, access);
+}
+
 static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gva_t gva,
-                               u32 error_code)
+                               u32 error_code, bool prefault)
 {
        gfn_t gfn;
        int r;
@@ -2343,17 +2680,68 @@ static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gva_t gva,
        gfn = gva >> PAGE_SHIFT;
 
        return nonpaging_map(vcpu, gva & PAGE_MASK,
-                            error_code & PFERR_WRITE_MASK, gfn);
+                            error_code & PFERR_WRITE_MASK, gfn, prefault);
+}
+
+static int kvm_arch_setup_async_pf(struct kvm_vcpu *vcpu, gva_t gva, gfn_t gfn)
+{
+       struct kvm_arch_async_pf arch;
+
+       arch.token = (vcpu->arch.apf.id++ << 12) | vcpu->vcpu_id;
+       arch.gfn = gfn;
+       arch.direct_map = vcpu->arch.mmu.direct_map;
+       arch.cr3 = vcpu->arch.mmu.get_cr3(vcpu);
+
+       return kvm_setup_async_pf(vcpu, gva, gfn, &arch);
+}
+
+static bool can_do_async_pf(struct kvm_vcpu *vcpu)
+{
+       if (unlikely(!irqchip_in_kernel(vcpu->kvm) ||
+                    kvm_event_needs_reinjection(vcpu)))
+               return false;
+
+       return kvm_x86_ops->interrupt_allowed(vcpu);
 }
 
-static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa,
-                               u32 error_code)
+static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
+                        gva_t gva, pfn_t *pfn, bool write, bool *writable)
+{
+       bool async;
+
+       *pfn = gfn_to_pfn_async(vcpu->kvm, gfn, &async, write, writable);
+
+       if (!async)
+               return false; /* *pfn has correct page already */
+
+       put_page(pfn_to_page(*pfn));
+
+       if (!prefault && can_do_async_pf(vcpu)) {
+               trace_kvm_try_async_get_page(gva, gfn);
+               if (kvm_find_async_pf_gfn(vcpu, gfn)) {
+                       trace_kvm_async_pf_doublefault(gva, gfn);
+                       kvm_make_request(KVM_REQ_APF_HALT, vcpu);
+                       return true;
+               } else if (kvm_arch_setup_async_pf(vcpu, gva, gfn))
+                       return true;
+       }
+
+       *pfn = gfn_to_pfn_prot(vcpu->kvm, gfn, write, writable);
+
+       return false;
+}
+
+static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa, u32 error_code,
+                         bool prefault)
 {
        pfn_t pfn;
        int r;
        int level;
+       int force_pt_level;
        gfn_t gfn = gpa >> PAGE_SHIFT;
        unsigned long mmu_seq;
+       int write = error_code & PFERR_WRITE_MASK;
+       bool map_writable;
 
        ASSERT(vcpu);
        ASSERT(VALID_PAGE(vcpu->arch.mmu.root_hpa));
@@ -2362,21 +2750,30 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa,
        if (r)
                return r;
 
-       level = mapping_level(vcpu, gfn);
-
-       gfn &= ~(KVM_PAGES_PER_HPAGE(level) - 1);
+       force_pt_level = mapping_level_dirty_bitmap(vcpu, gfn);
+       if (likely(!force_pt_level)) {
+               level = mapping_level(vcpu, gfn);
+               gfn &= ~(KVM_PAGES_PER_HPAGE(level) - 1);
+       } else
+               level = PT_PAGE_TABLE_LEVEL;
 
        mmu_seq = vcpu->kvm->mmu_notifier_seq;
        smp_rmb();
-       pfn = gfn_to_pfn(vcpu->kvm, gfn);
+
+       if (try_async_pf(vcpu, prefault, gfn, gpa, &pfn, write, &map_writable))
+               return 0;
+
+       /* mmio */
        if (is_error_pfn(pfn))
                return kvm_handle_bad_page(vcpu->kvm, gfn, pfn);
        spin_lock(&vcpu->kvm->mmu_lock);
        if (mmu_notifier_retry(vcpu, mmu_seq))
                goto out_unlock;
        kvm_mmu_free_some_pages(vcpu);
-       r = __direct_map(vcpu, gpa, error_code & PFERR_WRITE_MASK,
-                        level, gfn, pfn);
+       if (likely(!force_pt_level))
+               transparent_hugepage_adjust(vcpu, &gfn, &pfn, &level);
+       r = __direct_map(vcpu, gpa, write, map_writable,
+                        level, gfn, pfn, prefault);
        spin_unlock(&vcpu->kvm->mmu_lock);
 
        return r;
@@ -2392,10 +2789,9 @@ static void nonpaging_free(struct kvm_vcpu *vcpu)
        mmu_free_roots(vcpu);
 }
 
-static int nonpaging_init_context(struct kvm_vcpu *vcpu)
+static int nonpaging_init_context(struct kvm_vcpu *vcpu,
+                                 struct kvm_mmu *context)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
-
        context->new_cr3 = nonpaging_new_cr3;
        context->page_fault = nonpaging_page_fault;
        context->gva_to_gpa = nonpaging_gva_to_gpa;
@@ -2406,6 +2802,8 @@ static int nonpaging_init_context(struct kvm_vcpu *vcpu)
        context->root_level = 0;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->root_hpa = INVALID_PAGE;
+       context->direct_map = true;
+       context->nx = false;
        return 0;
 }
 
@@ -2417,15 +2815,19 @@ void kvm_mmu_flush_tlb(struct kvm_vcpu *vcpu)
 
 static void paging_new_cr3(struct kvm_vcpu *vcpu)
 {
-       pgprintk("%s: cr3 %lx\n", __func__, vcpu->arch.cr3);
+       pgprintk("%s: cr3 %lx\n", __func__, kvm_read_cr3(vcpu));
        mmu_free_roots(vcpu);
 }
 
+static unsigned long get_cr3(struct kvm_vcpu *vcpu)
+{
+       return kvm_read_cr3(vcpu);
+}
+
 static void inject_page_fault(struct kvm_vcpu *vcpu,
-                             u64 addr,
-                             u32 err_code)
+                             struct x86_exception *fault)
 {
-       kvm_inject_page_fault(vcpu, addr, err_code);
+       vcpu->arch.mmu.inject_page_fault(vcpu, fault);
 }
 
 static void paging_free(struct kvm_vcpu *vcpu)
@@ -2433,12 +2835,12 @@ static void paging_free(struct kvm_vcpu *vcpu)
        nonpaging_free(vcpu);
 }
 
-static bool is_rsvd_bits_set(struct kvm_vcpu *vcpu, u64 gpte, int level)
+static bool is_rsvd_bits_set(struct kvm_mmu *mmu, u64 gpte, int level)
 {
        int bit7;
 
        bit7 = (gpte >> 7) & 1;
-       return (gpte & vcpu->arch.mmu.rsvd_bits_mask[bit7][level-1]) != 0;
+       return (gpte & mmu->rsvd_bits_mask[bit7][level-1]) != 0;
 }
 
 #define PTTYPE 64
@@ -2449,13 +2851,14 @@ static bool is_rsvd_bits_set(struct kvm_vcpu *vcpu, u64 gpte, int level)
 #include "paging_tmpl.h"
 #undef PTTYPE
 
-static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
+static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu,
+                                 struct kvm_mmu *context,
+                                 int level)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
        int maxphyaddr = cpuid_maxphyaddr(vcpu);
        u64 exb_bit_rsvd = 0;
 
-       if (!is_nx(vcpu))
+       if (!context->nx)
                exb_bit_rsvd = rsvd_bits(63, 63);
        switch (level) {
        case PT32_ROOT_LEVEL:
@@ -2510,9 +2913,13 @@ static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
        }
 }
 
-static int paging64_init_context_common(struct kvm_vcpu *vcpu, int level)
+static int paging64_init_context_common(struct kvm_vcpu *vcpu,
+                                       struct kvm_mmu *context,
+                                       int level)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       context->nx = is_nx(vcpu);
+
+       reset_rsvds_bits_mask(vcpu, context, level);
 
        ASSERT(is_pae(vcpu));
        context->new_cr3 = paging_new_cr3;
@@ -2525,20 +2932,23 @@ static int paging64_init_context_common(struct kvm_vcpu *vcpu, int level)
        context->root_level = level;
        context->shadow_root_level = level;
        context->root_hpa = INVALID_PAGE;
+       context->direct_map = false;
        return 0;
 }
 
-static int paging64_init_context(struct kvm_vcpu *vcpu)
+static int paging64_init_context(struct kvm_vcpu *vcpu,
+                                struct kvm_mmu *context)
 {
-       reset_rsvds_bits_mask(vcpu, PT64_ROOT_LEVEL);
-       return paging64_init_context_common(vcpu, PT64_ROOT_LEVEL);
+       return paging64_init_context_common(vcpu, context, PT64_ROOT_LEVEL);
 }
 
-static int paging32_init_context(struct kvm_vcpu *vcpu)
+static int paging32_init_context(struct kvm_vcpu *vcpu,
+                                struct kvm_mmu *context)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       context->nx = false;
+
+       reset_rsvds_bits_mask(vcpu, context, PT32_ROOT_LEVEL);
 
-       reset_rsvds_bits_mask(vcpu, PT32_ROOT_LEVEL);
        context->new_cr3 = paging_new_cr3;
        context->page_fault = paging32_page_fault;
        context->gva_to_gpa = paging32_gva_to_gpa;
@@ -2549,19 +2959,21 @@ static int paging32_init_context(struct kvm_vcpu *vcpu)
        context->root_level = PT32_ROOT_LEVEL;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->root_hpa = INVALID_PAGE;
+       context->direct_map = false;
        return 0;
 }
 
-static int paging32E_init_context(struct kvm_vcpu *vcpu)
+static int paging32E_init_context(struct kvm_vcpu *vcpu,
+                                 struct kvm_mmu *context)
 {
-       reset_rsvds_bits_mask(vcpu, PT32E_ROOT_LEVEL);
-       return paging64_init_context_common(vcpu, PT32E_ROOT_LEVEL);
+       return paging64_init_context_common(vcpu, context, PT32E_ROOT_LEVEL);
 }
 
 static int init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       struct kvm_mmu *context = vcpu->arch.walk_mmu;
 
+       context->base_role.word = 0;
        context->new_cr3 = nonpaging_new_cr3;
        context->page_fault = tdp_page_fault;
        context->free = nonpaging_free;
@@ -2570,20 +2982,29 @@ static int init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
        context->invlpg = nonpaging_invlpg;
        context->shadow_root_level = kvm_x86_ops->get_tdp_level();
        context->root_hpa = INVALID_PAGE;
+       context->direct_map = true;
+       context->set_cr3 = kvm_x86_ops->set_tdp_cr3;
+       context->get_cr3 = get_cr3;
+       context->inject_page_fault = kvm_inject_page_fault;
+       context->nx = is_nx(vcpu);
 
        if (!is_paging(vcpu)) {
+               context->nx = false;
                context->gva_to_gpa = nonpaging_gva_to_gpa;
                context->root_level = 0;
        } else if (is_long_mode(vcpu)) {
-               reset_rsvds_bits_mask(vcpu, PT64_ROOT_LEVEL);
+               context->nx = is_nx(vcpu);
+               reset_rsvds_bits_mask(vcpu, context, PT64_ROOT_LEVEL);
                context->gva_to_gpa = paging64_gva_to_gpa;
                context->root_level = PT64_ROOT_LEVEL;
        } else if (is_pae(vcpu)) {
-               reset_rsvds_bits_mask(vcpu, PT32E_ROOT_LEVEL);
+               context->nx = is_nx(vcpu);
+               reset_rsvds_bits_mask(vcpu, context, PT32E_ROOT_LEVEL);
                context->gva_to_gpa = paging64_gva_to_gpa;
                context->root_level = PT32E_ROOT_LEVEL;
        } else {
-               reset_rsvds_bits_mask(vcpu, PT32_ROOT_LEVEL);
+               context->nx = false;
+               reset_rsvds_bits_mask(vcpu, context, PT32_ROOT_LEVEL);
                context->gva_to_gpa = paging32_gva_to_gpa;
                context->root_level = PT32_ROOT_LEVEL;
        }
@@ -2591,33 +3012,83 @@ static int init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
        return 0;
 }
 
-static int init_kvm_softmmu(struct kvm_vcpu *vcpu)
+int kvm_init_shadow_mmu(struct kvm_vcpu *vcpu, struct kvm_mmu *context)
 {
        int r;
-
        ASSERT(vcpu);
        ASSERT(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
 
        if (!is_paging(vcpu))
-               r = nonpaging_init_context(vcpu);
+               r = nonpaging_init_context(vcpu, context);
        else if (is_long_mode(vcpu))
-               r = paging64_init_context(vcpu);
+               r = paging64_init_context(vcpu, context);
        else if (is_pae(vcpu))
-               r = paging32E_init_context(vcpu);
+               r = paging32E_init_context(vcpu, context);
        else
-               r = paging32_init_context(vcpu);
+               r = paging32_init_context(vcpu, context);
 
        vcpu->arch.mmu.base_role.cr4_pae = !!is_pae(vcpu);
-       vcpu->arch.mmu.base_role.cr0_wp = is_write_protection(vcpu);
+       vcpu->arch.mmu.base_role.cr0_wp  = is_write_protection(vcpu);
+
+       return r;
+}
+EXPORT_SYMBOL_GPL(kvm_init_shadow_mmu);
+
+static int init_kvm_softmmu(struct kvm_vcpu *vcpu)
+{
+       int r = kvm_init_shadow_mmu(vcpu, vcpu->arch.walk_mmu);
+
+       vcpu->arch.walk_mmu->set_cr3           = kvm_x86_ops->set_cr3;
+       vcpu->arch.walk_mmu->get_cr3           = get_cr3;
+       vcpu->arch.walk_mmu->inject_page_fault = kvm_inject_page_fault;
 
        return r;
 }
 
+static int init_kvm_nested_mmu(struct kvm_vcpu *vcpu)
+{
+       struct kvm_mmu *g_context = &vcpu->arch.nested_mmu;
+
+       g_context->get_cr3           = get_cr3;
+       g_context->inject_page_fault = kvm_inject_page_fault;
+
+       /*
+        * Note that arch.mmu.gva_to_gpa translates l2_gva to l1_gpa. The
+        * translation of l2_gpa to l1_gpa addresses is done using the
+        * arch.nested_mmu.gva_to_gpa function. Basically the gva_to_gpa
+        * functions between mmu and nested_mmu are swapped.
+        */
+       if (!is_paging(vcpu)) {
+               g_context->nx = false;
+               g_context->root_level = 0;
+               g_context->gva_to_gpa = nonpaging_gva_to_gpa_nested;
+       } else if (is_long_mode(vcpu)) {
+               g_context->nx = is_nx(vcpu);
+               reset_rsvds_bits_mask(vcpu, g_context, PT64_ROOT_LEVEL);
+               g_context->root_level = PT64_ROOT_LEVEL;
+               g_context->gva_to_gpa = paging64_gva_to_gpa_nested;
+       } else if (is_pae(vcpu)) {
+               g_context->nx = is_nx(vcpu);
+               reset_rsvds_bits_mask(vcpu, g_context, PT32E_ROOT_LEVEL);
+               g_context->root_level = PT32E_ROOT_LEVEL;
+               g_context->gva_to_gpa = paging64_gva_to_gpa_nested;
+       } else {
+               g_context->nx = false;
+               reset_rsvds_bits_mask(vcpu, g_context, PT32_ROOT_LEVEL);
+               g_context->root_level = PT32_ROOT_LEVEL;
+               g_context->gva_to_gpa = paging32_gva_to_gpa_nested;
+       }
+
+       return 0;
+}
+
 static int init_kvm_mmu(struct kvm_vcpu *vcpu)
 {
        vcpu->arch.update_pte.pfn = bad_pfn;
 
-       if (tdp_enabled)
+       if (mmu_is_nested(vcpu))
+               return init_kvm_nested_mmu(vcpu);
+       else if (tdp_enabled)
                return init_kvm_tdp_mmu(vcpu);
        else
                return init_kvm_softmmu(vcpu);
@@ -2652,7 +3123,7 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
        if (r)
                goto out;
        /* set_cr3() should ensure TLB has been flushed */
-       kvm_x86_ops->set_cr3(vcpu, vcpu->arch.mmu.root_hpa);
+       vcpu->arch.mmu.set_cr3(vcpu, vcpu->arch.mmu.root_hpa);
 out:
        return r;
 }
@@ -2662,6 +3133,7 @@ void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
        mmu_free_roots(vcpu);
 }
+EXPORT_SYMBOL_GPL(kvm_mmu_unload);
 
 static void mmu_pte_write_zap_pte(struct kvm_vcpu *vcpu,
                                  struct kvm_mmu_page *sp,
@@ -2694,9 +3166,6 @@ static void mmu_pte_write_new_pte(struct kvm_vcpu *vcpu,
                return;
         }
 
-       if (is_rsvd_bits_set(vcpu, *(u64 *)new, PT_PAGE_TABLE_LEVEL))
-               return;
-
        ++vcpu->kvm->stat.mmu_pte_updated;
        if (!sp->role.cr4_pae)
                paging32_update_pte(vcpu, sp, spte, new);
@@ -2754,7 +3223,6 @@ static void mmu_guess_page_from_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                kvm_release_pfn_clean(pfn);
                return;
        }
-       vcpu->arch.update_pte.gfn = gfn;
        vcpu->arch.update_pte.pfn = pfn;
 }
 
@@ -2801,9 +3269,8 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 
        /*
         * Assume that the pte write on a page table of the same type
-        * as the current vcpu paging mode.  This is nearly always true
-        * (might be false while changing modes).  Note it is verified later
-        * by update_pte().
+        * as the current vcpu paging mode since we update the sptes only
+        * when they have the same mode.
         */
        if ((is_pae(vcpu) && bytes == 4) || !new) {
                /* Handle a 32-bit guest writing two halves of a 64-bit gpte */
@@ -2833,11 +3300,11 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        spin_lock(&vcpu->kvm->mmu_lock);
        if (atomic_read(&vcpu->kvm->arch.invlpg_counter) != invlpg_counter)
                gentry = 0;
-       kvm_mmu_access_page(vcpu, gfn);
        kvm_mmu_free_some_pages(vcpu);
        ++vcpu->kvm->stat.mmu_pte_write;
-       kvm_mmu_audit(vcpu, "pre pte write");
+       trace_kvm_mmu_audit(vcpu, AUDIT_PRE_PTE_WRITE);
        if (guest_initiated) {
+               kvm_mmu_access_page(vcpu, gfn);
                if (gfn == vcpu->arch.last_pt_write_gfn
                    && !last_updated_pte_accessed(vcpu)) {
                        ++vcpu->arch.last_pt_write_count;
@@ -2909,7 +3376,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        }
        mmu_pte_write_flush_tlb(vcpu, zap_page, remote_flush, local_flush);
        kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
-       kvm_mmu_audit(vcpu, "post pte write");
+       trace_kvm_mmu_audit(vcpu, AUDIT_POST_PTE_WRITE);
        spin_unlock(&vcpu->kvm->mmu_lock);
        if (!is_error_pfn(vcpu->arch.update_pte.pfn)) {
                kvm_release_pfn_clean(vcpu->arch.update_pte.pfn);
@@ -2922,7 +3389,7 @@ int kvm_mmu_unprotect_page_virt(struct kvm_vcpu *vcpu, gva_t gva)
        gpa_t gpa;
        int r;
 
-       if (tdp_enabled)
+       if (vcpu->arch.mmu.direct_map)
                return 0;
 
        gpa = kvm_mmu_gva_to_gpa_read(vcpu, gva, NULL);
@@ -2936,29 +3403,27 @@ EXPORT_SYMBOL_GPL(kvm_mmu_unprotect_page_virt);
 
 void __kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu)
 {
-       int free_pages;
        LIST_HEAD(invalid_list);
 
-       free_pages = vcpu->kvm->arch.n_free_mmu_pages;
-       while (free_pages < KVM_REFILL_PAGES &&
+       while (kvm_mmu_available_pages(vcpu->kvm) < KVM_REFILL_PAGES &&
               !list_empty(&vcpu->kvm->arch.active_mmu_pages)) {
                struct kvm_mmu_page *sp;
 
                sp = container_of(vcpu->kvm->arch.active_mmu_pages.prev,
                                  struct kvm_mmu_page, link);
-               free_pages += kvm_mmu_prepare_zap_page(vcpu->kvm, sp,
-                                                      &invalid_list);
+               kvm_mmu_prepare_zap_page(vcpu->kvm, sp, &invalid_list);
+               kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
                ++vcpu->kvm->stat.mmu_recycled;
        }
-       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
 }
 
-int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code)
+int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code,
+                      void *insn, int insn_len)
 {
        int r;
        enum emulation_result er;
 
-       r = vcpu->arch.mmu.page_fault(vcpu, cr2, error_code);
+       r = vcpu->arch.mmu.page_fault(vcpu, cr2, error_code, false);
        if (r < 0)
                goto out;
 
@@ -2971,7 +3436,7 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code)
        if (r)
                goto out;
 
-       er = emulate_instruction(vcpu, cr2, error_code, 0);
+       er = x86_emulate_instruction(vcpu, cr2, 0, insn, insn_len);
 
        switch (er) {
        case EMULATE_DONE:
@@ -3012,6 +3477,8 @@ EXPORT_SYMBOL_GPL(kvm_disable_tdp);
 static void free_mmu_pages(struct kvm_vcpu *vcpu)
 {
        free_page((unsigned long)vcpu->arch.mmu.pae_root);
+       if (vcpu->arch.mmu.lm_root != NULL)
+               free_page((unsigned long)vcpu->arch.mmu.lm_root);
 }
 
 static int alloc_mmu_pages(struct kvm_vcpu *vcpu)
@@ -3053,15 +3520,6 @@ int kvm_mmu_setup(struct kvm_vcpu *vcpu)
        return init_kvm_mmu(vcpu);
 }
 
-void kvm_mmu_destroy(struct kvm_vcpu *vcpu)
-{
-       ASSERT(vcpu);
-
-       destroy_kvm_mmu(vcpu);
-       free_mmu_pages(vcpu);
-       mmu_free_memory_caches(vcpu);
-}
-
 void kvm_mmu_slot_remove_write_access(struct kvm *kvm, int slot)
 {
        struct kvm_mmu_page *sp;
@@ -3074,10 +3532,22 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm, int slot)
                        continue;
 
                pt = sp->spt;
-               for (i = 0; i < PT64_ENT_PER_PAGE; ++i)
+               for (i = 0; i < PT64_ENT_PER_PAGE; ++i) {
+                       if (!is_shadow_present_pte(pt[i]) ||
+                             !is_last_spte(pt[i], sp->role.level))
+                               continue;
+
+                       if (is_large_pte(pt[i])) {
+                               drop_spte(kvm, &pt[i],
+                                         shadow_trap_nonpresent_pte);
+                               --kvm->stat.lpages;
+                               continue;
+                       }
+
                        /* avoid RMW */
                        if (is_writable_pte(pt[i]))
-                               pt[i] &= ~PT_WRITABLE_MASK;
+                               update_spte(&pt[i], pt[i] & ~PT_WRITABLE_MASK);
+               }
        }
        kvm_flush_remote_tlbs(kvm);
 }
@@ -3111,23 +3581,22 @@ static int mmu_shrink(struct shrinker *shrink, int nr_to_scan, gfp_t gfp_mask)
 {
        struct kvm *kvm;
        struct kvm *kvm_freed = NULL;
-       int cache_count = 0;
 
-       spin_lock(&kvm_lock);
+       if (nr_to_scan == 0)
+               goto out;
+
+       raw_spin_lock(&kvm_lock);
 
        list_for_each_entry(kvm, &vm_list, vm_list) {
-               int npages, idx, freed_pages;
+               int idx, freed_pages;
                LIST_HEAD(invalid_list);
 
                idx = srcu_read_lock(&kvm->srcu);
                spin_lock(&kvm->mmu_lock);
-               npages = kvm->arch.n_alloc_mmu_pages -
-                        kvm->arch.n_free_mmu_pages;
-               cache_count += npages;
-               if (!kvm_freed && nr_to_scan > 0 && npages > 0) {
+               if (!kvm_freed && nr_to_scan > 0 &&
+                   kvm->arch.n_used_mmu_pages > 0) {
                        freed_pages = kvm_mmu_remove_some_alloc_mmu_pages(kvm,
                                                          &invalid_list);
-                       cache_count -= freed_pages;
                        kvm_freed = kvm;
                }
                nr_to_scan--;
@@ -3139,9 +3608,10 @@ static int mmu_shrink(struct shrinker *shrink, int nr_to_scan, gfp_t gfp_mask)
        if (kvm_freed)
                list_move_tail(&kvm_freed->vm_list, &vm_list);
 
-       spin_unlock(&kvm_lock);
+       raw_spin_unlock(&kvm_lock);
 
-       return cache_count;
+out:
+       return percpu_counter_read_positive(&kvm_total_used_mmu_pages);
 }
 
 static struct shrinker mmu_shrinker = {
@@ -3159,12 +3629,6 @@ static void mmu_destroy_caches(void)
                kmem_cache_destroy(mmu_page_header_cache);
 }
 
-void kvm_mmu_module_exit(void)
-{
-       mmu_destroy_caches();
-       unregister_shrinker(&mmu_shrinker);
-}
-
 int kvm_mmu_module_init(void)
 {
        pte_chain_cache = kmem_cache_create("kvm_pte_chain",
@@ -3184,6 +3648,9 @@ int kvm_mmu_module_init(void)
        if (!mmu_page_header_cache)
                goto nomem;
 
+       if (percpu_counter_init(&kvm_total_used_mmu_pages, 0))
+               goto nomem;
+
        register_shrinker(&mmu_shrinker);
 
        return 0;
@@ -3258,7 +3725,7 @@ static int kvm_pv_mmu_write(struct kvm_vcpu *vcpu,
 
 static int kvm_pv_mmu_flush_tlb(struct kvm_vcpu *vcpu)
 {
-       (void)kvm_set_cr3(vcpu, vcpu->arch.cr3);
+       (void)kvm_set_cr3(vcpu, kvm_read_cr3(vcpu));
        return 1;
 }
 
@@ -3354,271 +3821,25 @@ int kvm_mmu_get_spte_hierarchy(struct kvm_vcpu *vcpu, u64 addr, u64 sptes[4])
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_get_spte_hierarchy);
 
-#ifdef AUDIT
-
-static const char *audit_msg;
-
-static gva_t canonicalize(gva_t gva)
-{
-#ifdef CONFIG_X86_64
-       gva = (long long)(gva << 16) >> 16;
-#endif
-       return gva;
-}
-
-
-typedef void (*inspect_spte_fn) (struct kvm *kvm, u64 *sptep);
-
-static void __mmu_spte_walk(struct kvm *kvm, struct kvm_mmu_page *sp,
-                           inspect_spte_fn fn)
-{
-       int i;
-
-       for (i = 0; i < PT64_ENT_PER_PAGE; ++i) {
-               u64 ent = sp->spt[i];
-
-               if (is_shadow_present_pte(ent)) {
-                       if (!is_last_spte(ent, sp->role.level)) {
-                               struct kvm_mmu_page *child;
-                               child = page_header(ent & PT64_BASE_ADDR_MASK);
-                               __mmu_spte_walk(kvm, child, fn);
-                       } else
-                               fn(kvm, &sp->spt[i]);
-               }
-       }
-}
-
-static void mmu_spte_walk(struct kvm_vcpu *vcpu, inspect_spte_fn fn)
-{
-       int i;
-       struct kvm_mmu_page *sp;
-
-       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
-               return;
-       if (vcpu->arch.mmu.shadow_root_level == PT64_ROOT_LEVEL) {
-               hpa_t root = vcpu->arch.mmu.root_hpa;
-               sp = page_header(root);
-               __mmu_spte_walk(vcpu->kvm, sp, fn);
-               return;
-       }
-       for (i = 0; i < 4; ++i) {
-               hpa_t root = vcpu->arch.mmu.pae_root[i];
-
-               if (root && VALID_PAGE(root)) {
-                       root &= PT64_BASE_ADDR_MASK;
-                       sp = page_header(root);
-                       __mmu_spte_walk(vcpu->kvm, sp, fn);
-               }
-       }
-       return;
-}
-
-static void audit_mappings_page(struct kvm_vcpu *vcpu, u64 page_pte,
-                               gva_t va, int level)
-{
-       u64 *pt = __va(page_pte & PT64_BASE_ADDR_MASK);
-       int i;
-       gva_t va_delta = 1ul << (PAGE_SHIFT + 9 * (level - 1));
-
-       for (i = 0; i < PT64_ENT_PER_PAGE; ++i, va += va_delta) {
-               u64 ent = pt[i];
-
-               if (ent == shadow_trap_nonpresent_pte)
-                       continue;
-
-               va = canonicalize(va);
-               if (is_shadow_present_pte(ent) && !is_last_spte(ent, level))
-                       audit_mappings_page(vcpu, ent, va, level - 1);
-               else {
-                       gpa_t gpa = kvm_mmu_gva_to_gpa_read(vcpu, va, NULL);
-                       gfn_t gfn = gpa >> PAGE_SHIFT;
-                       pfn_t pfn = gfn_to_pfn(vcpu->kvm, gfn);
-                       hpa_t hpa = (hpa_t)pfn << PAGE_SHIFT;
-
-                       if (is_error_pfn(pfn)) {
-                               kvm_release_pfn_clean(pfn);
-                               continue;
-                       }
-
-                       if (is_shadow_present_pte(ent)
-                           && (ent & PT64_BASE_ADDR_MASK) != hpa)
-                               printk(KERN_ERR "xx audit error: (%s) levels %d"
-                                      " gva %lx gpa %llx hpa %llx ent %llx %d\n",
-                                      audit_msg, vcpu->arch.mmu.root_level,
-                                      va, gpa, hpa, ent,
-                                      is_shadow_present_pte(ent));
-                       else if (ent == shadow_notrap_nonpresent_pte
-                                && !is_error_hpa(hpa))
-                               printk(KERN_ERR "audit: (%s) notrap shadow,"
-                                      " valid guest gva %lx\n", audit_msg, va);
-                       kvm_release_pfn_clean(pfn);
-
-               }
-       }
-}
-
-static void audit_mappings(struct kvm_vcpu *vcpu)
-{
-       unsigned i;
-
-       if (vcpu->arch.mmu.root_level == 4)
-               audit_mappings_page(vcpu, vcpu->arch.mmu.root_hpa, 0, 4);
-       else
-               for (i = 0; i < 4; ++i)
-                       if (vcpu->arch.mmu.pae_root[i] & PT_PRESENT_MASK)
-                               audit_mappings_page(vcpu,
-                                                   vcpu->arch.mmu.pae_root[i],
-                                                   i << 30,
-                                                   2);
-}
-
-static int count_rmaps(struct kvm_vcpu *vcpu)
-{
-       struct kvm *kvm = vcpu->kvm;
-       struct kvm_memslots *slots;
-       int nmaps = 0;
-       int i, j, k, idx;
-
-       idx = srcu_read_lock(&kvm->srcu);
-       slots = kvm_memslots(kvm);
-       for (i = 0; i < KVM_MEMORY_SLOTS; ++i) {
-               struct kvm_memory_slot *m = &slots->memslots[i];
-               struct kvm_rmap_desc *d;
-
-               for (j = 0; j < m->npages; ++j) {
-                       unsigned long *rmapp = &m->rmap[j];
-
-                       if (!*rmapp)
-                               continue;
-                       if (!(*rmapp & 1)) {
-                               ++nmaps;
-                               continue;
-                       }
-                       d = (struct kvm_rmap_desc *)(*rmapp & ~1ul);
-                       while (d) {
-                               for (k = 0; k < RMAP_EXT; ++k)
-                                       if (d->sptes[k])
-                                               ++nmaps;
-                                       else
-                                               break;
-                               d = d->more;
-                       }
-               }
-       }
-       srcu_read_unlock(&kvm->srcu, idx);
-       return nmaps;
-}
-
-void inspect_spte_has_rmap(struct kvm *kvm, u64 *sptep)
-{
-       unsigned long *rmapp;
-       struct kvm_mmu_page *rev_sp;
-       gfn_t gfn;
-
-       if (is_writable_pte(*sptep)) {
-               rev_sp = page_header(__pa(sptep));
-               gfn = kvm_mmu_page_get_gfn(rev_sp, sptep - rev_sp->spt);
-
-               if (!gfn_to_memslot(kvm, gfn)) {
-                       if (!printk_ratelimit())
-                               return;
-                       printk(KERN_ERR "%s: no memslot for gfn %ld\n",
-                                        audit_msg, gfn);
-                       printk(KERN_ERR "%s: index %ld of sp (gfn=%lx)\n",
-                              audit_msg, (long int)(sptep - rev_sp->spt),
-                                       rev_sp->gfn);
-                       dump_stack();
-                       return;
-               }
-
-               rmapp = gfn_to_rmap(kvm, gfn, rev_sp->role.level);
-               if (!*rmapp) {
-                       if (!printk_ratelimit())
-                               return;
-                       printk(KERN_ERR "%s: no rmap for writable spte %llx\n",
-                                        audit_msg, *sptep);
-                       dump_stack();
-               }
-       }
-
-}
-
-void audit_writable_sptes_have_rmaps(struct kvm_vcpu *vcpu)
-{
-       mmu_spte_walk(vcpu, inspect_spte_has_rmap);
-}
-
-static void check_writable_mappings_rmap(struct kvm_vcpu *vcpu)
+void kvm_mmu_destroy(struct kvm_vcpu *vcpu)
 {
-       struct kvm_mmu_page *sp;
-       int i;
-
-       list_for_each_entry(sp, &vcpu->kvm->arch.active_mmu_pages, link) {
-               u64 *pt = sp->spt;
-
-               if (sp->role.level != PT_PAGE_TABLE_LEVEL)
-                       continue;
-
-               for (i = 0; i < PT64_ENT_PER_PAGE; ++i) {
-                       u64 ent = pt[i];
-
-                       if (!(ent & PT_PRESENT_MASK))
-                               continue;
-                       if (!is_writable_pte(ent))
-                               continue;
-                       inspect_spte_has_rmap(vcpu->kvm, &pt[i]);
-               }
-       }
-       return;
-}
+       ASSERT(vcpu);
 
-static void audit_rmap(struct kvm_vcpu *vcpu)
-{
-       check_writable_mappings_rmap(vcpu);
-       count_rmaps(vcpu);
+       destroy_kvm_mmu(vcpu);
+       free_mmu_pages(vcpu);
+       mmu_free_memory_caches(vcpu);
 }
 
-static void audit_write_protection(struct kvm_vcpu *vcpu)
-{
-       struct kvm_mmu_page *sp;
-       struct kvm_memory_slot *slot;
-       unsigned long *rmapp;
-       u64 *spte;
-       gfn_t gfn;
-
-       list_for_each_entry(sp, &vcpu->kvm->arch.active_mmu_pages, link) {
-               if (sp->role.direct)
-                       continue;
-               if (sp->unsync)
-                       continue;
-
-               slot = gfn_to_memslot(vcpu->kvm, sp->gfn);
-               rmapp = &slot->rmap[gfn - slot->base_gfn];
-
-               spte = rmap_next(vcpu->kvm, rmapp, NULL);
-               while (spte) {
-                       if (is_writable_pte(*spte))
-                               printk(KERN_ERR "%s: (%s) shadow page has "
-                               "writable mappings: gfn %lx role %x\n",
-                              __func__, audit_msg, sp->gfn,
-                              sp->role.word);
-                       spte = rmap_next(vcpu->kvm, rmapp, spte);
-               }
-       }
-}
+#ifdef CONFIG_KVM_MMU_AUDIT
+#include "mmu_audit.c"
+#else
+static void mmu_audit_disable(void) { }
+#endif
 
-static void kvm_mmu_audit(struct kvm_vcpu *vcpu, const char *msg)
+void kvm_mmu_module_exit(void)
 {
-       int olddbg = dbg;
-
-       dbg = 0;
-       audit_msg = msg;
-       audit_rmap(vcpu);
-       audit_write_protection(vcpu);
-       if (strcmp("pre pte write", audit_msg) != 0)
-               audit_mappings(vcpu);
-       audit_writable_sptes_have_rmaps(vcpu);
-       dbg = olddbg;
+       mmu_destroy_caches();
+       percpu_counter_destroy(&kvm_total_used_mmu_pages);
+       unregister_shrinker(&mmu_shrinker);
+       mmu_audit_disable();
 }
-
-#endif