ipv4: fix redirect handling
[linux-2.6.git] / net / ipv4 / route.c
index a8ccd9b..b2e9544 100644 (file)
 #ifdef CONFIG_SYSCTL
 #include <linux/sysctl.h>
 #endif
+#include <net/atmclip.h>
+#include <net/secure_seq.h>
 
 #define RT_FL_TOS(oldflp4) \
     ((u32)(oldflp4->flowi4_tos & (IPTOS_RT_MASK | RTO_ONLINK)))
@@ -184,6 +186,8 @@ static u32 *ipv4_cow_metrics(struct dst_entry *dst, unsigned long old)
        return p;
 }
 
+static struct neighbour *ipv4_neigh_lookup(const struct dst_entry *dst, const void *daddr);
+
 static struct dst_ops ipv4_dst_ops = {
        .family =               AF_INET,
        .protocol =             cpu_to_be16(ETH_P_IP),
@@ -198,6 +202,7 @@ static struct dst_ops ipv4_dst_ops = {
        .link_failure =         ipv4_link_failure,
        .update_pmtu =          ip_rt_update_pmtu,
        .local_out =            __ip_local_out,
+       .neigh_lookup =         ipv4_neigh_lookup,
 };
 
 #define ECN_OR_COST(class)     TC_PRIO_##class
@@ -411,7 +416,13 @@ static int rt_cache_seq_show(struct seq_file *seq, void *v)
                           "HHUptod\tSpecDst");
        else {
                struct rtable *r = v;
-               int len;
+               struct neighbour *n;
+               int len, HHUptod;
+
+               rcu_read_lock();
+               n = dst_get_neighbour(&r->dst);
+               HHUptod = (n && (n->nud_state & NUD_CONNECTED)) ? 1 : 0;
+               rcu_read_unlock();
 
                seq_printf(seq, "%s\t%08X\t%08X\t%8X\t%d\t%u\t%d\t"
                              "%08X\t%d\t%u\t%u\t%02X\t%d\t%1d\t%08X%n",
@@ -425,9 +436,8 @@ static int rt_cache_seq_show(struct seq_file *seq, void *v)
                        (int)((dst_metric(&r->dst, RTAX_RTT) >> 3) +
                              dst_metric(&r->dst, RTAX_RTTVAR)),
                        r->rt_key_tos,
-                       r->dst.hh ? atomic_read(&r->dst.hh->hh_refcnt) : -1,
-                       r->dst.hh ? (r->dst.hh->hh_output ==
-                                      dev_queue_xmit) : 0,
+                       -1,
+                       HHUptod,
                        r->rt_spec_dst, &len);
 
                seq_printf(seq, "%*s\n", 127 - len, "");
@@ -716,7 +726,7 @@ static inline bool compare_hash_inputs(const struct rtable *rt1,
 {
        return ((((__force u32)rt1->rt_key_dst ^ (__force u32)rt2->rt_key_dst) |
                ((__force u32)rt1->rt_key_src ^ (__force u32)rt2->rt_key_src) |
-               (rt1->rt_iif ^ rt2->rt_iif)) == 0);
+               (rt1->rt_route_iif ^ rt2->rt_route_iif)) == 0);
 }
 
 static inline int compare_keys(struct rtable *rt1, struct rtable *rt2)
@@ -725,8 +735,8 @@ static inline int compare_keys(struct rtable *rt1, struct rtable *rt2)
                ((__force u32)rt1->rt_key_src ^ (__force u32)rt2->rt_key_src) |
                (rt1->rt_mark ^ rt2->rt_mark) |
                (rt1->rt_key_tos ^ rt2->rt_key_tos) |
-               (rt1->rt_oif ^ rt2->rt_oif) |
-               (rt1->rt_iif ^ rt2->rt_iif)) == 0;
+               (rt1->rt_route_iif ^ rt2->rt_route_iif) |
+               (rt1->rt_oif ^ rt2->rt_oif)) == 0;
 }
 
 static inline int compare_netns(struct rtable *rt1, struct rtable *rt2)
@@ -1006,6 +1016,37 @@ static int slow_chain_length(const struct rtable *head)
        return length >> FRACT_BITS;
 }
 
