diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index e7cb9d70cad9fcd47b77fffa079a901098dd5d56..c64ad1fca4e4af4dd75ab2bcd7e73d93dec12542 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -777,6 +777,34 @@ void dup_userfaultfd_complete(struct list_head *fcs) } } +void dup_userfaultfd_fail(struct list_head *fcs) +{ + struct userfaultfd_fork_ctx *fctx, *n; + + /* + * An error has occurred on fork, we will tear memory down, but have + * allocated memory for fctx's and raised reference counts for both the + * original and child contexts (and on the mm for each as a result). + * + * These would ordinarily be taken care of by a user handling the event, + * but we are no longer doing so, so manually clean up here. + * + * mm tear down will take care of cleaning up VMA contexts. + */ + list_for_each_entry_safe(fctx, n, fcs, list) { + struct userfaultfd_ctx *octx = fctx->orig; + struct userfaultfd_ctx *ctx = fctx->new; + + atomic_dec(&octx->mmap_changing); + VM_BUG_ON(atomic_read(&octx->mmap_changing) < 0); + userfaultfd_ctx_put(octx); + userfaultfd_ctx_put(ctx); + + list_del(&fctx->list); + kfree(fctx); + } +} + void mremap_userfaultfd_prep(struct vm_area_struct *vma, struct vm_userfaultfd_ctx *vm_ctx) { diff --git a/include/linux/ksm.h b/include/linux/ksm.h index c8144f9ca9d8152bc3c3bb1ae80bebab1e7a1601..691c1f54254e62975b5725a8c03bd84b580c5ded 100644 --- a/include/linux/ksm.h +++ b/include/linux/ksm.h @@ -54,20 +54,10 @@ static inline long mm_ksm_zero_pages(struct mm_struct *mm) return atomic_long_read(&mm->ksm_zero_pages); } -static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm) +static inline void ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm) { - int ret; - - if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags)) { - ret = __ksm_enter(mm); - if (ret) - return ret; - } - - if (test_bit(MMF_VM_MERGE_ANY, &oldmm->flags)) - set_bit(MMF_VM_MERGE_ANY, &mm->flags); - - return 0; + if (test_bit(MMF_VM_MERGEABLE, &oldmm->flags)) + __ksm_enter(mm); } static inline void ksm_exit(struct mm_struct *mm) @@ -111,9 +101,8 @@ static inline int ksm_disable(struct mm_struct *mm) return 0; } -static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm) +static inline void ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm) { - return 0; } static inline void ksm_exit(struct mm_struct *mm) diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h index 9427d5fccf7b65f453df1141263a2dfe98e66ff5..aa8e3725f10347bf08304853e710e2f910d68607 100644 --- a/include/linux/userfaultfd_k.h +++ b/include/linux/userfaultfd_k.h @@ -186,6 +186,7 @@ static inline bool vma_can_userfault(struct vm_area_struct *vma, extern int dup_userfaultfd(struct vm_area_struct *, struct list_head *); extern void dup_userfaultfd_complete(struct list_head *); +void dup_userfaultfd_fail(struct list_head *); extern void mremap_userfaultfd_prep(struct vm_area_struct *, struct vm_userfaultfd_ctx *); @@ -261,6 +262,10 @@ static inline void dup_userfaultfd_complete(struct list_head *l) { } +static inline void dup_userfaultfd_fail(struct list_head *l) +{ +} + static inline void mremap_userfaultfd_prep(struct vm_area_struct *vma, struct vm_userfaultfd_ctx *ctx) { diff --git a/kernel/fork.c b/kernel/fork.c index a8a30a21799a6cb085958e53f05ae8428fb95245..34cddf4a8b26321e4d19a4b83a058ac9f6714e9e 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -685,11 +685,6 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, mm->exec_vm = oldmm->exec_vm; mm->stack_vm = oldmm->stack_vm; - retval = ksm_fork(mm, oldmm); - if (retval) - goto out; - khugepaged_fork(mm, oldmm); - /* Use __mt_dup() to efficiently build an identical maple tree. */ retval = __mt_dup(&oldmm->mm_mt, &mm->mm_mt, GFP_KERNEL); if (unlikely(retval)) @@ -792,6 +787,8 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, vma_iter_free(&vmi); if (!retval) { mt_set_in_rcu(vmi.mas.tree); + ksm_fork(mm, oldmm); + khugepaged_fork(mm, oldmm); } else if (mpnt) { /* * The entire maple tree has already been duplicated. If the @@ -807,7 +804,10 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm, mmap_write_unlock(mm); flush_tlb_mm(oldmm); mmap_write_unlock(oldmm); - dup_userfaultfd_complete(&uf); + if (!retval) + dup_userfaultfd_complete(&uf); + else + dup_userfaultfd_fail(&uf); fail_uprobe_end: uprobe_end_dup_mmap(); return retval;