diff --git a/drivers/block/nbd.c b/drivers/block/nbd.c index fa111ed7cdae21cf3dcb45b932a4b0f98aae875f..85eba9dcba3ec80ad0bb8784dcc3d0ed88af85ce 100644 --- a/drivers/block/nbd.c +++ b/drivers/block/nbd.c @@ -1121,6 +1121,14 @@ static struct socket *nbd_get_socket(struct nbd_device *nbd, unsigned long fd, if (!sock) return NULL; + if (!sk_is_tcp(sock->sk) && + !sk_is_stream_unix(sock->sk)) { + dev_err(disk_to_dev(nbd->disk), "Unsupported socket: should be TCP or UNIX.\n"); + *err = -EINVAL; + sockfd_put(sock); + return NULL; + } + if (sock->ops->shutdown == sock_no_shutdown) { dev_err(disk_to_dev(nbd->disk), "Unsupported socket: shutdown callout must be supported.\n"); *err = -EINVAL; diff --git a/include/net/sock.h b/include/net/sock.h index b81e870a9fcc9cf56d21f86d0abf81292b20d80b..095b429231a2dff91ebf1327373400b1be460f40 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -2678,6 +2678,16 @@ static inline void skb_setup_tx_timestamp(struct sk_buff *skb, __u16 tsflags) &skb_shinfo(skb)->tskey); } +static inline bool sk_is_tcp(const struct sock *sk) +{ + return sk->sk_type == SOCK_STREAM && sk->sk_protocol == IPPROTO_TCP; +} + +static inline bool sk_is_stream_unix(const struct sock *sk) +{ + return sk->sk_family == AF_UNIX && sk->sk_type == SOCK_STREAM; +} + DECLARE_STATIC_KEY_FALSE(tcp_rx_skb_cache_key); /** * sk_eat_skb - Release a skb if it is no longer needed diff --git a/net/core/skbuff.c b/net/core/skbuff.c index c5d13c3fd0d220c3332f73c0ee82fc20d7866f1f..47ed4de1edf4d5e8c891e37081b348b3e01587bd 100644 --- a/net/core/skbuff.c +++ b/net/core/skbuff.c @@ -4930,8 +4930,7 @@ static void __skb_complete_tx_timestamp(struct sk_buff *skb, serr->header.h4.iif = skb->dev ? skb->dev->ifindex : 0; if (sk->sk_tsflags & SOF_TIMESTAMPING_OPT_ID) { serr->ee.ee_data = skb_shinfo(skb)->tskey; - if (sk->sk_protocol == IPPROTO_TCP && - sk->sk_type == SOCK_STREAM) + if (sk_is_tcp(sk)) serr->ee.ee_data -= sk->sk_tskey; } @@ -4999,8 +4998,7 @@ void __skb_tstamp_tx(struct sk_buff *orig_skb, if (tsonly) { #ifdef CONFIG_INET if ((sk->sk_tsflags & SOF_TIMESTAMPING_OPT_STATS) && - sk->sk_protocol == IPPROTO_TCP && - sk->sk_type == SOCK_STREAM) { + sk_is_tcp(sk)) { skb = tcp_get_timestamping_opt_stats(sk, orig_skb); opt_stats = true; } else diff --git a/net/core/sock.c b/net/core/sock.c index 0e4ba9d63542ce66a11d3a0fe6720399f670ef3a..88bb667aebdc228fda185c86c694fc7f2a79c3e9 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -1012,8 +1012,7 @@ int sock_setsockopt(struct socket *sock, int level, int optname, if (val & SOF_TIMESTAMPING_OPT_ID && !(sk->sk_tsflags & SOF_TIMESTAMPING_OPT_ID)) { - if (sk->sk_protocol == IPPROTO_TCP && - sk->sk_type == SOCK_STREAM) { + if (sk_is_tcp(sk)) { if ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) { ret = -EINVAL; @@ -1202,8 +1201,7 @@ int sock_setsockopt(struct socket *sock, int level, int optname, case SO_ZEROCOPY: if (sk->sk_family == PF_INET || sk->sk_family == PF_INET6) { - if (!((sk->sk_type == SOCK_STREAM && - sk->sk_protocol == IPPROTO_TCP) || + if (!(sk_is_tcp(sk) || (sk->sk_type == SOCK_DGRAM && sk->sk_protocol == IPPROTO_UDP))) ret = -ENOTSUPP; diff --git a/net/core/sock_map.c b/net/core/sock_map.c index beb8768a928581e5e2d8f2ca239dd6ca82e86333..07ab464d2d4ebc976837e27bc0c54de6360e1ccd 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -550,12 +550,6 @@ static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; } -static bool sk_is_tcp(const struct sock *sk) -{ - return sk->sk_type == SOCK_STREAM && - sk->sk_protocol == IPPROTO_TCP; -} - static bool sk_is_udp(const struct sock *sk) { return sk->sk_type == SOCK_DGRAM &&