diff --git a/include/net/bluetooth/bluetooth.h b/include/net/bluetooth/bluetooth.h index fa37bdd00fa5b14f65fd14eaf4792a61a8807312..222985cea468f6745731d87588f6438199c26d77 100644 --- a/include/net/bluetooth/bluetooth.h +++ b/include/net/bluetooth/bluetooth.h @@ -314,6 +314,7 @@ int bt_sock_register(int proto, const struct net_proto_family *ops); void bt_sock_unregister(int proto); void bt_sock_link(struct bt_sock_list *l, struct sock *s); void bt_sock_unlink(struct bt_sock_list *l, struct sock *s); +bool bt_sock_linked(struct bt_sock_list *l, struct sock *s); struct sock *bt_sock_alloc(struct net *net, struct socket *sock, struct proto *prot, int proto, gfp_t prio, int kern); int bt_sock_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, diff --git a/net/bluetooth/af_bluetooth.c b/net/bluetooth/af_bluetooth.c index a06a5c02b03e5ce121c8a8d655db6b16f82a52c2..a6a71c7d7169bb07827500f9c0148f69268b1879 100644 --- a/net/bluetooth/af_bluetooth.c +++ b/net/bluetooth/af_bluetooth.c @@ -175,6 +175,28 @@ void bt_sock_unlink(struct bt_sock_list *l, struct sock *sk) } EXPORT_SYMBOL(bt_sock_unlink); +bool bt_sock_linked(struct bt_sock_list *l, struct sock *s) +{ + struct sock *sk; + + if (!l || !s) + return false; + + read_lock(&l->lock); + + sk_for_each(sk, &l->head) { + if (s == sk) { + read_unlock(&l->lock); + return true; + } + } + + read_unlock(&l->lock); + + return false; +} +EXPORT_SYMBOL(bt_sock_linked); + void bt_accept_enqueue(struct sock *parent, struct sock *sk, bool bh) { BT_DBG("parent %p, sk %p", parent, sk); diff --git a/net/bluetooth/sco.c b/net/bluetooth/sco.c index 2083c0552857e693af9deaca64f82197d95f1577..3fc0c40a47f74b385c784518aa92a416dc483ec5 100644 --- a/net/bluetooth/sco.c +++ b/net/bluetooth/sco.c @@ -76,6 +76,16 @@ struct sco_pinfo { #define SCO_CONN_TIMEOUT (HZ * 40) #define SCO_DISCONN_TIMEOUT (HZ * 2) +static struct sock *sco_sock_hold(struct sco_conn *conn) +{ + if (!conn || !bt_sock_linked(&sco_sk_list, conn->sk)) + return NULL; + + sock_hold(conn->sk); + + return conn->sk; +} + static void sco_sock_timeout(struct work_struct *work) { struct sco_conn *conn = container_of(work, struct sco_conn, @@ -87,9 +97,7 @@ static void sco_sock_timeout(struct work_struct *work) sco_conn_unlock(conn); return; } - sk = conn->sk; - if (sk) - sock_hold(sk); + sk = sco_sock_hold(conn); sco_conn_unlock(conn); if (!sk) @@ -192,11 +200,10 @@ static void sco_conn_del(struct hci_conn *hcon, int err) /* Kill socket */ sco_conn_lock(conn); - sk = conn->sk; + sk = sco_sock_hold(conn); sco_conn_unlock(conn); if (sk) { - sock_hold(sk); lock_sock(sk); sco_sock_clear_timer(sk); sco_chan_del(sk, err);