1 /*
2  * Copyright (c) 2000 by Sun Microsystems, Inc.
3  * All rights reserved.
4  *
5  * Routines to compress and uncompess tcp packets (for transmission
6  * over low speed serial lines.
7  *
8  * Copyright (c) 1989 Regents of the University of California.
9  * All rights reserved.
10  *
11  * Redistribution and use in source and binary forms are permitted
12  * provided that the above copyright notice and this paragraph are
13  * duplicated in all such forms and that any documentation,
14  * advertising materials, and other materials related to such
15  * distribution and use acknowledge that the software was developed
16  * by the University of California, Berkeley.  The name of the
17  * University may not be used to endorse or promote products derived
18  * from this software without specific prior written permission.
19  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED
21  * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
22  *
23  *	Van Jacobson (van@helios.ee.lbl.gov), Dec 31, 1989:
24  *	- Initial distribution.
25  *
26  * Modified June 1993 by Paul Mackerras, paulus@cs.anu.edu.au,
27  * so that the entire packet being decompressed doesn't have
28  * to be in contiguous memory (just the compressed header).
29  */
30 
31 /*
32  * This version is used under STREAMS in Solaris 2
33  *
34  * $Id: vjcompress.c,v 1.10 1999/09/15 23:49:06 masputra Exp $
35  */
36 
37 #include <sys/types.h>
38 #include <sys/param.h>
39 #include <sys/byteorder.h>	/* for ntohl, etc. */
40 #include <sys/systm.h>
41 #include <sys/sysmacros.h>
42 
43 #include <netinet/in.h>
44 #include <netinet/in_systm.h>
45 #include <netinet/ip.h>
46 #include <netinet/tcp.h>
47 
48 #include <net/ppp_defs.h>
49 #include <net/vjcompress.h>
50 
51 #pragma ident	"%Z%%M%	%I%	%E% SMI"
52 
53 #ifndef VJ_NO_STATS
54 #define	INCR(counter) ++comp->stats.counter
55 #else
56 #define	INCR(counter)
57 #endif
58 
59 #define	BCMP(p1, p2, n) bcmp((char *)(p1), (char *)(p2), (unsigned int)(n))
60 
61 #undef  BCOPY
62 #define	BCOPY(p1, p2, n) bcopy((char *)(p1), (char *)(p2), (unsigned int)(n))
63 
64 /*
65  * I'd like to use offsetof(struct ip,ip_hl) and offsetof(struct
66  * tcp,th_off), but these are bitfields.
67  */
68 #define	getip_hl(bp)	(((uchar_t *)bp)[0] & 0x0F)
69 #define	getth_off(bp)	(((uchar_t *)bp)[12] >> 4)
70 #define	getip_p(bp)	(((uchar_t *)bp)[offsetof(struct ip, ip_p)])
71 #define	setip_p(bp, v)	(((uchar_t *)bp)[offsetof(struct ip, ip_p)] = (v))
72 
73 /*
74  * vj_compress_init()
75  */
76 void
77 vj_compress_init(struct vjcompress *comp, int max_state)
78 {
79 	register uint_t		i;
80 	register struct cstate	*tstate = comp->tstate;
81 
82 	if (max_state == -1) {
83 		max_state = MAX_STATES - 1;
84 	}
85 
86 	bzero((char *)comp, sizeof (*comp));
87 
88 	for (i = max_state; i > 0; --i) {
89 		tstate[i].cs_id = i & 0xff;
90 		tstate[i].cs_next = &tstate[i - 1];
91 	}
92 
93 	tstate[0].cs_next = &tstate[max_state];
94 	tstate[0].cs_id = 0;
95 
96 	comp->last_cs = &tstate[0];
97 	comp->last_recv = 255;
98 	comp->last_xmit = 255;
99 	comp->flags = VJF_TOSS;
100 }
101 
102 /*
103  * ENCODE encodes a number that is known to be non-zero.  ENCODEZ
104  * checks for zero (since zero has to be encoded in the long, 3 byte
105  * form).
106  */
107 #define	ENCODE(n) {						\
108 	if ((ushort_t)(n) >= 256) {				\
109 		*cp++ = 0;					\
110 		cp[1] = (n) & 0xff;				\
111 		cp[0] = ((n) >> 8) & 0xff;			\
112 		cp += 2;					\
113 	} else {						\
114 		*cp++ = (n) & 0xff;				\
115 	}							\
116 }
117 #define	ENCODEZ(n) {						\
118 	if ((ushort_t)(n) >= 256 || (ushort_t)(n) == 0) {	\
119 		*cp++ = 0;					\
120 		cp[1] = (n) & 0xff;				\
121 		cp[0] = ((n) >> 8) & 0xff;			\
122 		cp += 2;					\
123 	} else {						\
124 		*cp++ = (n) & 0xff;				\
125 	}							\
126 }
127 
128 #define	DECODEL(f) {							\
129 	if (*cp == 0) {							\
130 		uint32_t tmp = ntohl(f) + ((cp[1] << 8) | cp[2]);	\
131 		(f) = htonl(tmp);					\
132 		cp += 3;						\
133 	} else {							\
134 		uint32_t tmp = ntohl(f) + (uint32_t)*cp++;		\
135 		(f) = htonl(tmp);					\
136 	}								\
137 }
138 
139 #define	DECODES(f) {							\
140 	if (*cp == 0) {							\
141 		ushort_t tmp = ntohs(f) + ((cp[1] << 8) | cp[2]);	\
142 		(f) = htons(tmp);					\
143 		cp += 3;						\
144 	} else {							\
145 		ushort_t tmp = ntohs(f) + (uint32_t)*cp++;		\
146 		(f) = htons(tmp);					\
147 	}								\
148 }
149 
150 #define	DECODEU(f) {							\
151 	if (*cp == 0) {							\
152 		(f) = htons((cp[1] << 8) | cp[2]);			\
153 		cp += 3;						\
154 	} else {							\
155 		(f) = htons((uint32_t)*cp++);				\
156 	}								\
157 }
158 
159 uint_t
160 vj_compress_tcp(register struct ip *ip, uint_t mlen, struct vjcompress *comp,
161 	int compress_cid, uchar_t **vjhdrp)
162 {
163 	register struct cstate	*cs = comp->last_cs->cs_next;
164 	register uint_t		hlen = getip_hl(ip);
165 	register struct tcphdr	*oth;
166 	register struct tcphdr	*th;
167 	register uint_t		deltaS;
168 	register uint_t		deltaA;
169 	register uint_t		changes = 0;
170 	uchar_t			new_seq[16];
171 	register uchar_t	*cp = new_seq;
172 	register uint_t		thlen;
173 
174 	/*
175 	 * Bail if this is an IP fragment or if the TCP packet isn't
176 	 * `compressible' (i.e., ACK isn't set or some other control bit is
177 	 * set).  (We assume that the caller has already made sure the
178 	 * packet is IP proto TCP)
179 	 */
180 	if ((ip->ip_off & htons(0x3fff)) || mlen < 40) {
181 		return (TYPE_IP);
182 	}
183 
184 	th = (struct tcphdr *)&((int *)ip)[hlen];
185 
186 	if ((th->th_flags & (TH_SYN|TH_FIN|TH_RST|TH_ACK)) != TH_ACK) {
187 		return (TYPE_IP);
188 	}
189 
190 	thlen = (hlen + getth_off(th)) << 2;
191 	if (thlen > mlen) {
192 		return (TYPE_IP);
193 	}
194 
195 	/*
196 	 * Packet is compressible -- we're going to send either a
197 	 * COMPRESSED_TCP or UNCOMPRESSED_TCP packet.  Either way we need
198 	 * to locate (or create) the connection state.  Special case the
199 	 * most recently used connection since it's most likely to be used
200 	 * again & we don't have to do any reordering if it's used.
201 	 */
202 	INCR(vjs_packets);
203 
204 	if (ip->ip_src.s_addr != cs->cs_ip.ip_src.s_addr ||
205 		ip->ip_dst.s_addr != cs->cs_ip.ip_dst.s_addr ||
206 		*(int *)th != ((int *)&cs->cs_ip)[getip_hl(&cs->cs_ip)]) {
207 
208 		/*
209 		 * Wasn't the first -- search for it.
210 		 *
211 		 * States are kept in a circularly linked list with
212 		 * last_cs pointing to the end of the list.  The
213 		 * list is kept in lru order by moving a state to the
214 		 * head of the list whenever it is referenced.  Since
215 		 * the list is short and, empirically, the connection
216 		 * we want is almost always near the front, we locate
217 		 * states via linear search.  If we don't find a state
218 		 * for the datagram, the oldest state is (re-)used.
219 		 */
220 		register struct cstate	*lcs;
221 		register struct cstate	*lastcs = comp->last_cs;
222 
223 		do {
224 			lcs = cs; cs = cs->cs_next;
225 
226 			INCR(vjs_searches);
227 
228 			if (ip->ip_src.s_addr == cs->cs_ip.ip_src.s_addr &&
229 				ip->ip_dst.s_addr == cs->cs_ip.ip_dst.s_addr &&
230 				*(int *)th == ((int *)
231 					&cs->cs_ip)[getip_hl(&cs->cs_ip)]) {
232 
233 				goto found;
234 			}
235 
236 		} while (cs != lastcs);
237 
238 		/*
239 		 * Didn't find it -- re-use oldest cstate.  Send an
240 		 * uncompressed packet that tells the other side what
241 		 * connection number we're using for this conversation.
242 		 * Note that since the state list is circular, the oldest
243 		 * state points to the newest and we only need to set
244 		 * last_cs to update the lru linkage.
245 		 */
246 		INCR(vjs_misses);
247 
248 		comp->last_cs = lcs;
249 
250 		goto uncompressed;
251 
252 found:
253 		/*
254 		 * Found it -- move to the front on the connection list.
255 		 */
256 		if (cs == lastcs) {
257 			comp->last_cs = lcs;
258 		} else {
259 			lcs->cs_next = cs->cs_next;
260 			cs->cs_next = lastcs->cs_next;
261 			lastcs->cs_next = cs;
262 		}
263 	}
264 
265 	/*
266 	 * Make sure that only what we expect to change changed. The first
267 	 * line of the `if' checks the IP protocol version, header length &
268 	 * type of service.  The 2nd line checks the "Don't fragment" bit.
269 	 * The 3rd line checks the time-to-live and protocol (the protocol
270 	 * check is unnecessary but costless).  The 4th line checks the TCP
271 	 * header length.  The 5th line checks IP options, if any.  The 6th
272 	 * line checks TCP options, if any.  If any of these things are
273 	 * different between the previous & current datagram, we send the
274 	 * current datagram `uncompressed'.
275 	 */
276 	oth = (struct tcphdr *)&((int *)&cs->cs_ip)[hlen];
277 
278 	/* Used to check for IP options. */
279 	deltaS = hlen;
280 
281 	if (((ushort_t *)ip)[0] != ((ushort_t *)&cs->cs_ip)[0] ||
282 		((ushort_t *)ip)[3] != ((ushort_t *)&cs->cs_ip)[3] ||
283 		((ushort_t *)ip)[4] != ((ushort_t *)&cs->cs_ip)[4] ||
284 		getth_off(th) != getth_off(oth) ||
285 		(deltaS > 5 &&
286 			BCMP(ip + 1, &cs->cs_ip + 1, (deltaS - 5) << 2)) ||
287 		(getth_off(th) > 5 &&
288 			BCMP(th + 1, oth + 1, (getth_off(th) - 5) << 2))) {
289 
290 		goto uncompressed;
291 	}
292 
293 	/*
294 	 * Figure out which of the changing fields changed.  The
295 	 * receiver expects changes in the order: urgent, window,
296 	 * ack, seq (the order minimizes the number of temporaries
297 	 * needed in this section of code).
298 	 */
299 	if (th->th_flags & TH_URG) {
300 
301 		deltaS = ntohs(th->th_urp);
302 
303 		ENCODEZ(deltaS);
304 
305 		changes |= NEW_U;
306 
307 	} else if (th->th_urp != oth->th_urp) {
308 
309 		/*
310 		 * argh! URG not set but urp changed -- a sensible
311 		 * implementation should never do this but RFC793
312 		 * doesn't prohibit the change so we have to deal
313 		 * with it
314 		 */
315 		goto uncompressed;
316 	}
317 
318 	if ((deltaS = (ushort_t)(ntohs(th->th_win) - ntohs(oth->th_win))) > 0) {
319 		ENCODE(deltaS);
320 
321 		changes |= NEW_W;
322 	}
323 
324 	if ((deltaA = ntohl(th->th_ack) - ntohl(oth->th_ack)) > 0) {
325 		if (deltaA > 0xffff) {
326 			goto uncompressed;
327 		}
328 
329 		ENCODE(deltaA);
330 
331 		changes |= NEW_A;
332 	}
333 
334 	if ((deltaS = ntohl(th->th_seq) - ntohl(oth->th_seq)) > 0) {
335 		if (deltaS > 0xffff) {
336 			goto uncompressed;
337 		}
338 
339 		ENCODE(deltaS);
340 
341 		changes |= NEW_S;
342 	}
343 
344 	switch (changes) {
345 
346 	case 0:
347 		/*
348 		 * Nothing changed. If this packet contains data and the
349 		 * last one didn't, this is probably a data packet following
350 		 * an ack (normal on an interactive connection) and we send
351 		 * it compressed.  Otherwise it's probably a retransmit,
352 		 * retransmitted ack or window probe.  Send it uncompressed
353 		 * in case the other side missed the compressed version.
354 		 */
355 		if (ip->ip_len != cs->cs_ip.ip_len &&
356 					ntohs(cs->cs_ip.ip_len) == thlen) {
357 			break;
358 		}
359 
360 		/* (otherwise fall through) */
361 		/* FALLTHRU */
362 
363 	case SPECIAL_I:
364 	case SPECIAL_D:
365 
366 		/*
367 		 * actual changes match one of our special case encodings --
368 		 * send packet uncompressed.
369 		 */
370 		goto uncompressed;
371 
372 	case NEW_S|NEW_A:
373 
374 		if (deltaS == deltaA &&
375 				deltaS == ntohs(cs->cs_ip.ip_len) - thlen) {
376 
377 			/*
378 			 * special case for echoed terminal traffic
379 			 */
380 			changes = SPECIAL_I;
381 			cp = new_seq;
382 		}
383 
384 		break;
385 
386 	case NEW_S:
387 
388 		if (deltaS == ntohs(cs->cs_ip.ip_len) - thlen) {
389 
390 			/*
391 			 * special case for data xfer
392 			 */
393 			changes = SPECIAL_D;
394 			cp = new_seq;
395 		}
396 
397 		break;
398 	}
399 
400 	deltaS = ntohs(ip->ip_id) - ntohs(cs->cs_ip.ip_id);
401 	if (deltaS != 1) {
402 		ENCODEZ(deltaS);
403 
404 		changes |= NEW_I;
405 	}
406 
407 	if (th->th_flags & TH_PUSH) {
408 		changes |= TCP_PUSH_BIT;
409 	}
410 
411 	/*
412 	 * Grab the cksum before we overwrite it below.  Then update our
413 	 * state with this packet's header.
414 	 */
415 	deltaA = ntohs(th->th_sum);
416 
417 	BCOPY(ip, &cs->cs_ip, thlen);
418 
419 	/*
420 	 * We want to use the original packet as our compressed packet.
421 	 * (cp - new_seq) is the number of bytes we need for compressed
422 	 * sequence numbers.  In addition we need one byte for the change
423 	 * mask, one for the connection id and two for the tcp checksum.
424 	 * So, (cp - new_seq) + 4 bytes of header are needed.  thlen is how
425 	 * many bytes of the original packet to toss so subtract the two to
426 	 * get the new packet size.
427 	 */
428 	deltaS = cp - new_seq;
429 
430 	cp = (uchar_t *)ip;
431 
432 	if (compress_cid == 0 || comp->last_xmit != cs->cs_id) {
433 		comp->last_xmit = cs->cs_id;
434 
435 		thlen -= deltaS + 4;
436 
437 		*vjhdrp = (cp += thlen);
438 
439 		*cp++ = changes | NEW_C;
440 		*cp++ = cs->cs_id;
441 	} else {
442 		thlen -= deltaS + 3;
443 
444 		*vjhdrp = (cp += thlen);
445 
446 		*cp++ = changes & 0xff;
447 	}
448 
449 	*cp++ = (deltaA >> 8) & 0xff;
450 	*cp++ = deltaA & 0xff;
451 
452 	BCOPY(new_seq, cp, deltaS);
453 
454 	INCR(vjs_compressed);
455 
456 	return (TYPE_COMPRESSED_TCP);
457 
458 	/*
459 	 * Update connection state cs & send uncompressed packet (that is,
460 	 * a regular ip/tcp packet but with the 'conversation id' we hope
461 	 * to use on future compressed packets in the protocol field).
462 	 */
463 uncompressed:
464 
465 	BCOPY(ip, &cs->cs_ip, thlen);
466 
467 	ip->ip_p = cs->cs_id;
468 	comp->last_xmit = cs->cs_id;
469 
470 	return (TYPE_UNCOMPRESSED_TCP);
471 }
472 
473 /*
474  * vj_uncompress_err()
475  *
476  * Called when we may have missed a packet.
477  */
478 void
479 vj_uncompress_err(struct vjcompress *comp)
480 {
481 	comp->flags |= VJF_TOSS;
482 
483 	INCR(vjs_errorin);
484 }
485 
486 /*
487  * vj_uncompress_uncomp()
488  *
489  * "Uncompress" a packet of type TYPE_UNCOMPRESSED_TCP.
490  */
491 int
492 vj_uncompress_uncomp(uchar_t *buf, int buflen, struct vjcompress *comp)
493 {
494 	register uint_t		hlen;
495 	register struct cstate	*cs;
496 
497 	hlen = getip_hl(buf) << 2;
498 
499 	if (getip_p(buf) >= MAX_STATES ||
500 	    hlen + sizeof (struct tcphdr) > buflen ||
501 	    (hlen += getth_off(buf+hlen) << 2) > buflen || hlen > MAX_HDR) {
502 
503 		comp->flags |= VJF_TOSS;
504 
505 		INCR(vjs_errorin);
506 
507 		return (0);
508 	}
509 
510 	cs = &comp->rstate[comp->last_recv = getip_p(buf)];
511 	comp->flags &= ~VJF_TOSS;
512 	setip_p(buf, IPPROTO_TCP);
513 
514 	BCOPY(buf, &cs->cs_ip, hlen);
515 
516 	cs->cs_hlen = hlen & 0xff;
517 
518 	INCR(vjs_uncompressedin);
519 
520 	return (1);
521 }
522 
523 /*
524  * vj_uncompress_tcp()
525  *
526  * Uncompress a packet of type TYPE_COMPRESSED_TCP.
527  * The packet starts at buf and is of total length total_len.
528  * The first buflen bytes are at buf; this must include the entire
529  * compressed TCP/IP header.  This procedure returns the length
530  * of the VJ header, with a pointer to the uncompressed IP header
531  * in *hdrp and its length in *hlenp.
532  */
533 int
534 vj_uncompress_tcp(uchar_t *buf, int buflen, int total_len,
535 	struct vjcompress *comp, uchar_t **hdrp, uint_t *hlenp)
536 {
537 	register uchar_t	*cp;
538 	register uint_t		hlen;
539 	register uint_t		changes;
540 	register struct tcphdr	*th;
541 	register struct cstate	*cs;
542 	register ushort_t	*bp;
543 	register uint_t		vjlen;
544 	register uint32_t	tmp;
545 
546 	INCR(vjs_compressedin);
547 
548 	cp = buf;
549 	changes = *cp++;
550 
551 	if (changes & NEW_C) {
552 		/*
553 		 * Make sure the state index is in range, then grab the state.
554 		 * If we have a good state index, clear the 'discard' flag.
555 		 */
556 		if (*cp >= MAX_STATES) {
557 			goto bad;
558 		}
559 
560 		comp->flags &= ~VJF_TOSS;
561 		comp->last_recv = *cp++;
562 	} else {
563 		/*
564 		 * this packet has an implicit state index.  If we've
565 		 * had a line error since the last time we got an
566 		 * explicit state index, we have to toss the packet
567 		 */
568 		if (comp->flags & VJF_TOSS) {
569 			INCR(vjs_tossed);
570 			return (-1);
571 		}
572 	}
573 
574 	cs = &comp->rstate[comp->last_recv];
575 	hlen = getip_hl(&cs->cs_ip) << 2;
576 
577 	th = (struct tcphdr *)((uint32_t *)&cs->cs_ip+hlen/sizeof (uint32_t));
578 	th->th_sum = htons((*cp << 8) | cp[1]);
579 
580 	cp += 2;
581 
582 	if (changes & TCP_PUSH_BIT) {
583 		th->th_flags |= TH_PUSH;
584 	} else {
585 		th->th_flags &= ~TH_PUSH;
586 	}
587 
588 	switch (changes & SPECIALS_MASK) {
589 
590 	case SPECIAL_I:
591 
592 		{
593 
594 		register uint32_t	i;
595 
596 		i = ntohs(cs->cs_ip.ip_len) - cs->cs_hlen;
597 
598 		tmp = ntohl(th->th_ack) + i;
599 		th->th_ack = htonl(tmp);
600 
601 		tmp = ntohl(th->th_seq) + i;
602 		th->th_seq = htonl(tmp);
603 
604 		}
605 
606 		break;
607 
608 	case SPECIAL_D:
609 
610 		tmp = ntohl(th->th_seq) + ntohs(cs->cs_ip.ip_len) - cs->cs_hlen;
611 		th->th_seq = htonl(tmp);
612 
613 		break;
614 
615 	default:
616 
617 		if (changes & NEW_U) {
618 			th->th_flags |= TH_URG;
619 			DECODEU(th->th_urp);
620 		} else {
621 			th->th_flags &= ~TH_URG;
622 		}
623 
624 		if (changes & NEW_W) {
625 			DECODES(th->th_win);
626 		}
627 
628 		if (changes & NEW_A) {
629 			DECODEL(th->th_ack);
630 		}
631 
632 		if (changes & NEW_S) {
633 			DECODEL(th->th_seq);
634 		}
635 
636 		break;
637 	}
638 
639 	if (changes & NEW_I) {
640 		DECODES(cs->cs_ip.ip_id);
641 	} else {
642 		cs->cs_ip.ip_id = ntohs(cs->cs_ip.ip_id) + 1;
643 		cs->cs_ip.ip_id = htons(cs->cs_ip.ip_id);
644 	}
645 
646 	/*
647 	 * At this point, cp points to the first byte of data in the
648 	 * packet.  Fill in the IP total length and update the IP
649 	 * header checksum.
650 	 */
651 	vjlen = cp - buf;
652 	buflen -= vjlen;
653 	if (buflen < 0) {
654 		/*
655 		 * we must have dropped some characters (crc should detect
656 		 * this but the old slip framing won't)
657 		 */
658 		goto bad;
659 	}
660 
661 	total_len += cs->cs_hlen - vjlen;
662 	cs->cs_ip.ip_len = htons(total_len);
663 
664 	/*
665 	 * recompute the ip header checksum
666 	 */
667 	bp = (ushort_t *)&cs->cs_ip;
668 	cs->cs_ip.ip_sum = 0;
669 
670 	for (changes = 0; hlen > 0; hlen -= 2) {
671 		changes += *bp++;
672 	}
673 
674 	changes = (changes & 0xffff) + (changes >> 16);
675 	changes = (changes & 0xffff) + (changes >> 16);
676 	cs->cs_ip.ip_sum = ~ changes;
677 
678 	*hdrp = (uchar_t *)&cs->cs_ip;
679 	*hlenp = cs->cs_hlen;
680 
681 	return (vjlen);
682 
683 bad:
684 
685 	comp->flags |= VJF_TOSS;
686 
687 	INCR(vjs_errorin);
688 
689 	return (-1);
690 }
691