+static struct neighbour *ipv4_neigh_lookup(const struct dst_entry *dst, const void *daddr)
+{
+       struct neigh_table *tbl = &arp_tbl;
+       static const __be32 inaddr_any = 0;
+       struct net_device *dev = dst->dev;
+       const __be32 *pkey = daddr;
+       struct neighbour *n;
+
+#if defined(CONFIG_ATM_CLIP) || defined(CONFIG_ATM_CLIP_MODULE)
+       if (dev->type == ARPHRD_ATM)
+               tbl = clip_tbl_hook;
+#endif
+       if (dev->flags & (IFF_LOOPBACK | IFF_POINTOPOINT))
+               pkey = &inaddr_any;
+
+       n = __ipv4_neigh_lookup(tbl, dev, *(__force u32 *)pkey);
+       if (n)
+               return n;
+       return neigh_create(tbl, pkey, dev);
+}
+
+static int rt_bind_neighbour(struct rtable *rt)
+{
+       struct neighbour *n = ipv4_neigh_lookup(&rt->dst, &rt->rt_gateway);
+       if (IS_ERR(n))
+               return PTR_ERR(n);
+       dst_set_neighbour(&rt->dst, n);
+
+       return 0;
+}
+
 static struct rtable *rt_intern_hash(unsigned hash, struct rtable *rt,
                                     struct sk_buff *skb, int ifindex)
 {
@@ -1042,7 +1083,7 @@ restart:
 
                rt->dst.flags |= DST_NOCACHE;
                if (rt->rt_type == RTN_UNICAST || rt_is_output_route(rt)) {
-                       int err = arp_bind_neighbour(&rt->dst);
+                       int err = rt_bind_neighbour(rt);
                        if (err) {
                                if (net_ratelimit())
                                        printk(KERN_WARNING
@@ -1138,7 +1179,7 @@ restart:
           route or unicast forwarding path.
         */
        if (rt->rt_type == RTN_UNICAST || rt_is_output_route(rt)) {
-               int err = arp_bind_neighbour(&rt->dst);
+               int err = rt_bind_neighbour(rt);
                if (err) {
                        spin_unlock_bh(rt_hash_lock_addr(hash));
 
@@ -1268,11 +1309,42 @@ static void rt_del(unsigned hash, struct rtable *rt)
        spin_unlock_bh(rt_hash_lock_addr(hash));
 }
 
+static int check_peer_redir(struct dst_entry *dst, struct inet_peer *peer)
+{
+       struct rtable *rt = (struct rtable *) dst;
+       __be32 orig_gw = rt->rt_gateway;
+       struct neighbour *n, *old_n;
+
+       dst_confirm(&rt->dst);
+
+       rt->rt_gateway = peer->redirect_learned.a4;
+
+       n = ipv4_neigh_lookup(&rt->dst, &rt->rt_gateway);
+       if (IS_ERR(n))
+               return PTR_ERR(n);
+       old_n = xchg(&rt->dst._neighbour, n);
+       if (old_n)
+               neigh_release(old_n);
+       if (!n || !(n->nud_state & NUD_VALID)) {
+               if (n)
+                       neigh_event_send(n, NULL);
+               rt->rt_gateway = orig_gw;
+               return -EAGAIN;
+       } else {
+               rt->rt_flags |= RTCF_REDIRECTED;
+               call_netevent_notifiers(NETEVENT_NEIGH_UPDATE, n);
+       }
+       return 0;
+}
+
 /* called in rcu_read_lock() section */
 void ip_rt_redirect(__be32 old_gw, __be32 daddr, __be32 new_gw,
                    __be32 saddr, struct net_device *dev)
 {
+       int s, i;
        struct in_device *in_dev = __in_dev_get_rcu(dev);
+       __be32 skeys[2] = { saddr, 0 };
+       int    ikeys[2] = { dev->ifindex, 0 };
        struct inet_peer *peer;
        struct net *net;
 
@@ -1295,13 +1367,43 @@ void ip_rt_redirect(__be32 old_gw, __be32 daddr, __be32 new_gw,
                        goto reject_redirect;
        }
 
-       peer = inet_getpeer_v4(daddr, 1);
-       if (peer) {
-               peer->redirect_learned.a4 = new_gw;
+       for (s = 0; s < 2; s++) {
+               for (i = 0; i < 2; i++) {
+                       unsigned int hash;
+                       struct rtable __rcu **rthp;
+                       struct rtable *rt;
 
-               inet_putpeer(peer);
+                       hash = rt_hash(daddr, skeys[s], ikeys[i], rt_genid(net));
+
+                       rthp = &rt_hash_table[hash].chain;
+
+                       while ((rt = rcu_dereference(*rthp)) != NULL) {
+                               rthp = &rt->dst.rt_next;
+
+                               if (rt->rt_key_dst != daddr ||
+                                   rt->rt_key_src != skeys[s] ||
+                                   rt->rt_oif != ikeys[i] ||
+                                   rt_is_input_route(rt) ||
+                                   rt_is_expired(rt) ||
+                                   !net_eq(dev_net(rt->dst.dev), net) ||
+                                   rt->dst.error ||
+                                   rt->dst.dev != dev ||
+                                   rt->rt_gateway != old_gw)
+                                       continue;
 
-               atomic_inc(&__rt_peer_genid);
+                               if (!rt->peer)
+                                       rt_bind_peer(rt, rt->rt_dst, 1);
+
+                               peer = rt->peer;
+                               if (peer) {
+                                       if (peer->redirect_learned.a4 != new_gw) {
+                                               peer->redirect_learned.a4 = new_gw;
+                                               atomic_inc(&__rt_peer_genid);
+                                       }
+                                       check_peer_redir(&rt->dst, peer);
+                               }
+                       }
+               }
        }
        return;
 
@@ -1531,11 +1633,10 @@ unsigned short ip_rt_frag_needed(struct net *net, const struct iphdr *iph,
                        est_mtu = mtu;
                        peer->pmtu_learned = mtu;
                        peer->pmtu_expires = pmtu_expires;
+                       atomic_inc(&__rt_peer_genid);
                }
 
                inet_putpeer(peer);
-
-               atomic_inc(&__rt_peer_genid);
        }
        return est_mtu ? : new_mtu;
 }
@@ -1588,30 +1689,6 @@ static void ip_rt_update_pmtu(struct dst_entry *dst, u32 mtu)
        }
 }
 
-static int check_peer_redir(struct dst_entry *dst, struct inet_peer *peer)
-{
-       struct rtable *rt = (struct rtable *) dst;
-       __be32 orig_gw = rt->rt_gateway;
-
-       dst_confirm(&rt->dst);
-
-       neigh_release(rt->dst.neighbour);
-       rt->dst.neighbour = NULL;
-
-       rt->rt_gateway = peer->redirect_learned.a4;
-       if (arp_bind_neighbour(&rt->dst) ||
-           !(rt->dst.neighbour->nud_state & NUD_VALID)) {
-               if (rt->dst.neighbour)
-                       neigh_event_send(rt->dst.neighbour, NULL);
-               rt->rt_gateway = orig_gw;
-               return -EAGAIN;
-       } else {
-               rt->rt_flags |= RTCF_REDIRECTED;
-               call_netevent_notifiers(NETEVENT_NEIGH_UPDATE,
-                                       rt->dst.neighbour);
-       }
-       return 0;
-}
 
 static struct dst_entry *ipv4_dst_check(struct dst_entry *dst, u32 cookie)
 {
@@ -1703,7 +1780,7 @@ void ip_rt_get_source(u8 *addr, struct sk_buff *skb, struct rtable *rt)
                memset(&fl4, 0, sizeof(fl4));
                fl4.daddr = iph->daddr;
                fl4.saddr = iph->saddr;
-               fl4.flowi4_tos = iph->tos;
+               fl4.flowi4_tos = RT_TOS(iph->tos);
                fl4.flowi4_oif = rt->dst.dev->ifindex;
                fl4.flowi4_iif = skb->dev->ifindex;
                fl4.flowi4_mark = skb->mark;
@@ -2280,8 +2357,7 @@ int ip_route_input_common(struct sk_buff *skb, __be32 daddr, __be32 saddr,
             rth = rcu_dereference(rth->dst.rt_next)) {
                if ((((__force u32)rth->rt_key_dst ^ (__force u32)daddr) |
                     ((__force u32)rth->rt_key_src ^ (__force u32)saddr) |
-                    (rth->rt_iif ^ iif) |
-                    rth->rt_oif |
+                    (rth->rt_route_iif ^ iif) |
                     (rth->rt_key_tos ^ tos)) == 0 &&
                    rth->rt_mark == skb->mark &&
                    net_eq(dev_net(rth->dst.dev), net) &&
@@ -2708,6 +2784,7 @@ static struct dst_ops ipv4_dst_blackhole_ops = {
        .default_advmss         =       ipv4_default_advmss,
        .update_pmtu            =       ipv4_rt_blackhole_update_pmtu,
        .cow_metrics            =       ipv4_rt_blackhole_cow_metrics,
+       .neigh_lookup           =       ipv4_neigh_lookup,
 };
 
 struct dst_entry *ipv4_blackhole_route(struct net *net, struct dst_entry *dst_orig)