8e243589045ff64d0ae5ee67271151f0e06cbc68
[linux-2.6.git] / net / ipv4 / ipcomp.c
1 /*
2  * IP Payload Compression Protocol (IPComp) - RFC3173.
3  *
4  * Copyright (c) 2003 James Morris <jmorris@intercode.com.au>
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by the Free
8  * Software Foundation; either version 2 of the License, or (at your option) 
9  * any later version.
10  *
11  * Todo:
12  *   - Tunable compression parameters.
13  *   - Compression stats.
14  *   - Adaptive compression.
15  */
16 #include <linux/config.h>
17 #include <linux/module.h>
18 #include <asm/scatterlist.h>
19 #include <asm/semaphore.h>
20 #include <linux/crypto.h>
21 #include <linux/pfkeyv2.h>
22 #include <linux/percpu.h>
23 #include <linux/smp.h>
24 #include <linux/list.h>
25 #include <linux/vmalloc.h>
26 #include <linux/rtnetlink.h>
27 #include <linux/mutex.h>
28 #include <net/ip.h>
29 #include <net/xfrm.h>
30 #include <net/icmp.h>
31 #include <net/ipcomp.h>
32 #include <net/protocol.h>
33
34 struct ipcomp_tfms {
35         struct list_head list;
36         struct crypto_tfm **tfms;
37         int users;
38 };
39
40 static DEFINE_MUTEX(ipcomp_resource_mutex);
41 static void **ipcomp_scratches;
42 static int ipcomp_scratch_users;
43 static LIST_HEAD(ipcomp_tfms_list);
44
45 static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb)
46 {
47         int err, plen, dlen;
48         struct ipcomp_data *ipcd = x->data;
49         u8 *start, *scratch;
50         struct crypto_tfm *tfm;
51         int cpu;
52         
53         plen = skb->len;
54         dlen = IPCOMP_SCRATCH_SIZE;
55         start = skb->data;
56
57         cpu = get_cpu();
58         scratch = *per_cpu_ptr(ipcomp_scratches, cpu);
59         tfm = *per_cpu_ptr(ipcd->tfms, cpu);
60
61         err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen);
62         if (err)
63                 goto out;
64
65         if (dlen < (plen + sizeof(struct ip_comp_hdr))) {
66                 err = -EINVAL;
67                 goto out;
68         }
69
70         err = pskb_expand_head(skb, 0, dlen - plen, GFP_ATOMIC);
71         if (err)
72                 goto out;
73                 
74         skb_put(skb, dlen - plen);
75         memcpy(skb->data, scratch, dlen);
76 out:    
77         put_cpu();
78         return err;
79 }
80
81 static int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb)
82 {
83         int err = 0;
84         struct iphdr *iph;
85         struct ip_comp_hdr *ipch;
86
87         if ((skb_is_nonlinear(skb) || skb_cloned(skb)) &&
88             skb_linearize(skb, GFP_ATOMIC) != 0) {
89                 err = -ENOMEM;
90                 goto out;
91         }
92
93         skb->ip_summed = CHECKSUM_NONE;
94
95         /* Remove ipcomp header and decompress original payload */      
96         iph = skb->nh.iph;
97         ipch = (void *)skb->data;
98         iph->protocol = ipch->nexthdr;
99         skb->h.raw = skb->nh.raw + sizeof(*ipch);
100         __skb_pull(skb, sizeof(*ipch));
101         err = ipcomp_decompress(x, skb);
102
103 out:    
104         return err;
105 }
106
107 static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb)
108 {
109         int err, plen, dlen, ihlen;
110         struct iphdr *iph = skb->nh.iph;
111         struct ipcomp_data *ipcd = x->data;
112         u8 *start, *scratch;
113         struct crypto_tfm *tfm;
114         int cpu;
115         
116         ihlen = iph->ihl * 4;
117         plen = skb->len - ihlen;
118         dlen = IPCOMP_SCRATCH_SIZE;
119         start = skb->data + ihlen;
120
121         cpu = get_cpu();
122         scratch = *per_cpu_ptr(ipcomp_scratches, cpu);
123         tfm = *per_cpu_ptr(ipcd->tfms, cpu);
124
125         err = crypto_comp_compress(tfm, start, plen, scratch, &dlen);
126         if (err)
127                 goto out;
128
129         if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) {
130                 err = -EMSGSIZE;
131                 goto out;
132         }
133         
134         memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen);
135         put_cpu();
136
137         pskb_trim(skb, ihlen + dlen + sizeof(struct ip_comp_hdr));
138         return 0;
139         
140 out:    
141         put_cpu();
142         return err;
143 }
144
145 static int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb)
146 {
147         int err;
148         struct iphdr *iph;
149         struct ip_comp_hdr *ipch;
150         struct ipcomp_data *ipcd = x->data;
151         int hdr_len = 0;
152
153         iph = skb->nh.iph;
154         iph->tot_len = htons(skb->len);
155         hdr_len = iph->ihl * 4;
156         if ((skb->len - hdr_len) < ipcd->threshold) {
157                 /* Don't bother compressing */
158                 goto out_ok;
159         }
160
161         if ((skb_is_nonlinear(skb) || skb_cloned(skb)) &&
162             skb_linearize(skb, GFP_ATOMIC) != 0) {
163                 goto out_ok;
164         }
165         
166         err = ipcomp_compress(x, skb);
167         iph = skb->nh.iph;
168
169         if (err) {
170                 goto out_ok;
171         }
172
173         /* Install ipcomp header, convert into ipcomp datagram. */
174         iph->tot_len = htons(skb->len);
175         ipch = (struct ip_comp_hdr *)((char *)iph + iph->ihl * 4);
176         ipch->nexthdr = iph->protocol;
177         ipch->flags = 0;
178         ipch->cpi = htons((u16 )ntohl(x->id.spi));
179         iph->protocol = IPPROTO_COMP;
180         ip_send_check(iph);
181         return 0;
182
183 out_ok:
184         if (x->props.mode)
185                 ip_send_check(iph);
186         return 0;
187 }
188
189 static void ipcomp4_err(struct sk_buff *skb, u32 info)
190 {
191         u32 spi;
192         struct iphdr *iph = (struct iphdr *)skb->data;
193         struct ip_comp_hdr *ipch = (struct ip_comp_hdr *)(skb->data+(iph->ihl<<2));
194         struct xfrm_state *x;
195
196         if (skb->h.icmph->type != ICMP_DEST_UNREACH ||
197             skb->h.icmph->code != ICMP_FRAG_NEEDED)
198                 return;
199
200         spi = htonl(ntohs(ipch->cpi));
201         x = xfrm_state_lookup((xfrm_address_t *)&iph->daddr,
202                               spi, IPPROTO_COMP, AF_INET);
203         if (!x)
204                 return;
205         NETDEBUG(KERN_DEBUG "pmtu discovery on SA IPCOMP/%08x/%u.%u.%u.%u\n",
206                  spi, NIPQUAD(iph->daddr));
207         xfrm_state_put(x);
208 }
209
210 /* We always hold one tunnel user reference to indicate a tunnel */ 
211 static struct xfrm_state *ipcomp_tunnel_create(struct xfrm_state *x)
212 {
213         struct xfrm_state *t;
214         
215         t = xfrm_state_alloc();
216         if (t == NULL)
217                 goto out;
218
219         t->id.proto = IPPROTO_IPIP;
220         t->id.spi = x->props.saddr.a4;
221         t->id.daddr.a4 = x->id.daddr.a4;
222         memcpy(&t->sel, &x->sel, sizeof(t->sel));
223         t->props.family = AF_INET;
224         t->props.mode = 1;
225         t->props.saddr.a4 = x->props.saddr.a4;
226         t->props.flags = x->props.flags;
227
228         if (xfrm_init_state(t))
229                 goto error;
230
231         atomic_set(&t->tunnel_users, 1);
232 out:
233         return t;
234
235 error:
236         t->km.state = XFRM_STATE_DEAD;
237         xfrm_state_put(t);
238         t = NULL;
239         goto out;
240 }
241
242 /*
243  * Must be protected by xfrm_cfg_mutex.  State and tunnel user references are
244  * always incremented on success.
245  */
246 static int ipcomp_tunnel_attach(struct xfrm_state *x)
247 {
248         int err = 0;
249         struct xfrm_state *t;
250
251         t = xfrm_state_lookup((xfrm_address_t *)&x->id.daddr.a4,
252                               x->props.saddr.a4, IPPROTO_IPIP, AF_INET);
253         if (!t) {
254                 t = ipcomp_tunnel_create(x);
255                 if (!t) {
256                         err = -EINVAL;
257                         goto out;
258                 }
259                 xfrm_state_insert(t);
260                 xfrm_state_hold(t);
261         }
262         x->tunnel = t;
263         atomic_inc(&t->tunnel_users);
264 out:
265         return err;
266 }
267
268 static void ipcomp_free_scratches(void)
269 {
270         int i;
271         void **scratches;
272
273         if (--ipcomp_scratch_users)
274                 return;
275
276         scratches = ipcomp_scratches;
277         if (!scratches)
278                 return;
279
280         for_each_possible_cpu(i)
281                 vfree(*per_cpu_ptr(scratches, i));
282
283         free_percpu(scratches);
284 }
285
286 static void **ipcomp_alloc_scratches(void)
287 {
288         int i;
289         void **scratches;
290
291         if (ipcomp_scratch_users++)
292                 return ipcomp_scratches;
293
294         scratches = alloc_percpu(void *);
295         if (!scratches)
296                 return NULL;
297
298         ipcomp_scratches = scratches;
299
300         for_each_possible_cpu(i) {
301                 void *scratch = vmalloc(IPCOMP_SCRATCH_SIZE);
302                 if (!scratch)
303                         return NULL;
304                 *per_cpu_ptr(scratches, i) = scratch;
305         }
306
307         return scratches;
308 }
309
310 static void ipcomp_free_tfms(struct crypto_tfm **tfms)
311 {
312         struct ipcomp_tfms *pos;
313         int cpu;
314
315         list_for_each_entry(pos, &ipcomp_tfms_list, list) {
316                 if (pos->tfms == tfms)
317                         break;
318         }
319
320         BUG_TRAP(pos);
321
322         if (--pos->users)
323                 return;
324
325         list_del(&pos->list);
326         kfree(pos);
327
328         if (!tfms)
329                 return;
330
331         for_each_possible_cpu(cpu) {
332                 struct crypto_tfm *tfm = *per_cpu_ptr(tfms, cpu);
333                 crypto_free_tfm(tfm);
334         }
335         free_percpu(tfms);
336 }
337
338 static struct crypto_tfm **ipcomp_alloc_tfms(const char *alg_name)
339 {
340         struct ipcomp_tfms *pos;
341         struct crypto_tfm **tfms;
342         int cpu;
343
344         /* This can be any valid CPU ID so we don't need locking. */
345         cpu = raw_smp_processor_id();
346
347         list_for_each_entry(pos, &ipcomp_tfms_list, list) {
348                 struct crypto_tfm *tfm;
349
350                 tfms = pos->tfms;
351                 tfm = *per_cpu_ptr(tfms, cpu);
352
353                 if (!strcmp(crypto_tfm_alg_name(tfm), alg_name)) {
354                         pos->users++;
355                         return tfms;
356                 }
357         }
358
359         pos = kmalloc(sizeof(*pos), GFP_KERNEL);
360         if (!pos)
361                 return NULL;
362
363         pos->users = 1;
364         INIT_LIST_HEAD(&pos->list);
365         list_add(&pos->list, &ipcomp_tfms_list);
366
367         pos->tfms = tfms = alloc_percpu(struct crypto_tfm *);
368         if (!tfms)
369                 goto error;
370
371         for_each_possible_cpu(cpu) {
372                 struct crypto_tfm *tfm = crypto_alloc_tfm(alg_name, 0);
373                 if (!tfm)
374                         goto error;
375                 *per_cpu_ptr(tfms, cpu) = tfm;
376         }
377
378         return tfms;
379
380 error:
381         ipcomp_free_tfms(tfms);
382         return NULL;
383 }
384
385 static void ipcomp_free_data(struct ipcomp_data *ipcd)
386 {
387         if (ipcd->tfms)
388                 ipcomp_free_tfms(ipcd->tfms);
389         ipcomp_free_scratches();
390 }
391
392 static void ipcomp_destroy(struct xfrm_state *x)
393 {
394         struct ipcomp_data *ipcd = x->data;
395         if (!ipcd)
396                 return;
397         xfrm_state_delete_tunnel(x);
398         mutex_lock(&ipcomp_resource_mutex);
399         ipcomp_free_data(ipcd);
400         mutex_unlock(&ipcomp_resource_mutex);
401         kfree(ipcd);
402 }
403
404 static int ipcomp_init_state(struct xfrm_state *x)
405 {
406         int err;
407         struct ipcomp_data *ipcd;
408         struct xfrm_algo_desc *calg_desc;
409
410         err = -EINVAL;
411         if (!x->calg)
412                 goto out;
413
414         if (x->encap)
415                 goto out;
416
417         err = -ENOMEM;
418         ipcd = kmalloc(sizeof(*ipcd), GFP_KERNEL);
419         if (!ipcd)
420                 goto out;
421
422         memset(ipcd, 0, sizeof(*ipcd));
423         x->props.header_len = 0;
424         if (x->props.mode)
425                 x->props.header_len += sizeof(struct iphdr);
426
427         mutex_lock(&ipcomp_resource_mutex);
428         if (!ipcomp_alloc_scratches())
429                 goto error;
430
431         ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name);
432         if (!ipcd->tfms)
433                 goto error;
434         mutex_unlock(&ipcomp_resource_mutex);
435
436         if (x->props.mode) {
437                 err = ipcomp_tunnel_attach(x);
438                 if (err)
439                         goto error_tunnel;
440         }
441
442         calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0);
443         BUG_ON(!calg_desc);
444         ipcd->threshold = calg_desc->uinfo.comp.threshold;
445         x->data = ipcd;
446         err = 0;
447 out:
448         return err;
449
450 error_tunnel:
451         mutex_lock(&ipcomp_resource_mutex);
452 error:
453         ipcomp_free_data(ipcd);
454         mutex_unlock(&ipcomp_resource_mutex);
455         kfree(ipcd);
456         goto out;
457 }
458
459 static struct xfrm_type ipcomp_type = {
460         .description    = "IPCOMP4",
461         .owner          = THIS_MODULE,
462         .proto          = IPPROTO_COMP,
463         .init_state     = ipcomp_init_state,
464         .destructor     = ipcomp_destroy,
465         .input          = ipcomp_input,
466         .output         = ipcomp_output
467 };
468
469 static struct net_protocol ipcomp4_protocol = {
470         .handler        =       xfrm4_rcv,
471         .err_handler    =       ipcomp4_err,
472         .no_policy      =       1,
473 };
474
475 static int __init ipcomp4_init(void)
476 {
477         if (xfrm_register_type(&ipcomp_type, AF_INET) < 0) {
478                 printk(KERN_INFO "ipcomp init: can't add xfrm type\n");
479                 return -EAGAIN;
480         }
481         if (inet_add_protocol(&ipcomp4_protocol, IPPROTO_COMP) < 0) {
482                 printk(KERN_INFO "ipcomp init: can't add protocol\n");
483                 xfrm_unregister_type(&ipcomp_type, AF_INET);
484                 return -EAGAIN;
485         }
486         return 0;
487 }
488
489 static void __exit ipcomp4_fini(void)
490 {
491         if (inet_del_protocol(&ipcomp4_protocol, IPPROTO_COMP) < 0)
492                 printk(KERN_INFO "ip ipcomp close: can't remove protocol\n");
493         if (xfrm_unregister_type(&ipcomp_type, AF_INET) < 0)
494                 printk(KERN_INFO "ip ipcomp close: can't remove xfrm type\n");
495 }
496
497 module_init(ipcomp4_init);
498 module_exit(ipcomp4_fini);
499
500 MODULE_LICENSE("GPL");
501 MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) - RFC3173");
502 MODULE_AUTHOR("James Morris <jmorris@intercode.com.au>");
503