[NETFILTER]: xt_connbytes: fix division by zero
[linux-2.6.git] / net / netfilter / xt_connbytes.c
1 /* Kernel module to match connection tracking byte counter.
2  * GPL (C) 2002 Martin Devera (devik@cdi.cz).
3  *
4  * 2004-07-20 Harald Welte <laforge@netfilter.org>
5  *      - reimplemented to use per-connection accounting counters
6  *      - add functionality to match number of packets
7  *      - add functionality to match average packet size
8  *      - add support to match directions seperately
9  * 2005-10-16 Harald Welte <laforge@netfilter.org>
10  *      - Port to x_tables
11  *
12  */
13 #include <linux/module.h>
14 #include <linux/skbuff.h>
15 #include <net/netfilter/nf_conntrack_compat.h>
16 #include <linux/netfilter/x_tables.h>
17 #include <linux/netfilter/xt_connbytes.h>
18
19 #include <asm/div64.h>
20 #include <asm/bitops.h>
21
22 MODULE_LICENSE("GPL");
23 MODULE_AUTHOR("Harald Welte <laforge@netfilter.org>");
24 MODULE_DESCRIPTION("iptables match for matching number of pkts/bytes per connection");
25 MODULE_ALIAS("ipt_connbytes");
26
27 /* 64bit divisor, dividend and result. dynamic precision */
28 static u_int64_t div64_64(u_int64_t dividend, u_int64_t divisor)
29 {
30         u_int32_t d = divisor;
31
32         if (divisor > 0xffffffffULL) {
33                 unsigned int shift = fls(divisor >> 32);
34
35                 d = divisor >> shift;
36                 dividend >>= shift;
37         }
38
39         do_div(dividend, d);
40         return dividend;
41 }
42
43 static int
44 match(const struct sk_buff *skb,
45       const struct net_device *in,
46       const struct net_device *out,
47       const struct xt_match *match,
48       const void *matchinfo,
49       int offset,
50       unsigned int protoff,
51       int *hotdrop)
52 {
53         const struct xt_connbytes_info *sinfo = matchinfo;
54         u_int64_t what = 0;     /* initialize to make gcc happy */
55         u_int64_t bytes = 0;
56         u_int64_t pkts = 0;
57         const struct ip_conntrack_counter *counters;
58
59         if (!(counters = nf_ct_get_counters(skb)))
60                 return 0; /* no match */
61
62         switch (sinfo->what) {
63         case XT_CONNBYTES_PKTS:
64                 switch (sinfo->direction) {
65                 case XT_CONNBYTES_DIR_ORIGINAL:
66                         what = counters[IP_CT_DIR_ORIGINAL].packets;
67                         break;
68                 case XT_CONNBYTES_DIR_REPLY:
69                         what = counters[IP_CT_DIR_REPLY].packets;
70                         break;
71                 case XT_CONNBYTES_DIR_BOTH:
72                         what = counters[IP_CT_DIR_ORIGINAL].packets;
73                         what += counters[IP_CT_DIR_REPLY].packets;
74                         break;
75                 }
76                 break;
77         case XT_CONNBYTES_BYTES:
78                 switch (sinfo->direction) {
79                 case XT_CONNBYTES_DIR_ORIGINAL:
80                         what = counters[IP_CT_DIR_ORIGINAL].bytes;
81                         break;
82                 case XT_CONNBYTES_DIR_REPLY:
83                         what = counters[IP_CT_DIR_REPLY].bytes;
84                         break;
85                 case XT_CONNBYTES_DIR_BOTH:
86                         what = counters[IP_CT_DIR_ORIGINAL].bytes;
87                         what += counters[IP_CT_DIR_REPLY].bytes;
88                         break;
89                 }
90                 break;
91         case XT_CONNBYTES_AVGPKT:
92                 switch (sinfo->direction) {
93                 case XT_CONNBYTES_DIR_ORIGINAL:
94                         bytes = counters[IP_CT_DIR_ORIGINAL].bytes;
95                         pkts  = counters[IP_CT_DIR_ORIGINAL].packets;
96                         break;
97                 case XT_CONNBYTES_DIR_REPLY:
98                         bytes = counters[IP_CT_DIR_REPLY].bytes;
99                         pkts  = counters[IP_CT_DIR_REPLY].packets;
100                         break;
101                 case XT_CONNBYTES_DIR_BOTH:
102                         bytes = counters[IP_CT_DIR_ORIGINAL].bytes +
103                                 counters[IP_CT_DIR_REPLY].bytes;
104                         pkts  = counters[IP_CT_DIR_ORIGINAL].packets +
105                                 counters[IP_CT_DIR_REPLY].packets;
106                         break;
107                 }
108                 if (pkts != 0)
109                         what = div64_64(bytes, pkts);
110                 break;
111         }
112
113         if (sinfo->count.to)
114                 return (what <= sinfo->count.to && what >= sinfo->count.from);
115         else
116                 return (what >= sinfo->count.from);
117 }
118
119 static int check(const char *tablename,
120                  const void *ip,
121                  const struct xt_match *match,
122                  void *matchinfo,
123                  unsigned int hook_mask)
124 {
125         const struct xt_connbytes_info *sinfo = matchinfo;
126
127         if (sinfo->what != XT_CONNBYTES_PKTS &&
128             sinfo->what != XT_CONNBYTES_BYTES &&
129             sinfo->what != XT_CONNBYTES_AVGPKT)
130                 return 0;
131
132         if (sinfo->direction != XT_CONNBYTES_DIR_ORIGINAL &&
133             sinfo->direction != XT_CONNBYTES_DIR_REPLY &&
134             sinfo->direction != XT_CONNBYTES_DIR_BOTH)
135                 return 0;
136
137         if (nf_ct_l3proto_try_module_get(match->family) < 0) {
138                 printk(KERN_WARNING "can't load conntrack support for "
139                                     "proto=%d\n", match->family);
140                 return 0;
141         }
142
143         return 1;
144 }
145
146 static void
147 destroy(const struct xt_match *match, void *matchinfo)
148 {
149         nf_ct_l3proto_module_put(match->family);
150 }
151
152 static struct xt_match xt_connbytes_match[] = {
153         {
154                 .name           = "connbytes",
155                 .family         = AF_INET,
156                 .checkentry     = check,
157                 .match          = match,
158                 .destroy        = destroy,
159                 .matchsize      = sizeof(struct xt_connbytes_info),
160                 .me             = THIS_MODULE
161         },
162         {
163                 .name           = "connbytes",
164                 .family         = AF_INET6,
165                 .checkentry     = check,
166                 .match          = match,
167                 .destroy        = destroy,
168                 .matchsize      = sizeof(struct xt_connbytes_info),
169                 .me             = THIS_MODULE
170         },
171 };
172
173 static int __init xt_connbytes_init(void)
174 {
175         return xt_register_matches(xt_connbytes_match,
176                                    ARRAY_SIZE(xt_connbytes_match));
177 }
178
179 static void __exit xt_connbytes_fini(void)
180 {
181         xt_unregister_matches(xt_connbytes_match,
182                               ARRAY_SIZE(xt_connbytes_match));
183 }
184
185 module_init(xt_connbytes_init);
186 module_exit(xt_connbytes_fini);