udp: complete port availability checking
Eric Dumazet [Thu, 9 Oct 2008 21:51:27 +0000 (14:51 -0700)]
While looking at UDP port randomization, I noticed it
was litle bit pessimistic, not looking at type of sockets
(IPV6/IPV4) and not looking at bound addresses if any.

We should perform same tests than when binding to a
specific port.

This permits a cleanup of udp_lib_get_port()

Signed-off-by: Eric Dumazet <dada1@cosmosbay.com>
Signed-off-by: David S. Miller <davem@davemloft.net>

net/ipv4/udp.c

index 67d8430..eacf4cf 100644 (file)
@@ -122,14 +122,23 @@ EXPORT_SYMBOL(sysctl_udp_wmem_min);
 atomic_t udp_memory_allocated;
 EXPORT_SYMBOL(udp_memory_allocated);
 
-static inline int __udp_lib_lport_inuse(struct net *net, __u16 num,
-                                       const struct hlist_head udptable[])
+static int udp_lib_lport_inuse(struct net *net, __u16 num,
+                              const struct hlist_head udptable[],
+                              struct sock *sk,
+                              int (*saddr_comp)(const struct sock *sk1,
+                                                const struct sock *sk2))
 {
-       struct sock *sk;
+       struct sock *sk2;
        struct hlist_node *node;
 
-       sk_for_each(sk, node, &udptable[udp_hashfn(net, num)])
-               if (net_eq(sock_net(sk), net) && sk->sk_hash == num)
+       sk_for_each(sk2, node, &udptable[udp_hashfn(net, num)])
+               if (net_eq(sock_net(sk2), net)                  &&
+                   sk2 != sk                                   &&
+                   sk2->sk_hash == num                         &&
+                   (!sk2->sk_reuse || !sk->sk_reuse)           &&
+                   (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
+                       || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
+                   (*saddr_comp)(sk, sk2))
                        return 1;
        return 0;
 }
@@ -146,9 +155,6 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                                         const struct sock *sk2 )    )
 {
        struct hlist_head *udptable = sk->sk_prot->h.udp_hash;
-       struct hlist_node *node;
-       struct hlist_head *head;
-       struct sock *sk2;
        int    error = 1;
        struct net *net = sock_net(sk);
 
@@ -165,32 +171,21 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                rand = net_random();
                snum = first = rand % remaining + low;
                rand |= 1;
-               while (__udp_lib_lport_inuse(net, snum, udptable)) {
+               while (udp_lib_lport_inuse(net, snum, udptable, sk,
+                                          saddr_comp)) {
                        do {
                                snum = snum + rand;
                        } while (snum < low || snum > high);
                        if (snum == first)
                                goto fail;
                }
-       } else {
-               head = &udptable[udp_hashfn(net, snum)];
-
-               sk_for_each(sk2, node, head)
-                       if (sk2->sk_hash == snum                             &&
-                           sk2 != sk                                        &&
-                           net_eq(sock_net(sk2), net)                       &&
-                           (!sk2->sk_reuse        || !sk->sk_reuse)         &&
-                           (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
-                            || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                           (*saddr_comp)(sk, sk2)                             )
-                               goto fail;
-       }
+       } else if (udp_lib_lport_inuse(net, snum, udptable, sk, saddr_comp))
+               goto fail;
 
        inet_sk(sk)->num = snum;
        sk->sk_hash = snum;
        if (sk_unhashed(sk)) {
-               head = &udptable[udp_hashfn(net, snum)];
-               sk_add_node(sk, head);
+               sk_add_node(sk, &udptable[udp_hashfn(net, snum)]);
                sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
        }
        error = 0;