xref: /illumos-gate/usr/src/tools/smatch/src/avl.c (revision 1f5207b7)
1 /*
2  * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a copy
5  * of this software and associated documentation files (the "Software"), to deal
6  * in the Software without restriction, including without limitation the rights
7  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8  * copies of the Software, and to permit persons to whom the Software is
9  * furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20  * THE SOFTWARE.
21  */
22 
23 #include <assert.h>
24 #include <stdlib.h>
25 
26 #include "smatch.h"
27 #include "smatch_slist.h"
28 
29 static AvlNode *mkNode(const struct sm_state *sm);
30 static void freeNode(AvlNode *node);
31 
32 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm);
33 
34 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm);
35 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret);
36 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
37 
38 static int sway(AvlNode **p, int sway);
39 static void balance(AvlNode **p, int side);
40 
41 static bool checkBalances(AvlNode *node, int *height);
42 static bool checkOrder(struct stree *avl);
43 static size_t countNode(AvlNode *node);
44 
45 int unfree_stree;
46 
47 /*
48  * Utility macros for converting between
49  * "balance" values (-1 or 1) and "side" values (0 or 1).
50  *
51  * bal(0)   == -1
52  * bal(1)   == +1
53  * side(-1) == 0
54  * side(+1) == 1
55  */
56 #define bal(side) ((side) == 0 ? -1 : 1)
57 #define side(bal) ((bal)  == 1 ?  1 : 0)
58 
avl_new(void)59 static struct stree *avl_new(void)
60 {
61 	struct stree *avl = malloc(sizeof(*avl));
62 
63 	unfree_stree++;
64 	assert(avl != NULL);
65 
66 	avl->root = NULL;
67 	avl->base_stree = NULL;
68 	avl->has_states = calloc(num_checks + 1, sizeof(char));
69 	avl->count = 0;
70 	avl->stree_id = 0;
71 	avl->references = 1;
72 	return avl;
73 }
74 
free_stree(struct stree ** avl)75 void free_stree(struct stree **avl)
76 {
77 	if (!*avl)
78 		return;
79 
80 	assert((*avl)->references > 0);
81 
82 	(*avl)->references--;
83 	if ((*avl)->references != 0) {
84 		*avl = NULL;
85 		return;
86 	}
87 
88 	unfree_stree--;
89 
90 	freeNode((*avl)->root);
91 	free(*avl);
92 	*avl = NULL;
93 }
94 
avl_lookup(const struct stree * avl,const struct sm_state * sm)95 struct sm_state *avl_lookup(const struct stree *avl, const struct sm_state *sm)
96 {
97 	AvlNode *found;
98 
99 	if (!avl)
100 		return NULL;
101 	if (sm->owner != USHRT_MAX &&
102 	    !avl->has_states[sm->owner])
103 		return NULL;
104 	found = lookup(avl, avl->root, sm);
105 	if (!found)
106 		return NULL;
107 	return (struct sm_state *)found->sm;
108 }
109 
avl_lookup_node(const struct stree * avl,const struct sm_state * sm)110 AvlNode *avl_lookup_node(const struct stree *avl, const struct sm_state *sm)
111 {
112 	return lookup(avl, avl->root, sm);
113 }
114 
stree_count(const struct stree * avl)115 size_t stree_count(const struct stree *avl)
116 {
117 	if (!avl)
118 		return 0;
119 	return avl->count;
120 }
121 
clone_stree_real(struct stree * orig)122 static struct stree *clone_stree_real(struct stree *orig)
123 {
124 	struct stree *new = avl_new();
125 	AvlIter i;
126 
127 	avl_foreach(i, orig)
128 		avl_insert(&new, i.sm);
129 
130 	new->base_stree = orig->base_stree;
131 	return new;
132 }
133 
avl_insert(struct stree ** avl,const struct sm_state * sm)134 bool avl_insert(struct stree **avl, const struct sm_state *sm)
135 {
136 	size_t old_count;
137 
138 	if (!*avl)
139 		*avl = avl_new();
140 	if ((*avl)->references > 1) {
141 		(*avl)->references--;
142 		*avl = clone_stree_real(*avl);
143 	}
144 	old_count = (*avl)->count;
145 	/* fortunately we never call get_state() on "unnull_path" */
146 	if (sm->owner != USHRT_MAX)
147 		(*avl)->has_states[sm->owner] = 1;
148 	insert_sm(*avl, &(*avl)->root, sm);
149 	return (*avl)->count != old_count;
150 }
151 
avl_remove(struct stree ** avl,const struct sm_state * sm)152 bool avl_remove(struct stree **avl, const struct sm_state *sm)
153 {
154 	AvlNode *node = NULL;
155 
156 	if (!*avl)
157 		return false;
158 	/* it's fairly rare for smatch to call avl_remove */
159 	if ((*avl)->references > 1) {
160 		(*avl)->references--;
161 		*avl = clone_stree_real(*avl);
162 	}
163 
164 	remove_sm(*avl, &(*avl)->root, sm, &node);
165 
166 	if ((*avl)->count == 0)
167 		free_stree(avl);
168 
169 	if (node == NULL) {
170 		return false;
171 	} else {
172 		free(node);
173 		return true;
174 	}
175 }
176 
mkNode(const struct sm_state * sm)177 static AvlNode *mkNode(const struct sm_state *sm)
178 {
179 	AvlNode *node = malloc(sizeof(*node));
180 
181 	assert(node != NULL);
182 
183 	node->sm = sm;
184 	node->lr[0] = NULL;
185 	node->lr[1] = NULL;
186 	node->balance = 0;
187 	return node;
188 }
189 
freeNode(AvlNode * node)190 static void freeNode(AvlNode *node)
191 {
192 	if (node) {
193 		freeNode(node->lr[0]);
194 		freeNode(node->lr[1]);
195 		free(node);
196 	}
197 }
198 
lookup(const struct stree * avl,AvlNode * node,const struct sm_state * sm)199 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm)
200 {
201 	int cmp;
202 
203 	if (node == NULL)
204 		return NULL;
205 
206 	cmp = cmp_tracker(sm, node->sm);
207 
208 	if (cmp < 0)
209 		return lookup(avl, node->lr[0], sm);
210 	if (cmp > 0)
211 		return lookup(avl, node->lr[1], sm);
212 	return node;
213 }
214 
215 /*
216  * Insert an sm into a subtree, rebalancing if necessary.
217  *
218  * Return true if the subtree's height increased.
219  */
insert_sm(struct stree * avl,AvlNode ** p,const struct sm_state * sm)220 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm)
221 {
222 	if (*p == NULL) {
223 		*p = mkNode(sm);
224 		avl->count++;
225 		return true;
226 	} else {
227 		AvlNode *node = *p;
228 		int      cmp  = cmp_tracker(sm, node->sm);
229 
230 		if (cmp == 0) {
231 			node->sm = sm;
232 			return false;
233 		}
234 
235 		if (!insert_sm(avl, &node->lr[side(cmp)], sm))
236 			return false;
237 
238 		/* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
239 		return sway(p, cmp) != 0;
240 	}
241 }
242 
243 /*
244  * Remove the node matching the given sm.
245  * If present, return the removed node through *ret .
246  * The returned node's lr and balance are meaningless.
247  *
248  * Return true if the subtree's height decreased.
249  */
remove_sm(struct stree * avl,AvlNode ** p,const struct sm_state * sm,AvlNode ** ret)250 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret)
251 {
252 	if (p == NULL || *p == NULL) {
253 		return false;
254 	} else {
255 		AvlNode *node = *p;
256 		int      cmp  = cmp_tracker(sm, node->sm);
257 
258 		if (cmp == 0) {
259 			*ret = node;
260 			avl->count--;
261 
262 			if (node->lr[0] != NULL && node->lr[1] != NULL) {
263 				AvlNode *replacement;
264 				int      side;
265 				bool     shrunk;
266 
267 				/* Pick a subtree to pull the replacement from such that
268 				 * this node doesn't have to be rebalanced. */
269 				side = node->balance <= 0 ? 0 : 1;
270 
271 				shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
272 
273 				replacement->lr[0]   = node->lr[0];
274 				replacement->lr[1]   = node->lr[1];
275 				replacement->balance = node->balance;
276 				*p = replacement;
277 
278 				if (!shrunk)
279 					return false;
280 
281 				replacement->balance -= bal(side);
282 
283 				/* If tree's balance became 0, it means the tree's height shrank due to removal. */
284 				return replacement->balance == 0;
285 			}
286 
287 			if (node->lr[0] != NULL)
288 				*p = node->lr[0];
289 			else
290 				*p = node->lr[1];
291 
292 			return true;
293 
294 		} else {
295 			if (!remove_sm(avl, &node->lr[side(cmp)], sm, ret))
296 				return false;
297 
298 			/* If tree's balance became 0, it means the tree's height shrank due to removal. */
299 			return sway(p, -cmp) == 0;
300 		}
301 	}
302 }
303 
304 /*
305  * Remove either the left-most (if side == 0) or right-most (if side == 1)
306  * node in a subtree, returning the removed node through *ret .
307  * The returned node's lr and balance are meaningless.
308  *
309  * The subtree must not be empty (i.e. *p must not be NULL).
310  *
311  * Return true if the subtree's height decreased.
312  */
removeExtremum(AvlNode ** p,int side,AvlNode ** ret)313 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
314 {
315 	AvlNode *node = *p;
316 
317 	if (node->lr[side] == NULL) {
318 		*ret = node;
319 		*p = node->lr[1 - side];
320 		return true;
321 	}
322 
323 	if (!removeExtremum(&node->lr[side], side, ret))
324 		return false;
325 
326 	/* If tree's balance became 0, it means the tree's height shrank due to removal. */
327 	return sway(p, -bal(side)) == 0;
328 }
329 
330 /*
331  * Rebalance a node if necessary.  Think of this function
332  * as a higher-level interface to balance().
333  *
334  * sway must be either -1 or 1, and indicates what was added to
335  * the balance of this node by a prior operation.
336  *
337  * Return the new balance of the subtree.
338  */
sway(AvlNode ** p,int sway)339 static int sway(AvlNode **p, int sway)
340 {
341 	if ((*p)->balance != sway)
342 		(*p)->balance += sway;
343 	else
344 		balance(p, side(sway));
345 
346 	return (*p)->balance;
347 }
348 
349 /*
350  * Perform tree rotations on an unbalanced node.
351  *
352  * side == 0 means the node's balance is -2 .
353  * side == 1 means the node's balance is +2 .
354  */
balance(AvlNode ** p,int side)355 static void balance(AvlNode **p, int side)
356 {
357 	AvlNode  *node  = *p,
358 	         *child = node->lr[side];
359 	int opposite    = 1 - side;
360 	int bal         = bal(side);
361 
362 	if (child->balance != -bal) {
363 		/* Left-left (side == 0) or right-right (side == 1) */
364 		node->lr[side]      = child->lr[opposite];
365 		child->lr[opposite] = node;
366 		*p = child;
367 
368 		child->balance -= bal;
369 		node->balance = -child->balance;
370 
371 	} else {
372 		/* Left-right (side == 0) or right-left (side == 1) */
373 		AvlNode *grandchild = child->lr[opposite];
374 
375 		node->lr[side]           = grandchild->lr[opposite];
376 		child->lr[opposite]      = grandchild->lr[side];
377 		grandchild->lr[side]     = child;
378 		grandchild->lr[opposite] = node;
379 		*p = grandchild;
380 
381 		node->balance       = 0;
382 		child->balance      = 0;
383 
384 		if (grandchild->balance == bal)
385 			node->balance  = -bal;
386 		else if (grandchild->balance == -bal)
387 			child->balance = bal;
388 
389 		grandchild->balance = 0;
390 	}
391 }
392 
393 
394 /************************* avl_check_invariants() *************************/
395 
avl_check_invariants(struct stree * avl)396 bool avl_check_invariants(struct stree *avl)
397 {
398 	int    dummy;
399 
400 	return checkBalances(avl->root, &dummy)
401 	    && checkOrder(avl)
402 	    && countNode(avl->root) == avl->count;
403 }
404 
checkBalances(AvlNode * node,int * height)405 static bool checkBalances(AvlNode *node, int *height)
406 {
407 	if (node) {
408 		int h0, h1;
409 
410 		if (!checkBalances(node->lr[0], &h0))
411 			return false;
412 		if (!checkBalances(node->lr[1], &h1))
413 			return false;
414 
415 		if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
416 			return false;
417 
418 		*height = (h0 > h1 ? h0 : h1) + 1;
419 		return true;
420 	} else {
421 		*height = 0;
422 		return true;
423 	}
424 }
425 
checkOrder(struct stree * avl)426 static bool checkOrder(struct stree *avl)
427 {
428 	AvlIter     i;
429 	const struct sm_state *last = NULL;
430 	bool        last_set = false;
431 
432 	avl_foreach(i, avl) {
433 		if (last_set && cmp_tracker(last, i.sm) >= 0)
434 			return false;
435 		last     = i.sm;
436 		last_set = true;
437 	}
438 
439 	return true;
440 }
441 
countNode(AvlNode * node)442 static size_t countNode(AvlNode *node)
443 {
444 	if (node)
445 		return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
446 	else
447 		return 0;
448 }
449 
450 
451 /************************* Traversal *************************/
452 
avl_iter_begin(AvlIter * iter,struct stree * avl,AvlDirection dir)453 void avl_iter_begin(AvlIter *iter, struct stree *avl, AvlDirection dir)
454 {
455 	AvlNode *node;
456 
457 	iter->stack_index = 0;
458 	iter->direction   = dir;
459 
460 	if (!avl || !avl->root) {
461 		iter->sm      = NULL;
462 		iter->node     = NULL;
463 		return;
464 	}
465 	node = avl->root;
466 
467 	while (node->lr[dir] != NULL) {
468 		iter->stack[iter->stack_index++] = node;
469 		node = node->lr[dir];
470 	}
471 
472 	iter->sm   = (struct sm_state *) node->sm;
473 	iter->node  = node;
474 }
475 
avl_iter_next(AvlIter * iter)476 void avl_iter_next(AvlIter *iter)
477 {
478 	AvlNode     *node = iter->node;
479 	AvlDirection dir  = iter->direction;
480 
481 	if (node == NULL)
482 		return;
483 
484 	node = node->lr[1 - dir];
485 	if (node != NULL) {
486 		while (node->lr[dir] != NULL) {
487 			iter->stack[iter->stack_index++] = node;
488 			node = node->lr[dir];
489 		}
490 	} else if (iter->stack_index > 0) {
491 		node = iter->stack[--iter->stack_index];
492 	} else {
493 		iter->sm      = NULL;
494 		iter->node     = NULL;
495 		return;
496 	}
497 
498 	iter->node  = node;
499 	iter->sm   = (struct sm_state *) node->sm;
500 }
501 
clone_stree(struct stree * orig)502 struct stree *clone_stree(struct stree *orig)
503 {
504 	if (!orig)
505 		return NULL;
506 
507 	orig->references++;
508 	return orig;
509 }
510 
set_stree_id(struct stree ** stree,int stree_id)511 void set_stree_id(struct stree **stree, int stree_id)
512 {
513 	if ((*stree)->stree_id != 0)
514 		*stree = clone_stree_real(*stree);
515 
516 	(*stree)->stree_id = stree_id;
517 }
518 
get_stree_id(struct stree * stree)519 int get_stree_id(struct stree *stree)
520 {
521 	if (!stree)
522 		return -1;
523 	return stree->stree_id;
524 }
525