diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 52242a148c705c567a53fbc26699703b31b766e3..6db562d9569d552fc6ae15d68a0e787256dc7487 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -700,6 +700,23 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) return t->send_pkt(reply); } +/* This function should be called with sk_lock held and SOCK_DONE set */ +static void virtio_transport_remove_sock(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct virtio_vsock_pkt *pkt, *tmp; + + /* We don't need to take rx_lock, as the socket is closing and we are + * removing it. + */ + list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { + list_del(&pkt->list); + virtio_transport_free_pkt(pkt); + } + + vsock_remove_sock(vsk); +} + static void virtio_transport_wait_close(struct sock *sk, long timeout) { if (timeout) { @@ -732,7 +749,7 @@ static void virtio_transport_do_close(struct vsock_sock *vsk, (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { vsk->close_work_scheduled = false; - vsock_remove_sock(vsk); + virtio_transport_remove_sock(vsk); /* Release refcnt obtained when we scheduled the timeout */ sock_put(sk); @@ -795,8 +812,6 @@ static bool virtio_transport_close(struct vsock_sock *vsk) void virtio_transport_release(struct vsock_sock *vsk) { - struct virtio_vsock_sock *vvs = vsk->trans; - struct virtio_vsock_pkt *pkt, *tmp; struct sock *sk = &vsk->sk; bool remove_sock = true; @@ -804,14 +819,11 @@ void virtio_transport_release(struct vsock_sock *vsk) if (sk->sk_type == SOCK_STREAM) remove_sock = virtio_transport_close(vsk); - list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { - list_del(&pkt->list); - virtio_transport_free_pkt(pkt); + if (remove_sock) { + sock_set_flag(sk, SOCK_DONE); + virtio_transport_remove_sock(vsk); } release_sock(sk); - - if (remove_sock) - vsock_remove_sock(vsk); } EXPORT_SYMBOL_GPL(virtio_transport_release); @@ -1037,6 +1049,14 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) lock_sock(sk); + /* Check if sk has been closed before lock_sock */ + if (sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset_no_sock(pkt); + release_sock(sk); + sock_put(sk); + goto free_pkt; + } + /* Update CID in case it has changed after a transport reset event */ vsk->local_addr.svm_cid = dst.svm_cid;