diff --git a/include/linux/mm.h b/include/linux/mm.h
index 3794258f39d19fc871b6fc6b2297c4bdb800474e..f939691a8bbc1054485446b113f60d004f00438f 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1734,6 +1734,7 @@ static inline struct page *follow_page(struct vm_area_struct *vma,
 #define FOLL_HWPOISON	0x100	/* check page is hwpoisoned */
 #define FOLL_NUMA	0x200	/* force NUMA hinting page fault */
 #define FOLL_MIGRATION	0x400	/* wait for page to replace migration entry */
+#define FOLL_COW	0x4000	/* internal GUP flag */
 
 typedef int (*pte_fn_t)(pte_t *pte, pgtable_t token, unsigned long addr,
 			void *data);
diff --git a/mm/memory.c b/mm/memory.c
index b8a46ed6d5eeb6937e6dd6bb4a1954f9f8236171..0602b35bede9dc24b2389cd8985fa259fe1301ea 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -1467,6 +1467,16 @@ int zap_vma_ptes(struct vm_area_struct *vma, unsigned long address,
 }
 EXPORT_SYMBOL_GPL(zap_vma_ptes);
 
+/*
+ * FOLL_FORCE can write to even unwritable pte's, but only
+ * after we've gone through a COW cycle and they are dirty.
+ */
+static inline bool can_follow_write_pte(pte_t pte, unsigned int flags)
+{
+	return pte_write(pte) ||
+		((flags & FOLL_FORCE) && (flags & FOLL_COW) && pte_dirty(pte));
+}
+
 /**
  * follow_page_mask - look up a page descriptor from a user-virtual address
  * @vma: vm_area_struct mapping @address
@@ -1574,7 +1584,7 @@ split_fallthrough:
 	}
 	if ((flags & FOLL_NUMA) && pte_numa(pte))
 		goto no_page;
-	if ((flags & FOLL_WRITE) && !pte_write(pte))
+	if ((flags & FOLL_WRITE) && !can_follow_write_pte(pte, flags))
 		goto unlock;
 
 	page = vm_normal_page(vma, address, pte);
@@ -1894,7 +1904,7 @@ long __get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
 				 */
 				if ((ret & VM_FAULT_WRITE) &&
 				    !(vma->vm_flags & VM_WRITE))
-					foll_flags &= ~FOLL_WRITE;
+					foll_flags |= FOLL_COW;
 
 				cond_resched();
 			}