1 /*
2  * Copyright (C) 2009 Dan Carpenter.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License
6  * as published by the Free Software Foundation; either version 2
7  * of the License, or (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
16  */
17 
18 /*
19  * The idea here is that you have an expression and you
20  * want to know what the type is for that.
21  */
22 
23 #include "smatch.h"
24 #include "smatch_slist.h"
25 
26 struct symbol *get_real_base_type(struct symbol *sym)
27 {
28 	struct symbol *ret;
29 
30 	if (!sym)
31 		return NULL;
32 	if (sym->type == SYM_BASETYPE)
33 		return sym;
34 	ret = get_base_type(sym);
35 	if (!ret)
36 		return NULL;
37 	if (ret->type == SYM_RESTRICT || ret->type == SYM_NODE)
38 		return get_real_base_type(ret);
39 	return ret;
40 }
41 
42 int type_bytes(struct symbol *type)
43 {
44 	int bits;
45 
46 	if (type && type->type == SYM_ARRAY)
47 		return array_bytes(type);
48 
49 	bits = type_bits(type);
50 	if (bits < 0)
51 		return 0;
52 	return bits_to_bytes(bits);
53 }
54 
55 int array_bytes(struct symbol *type)
56 {
57 	if (!type || type->type != SYM_ARRAY)
58 		return 0;
59 	return bits_to_bytes(type->bit_size);
60 }
61 
62 static struct symbol *get_binop_type(struct expression *expr)
63 {
64 	struct symbol *left, *right;
65 
66 	left = get_type(expr->left);
67 	if (!left)
68 		return NULL;
69 
70 	if (expr->op == SPECIAL_LEFTSHIFT ||
71 	    expr->op == SPECIAL_RIGHTSHIFT) {
72 		if (type_positive_bits(left) < 31)
73 			return &int_ctype;
74 		return left;
75 	}
76 	right = get_type(expr->right);
77 	if (!right)
78 		return NULL;
79 
80 	if (expr->op == '-' &&
81 	    (is_ptr_type(left) && is_ptr_type(right)))
82 		return ssize_t_ctype;
83 
84 	if (left->type == SYM_PTR || left->type == SYM_ARRAY)
85 		return left;
86 	if (right->type == SYM_PTR || right->type == SYM_ARRAY)
87 		return right;
88 
89 	if (type_positive_bits(left) < 31 && type_positive_bits(right) < 31)
90 		return &int_ctype;
91 
92 	if (type_positive_bits(left) > type_positive_bits(right))
93 		return left;
94 	return right;
95 }
96 
97 static struct symbol *get_type_symbol(struct expression *expr)
98 {
99 	if (!expr || expr->type != EXPR_SYMBOL || !expr->symbol)
100 		return NULL;
101 
102 	return get_real_base_type(expr->symbol);
103 }
104 
105 static struct symbol *get_member_symbol(struct symbol_list *symbol_list, struct ident *member)
106 {
107 	struct symbol *tmp, *sub;
108 
109 	FOR_EACH_PTR(symbol_list, tmp) {
110 		if (!tmp->ident) {
111 			sub = get_real_base_type(tmp);
112 			sub = get_member_symbol(sub->symbol_list, member);
113 			if (sub)
114 				return sub;
115 			continue;
116 		}
117 		if (tmp->ident == member)
118 			return tmp;
119 	} END_FOR_EACH_PTR(tmp);
120 
121 	return NULL;
122 }
123 
124 static struct symbol *get_symbol_from_deref(struct expression *expr)
125 {
126 	struct ident *member;
127 	struct symbol *sym;
128 
129 	if (!expr || expr->type != EXPR_DEREF)
130 		return NULL;
131 
132 	member = expr->member;
133 	sym = get_type(expr->deref);
134 	if (!sym) {
135 		// sm_msg("could not find struct type");
136 		return NULL;
137 	}
138 	if (sym->type == SYM_PTR)
139 		sym = get_real_base_type(sym);
140 	sym = get_member_symbol(sym->symbol_list, member);
141 	if (!sym)
142 		return NULL;
143 	return get_real_base_type(sym);
144 }
145 
146 static struct symbol *get_return_type(struct expression *expr)
147 {
148 	struct symbol *tmp;
149 
150 	tmp = get_type(expr->fn);
151 	if (!tmp)
152 		return NULL;
153 	/* this is to handle __builtin_constant_p() */
154 	if (tmp->type != SYM_FN)
155 		tmp = get_base_type(tmp);
156 	return get_real_base_type(tmp);
157 }
158 
159 static struct symbol *get_expr_stmt_type(struct statement *stmt)
160 {
161 	if (stmt->type != STMT_COMPOUND)
162 		return NULL;
163 	stmt = last_ptr_list((struct ptr_list *)stmt->stmts);
164 	if (stmt->type == STMT_LABEL)
165 		stmt = stmt->label_statement;
166 	if (stmt->type != STMT_EXPRESSION)
167 		return NULL;
168 	return get_type(stmt->expression);
169 }
170 
171 static struct symbol *get_select_type(struct expression *expr)
172 {
173 	struct symbol *one, *two;
174 
175 	one = get_type(expr->cond_true);
176 	two = get_type(expr->cond_false);
177 	if (!one || !two)
178 		return NULL;
179 	/*
180 	 * This is a hack.  If the types are not equiv then we
181 	 * really don't know the type.  But I think guessing is
182 	 *  probably Ok here.
183 	 */
184 	if (type_positive_bits(one) > type_positive_bits(two))
185 		return one;
186 	return two;
187 }
188 
189 struct symbol *get_pointer_type(struct expression *expr)
190 {
191 	struct symbol *sym;
192 
193 	sym = get_type(expr);
194 	if (!sym)
195 		return NULL;
196 	if (sym->type == SYM_NODE) {
197 		sym = get_real_base_type(sym);
198 		if (!sym)
199 			return NULL;
200 	}
201 	if (sym->type != SYM_PTR && sym->type != SYM_ARRAY)
202 		return NULL;
203 	return get_real_base_type(sym);
204 }
205 
206 static struct symbol *fake_pointer_sym(struct expression *expr)
207 {
208 	struct symbol *sym;
209 	struct symbol *base;
210 
211 	sym = alloc_symbol(expr->pos, SYM_PTR);
212 	expr = expr->unop;
213 	base = get_type(expr);
214 	if (!base)
215 		return NULL;
216 	sym->ctype.base_type = base;
217 	return sym;
218 }
219 
220 static struct symbol *get_type_helper(struct expression *expr)
221 {
222 	struct symbol *ret;
223 
224 	expr = strip_parens(expr);
225 	if (!expr)
226 		return NULL;
227 
228 	if (expr->ctype)
229 		return expr->ctype;
230 
231 	switch (expr->type) {
232 	case EXPR_STRING:
233 		ret = &string_ctype;
234 		break;
235 	case EXPR_SYMBOL:
236 		ret = get_type_symbol(expr);
237 		break;
238 	case EXPR_DEREF:
239 		ret = get_symbol_from_deref(expr);
240 		break;
241 	case EXPR_PREOP:
242 	case EXPR_POSTOP:
243 		if (expr->op == '&')
244 			ret = fake_pointer_sym(expr);
245 		else if (expr->op == '*')
246 			ret = get_pointer_type(expr->unop);
247 		else
248 			ret = get_type(expr->unop);
249 		break;
250 	case EXPR_ASSIGNMENT:
251 		ret = get_type(expr->left);
252 		break;
253 	case EXPR_CAST:
254 	case EXPR_FORCE_CAST:
255 	case EXPR_IMPLIED_CAST:
256 		ret = get_real_base_type(expr->cast_type);
257 		break;
258 	case EXPR_COMPARE:
259 	case EXPR_BINOP:
260 		ret = get_binop_type(expr);
261 		break;
262 	case EXPR_CALL:
263 		ret = get_return_type(expr);
264 		break;
265 	case EXPR_STATEMENT:
266 		ret = get_expr_stmt_type(expr->statement);
267 		break;
268 	case EXPR_CONDITIONAL:
269 	case EXPR_SELECT:
270 		ret = get_select_type(expr);
271 		break;
272 	case EXPR_SIZEOF:
273 		ret = &ulong_ctype;
274 		break;
275 	case EXPR_LOGICAL:
276 		ret = &int_ctype;
277 		break;
278 	case EXPR_OFFSETOF:
279 		ret = &ulong_ctype;
280 		break;
281 	default:
282 		return NULL;
283 	}
284 
285 	if (ret && ret->type == SYM_TYPEOF)
286 		ret = get_type(ret->initializer);
287 
288 	expr->ctype = ret;
289 	return ret;
290 }
291 
292 static struct symbol *get_final_type_helper(struct expression *expr)
293 {
294 	/*
295 	 * The problem is that I wrote a bunch of Smatch to think that
296 	 * you could do get_type() on an expression and it would give
297 	 * you what the comparison was type promoted to.  This is wrong
298 	 * but fixing it is a big of work...  Hence this horrible hack.
299 	 *
300 	 */
301 
302 	expr = strip_parens(expr);
303 	if (!expr)
304 		return NULL;
305 
306 	if (expr->type == EXPR_COMPARE)
307 		return &int_ctype;
308 
309 	return NULL;
310 }
311 
312 struct symbol *get_type(struct expression *expr)
313 {
314 	return get_type_helper(expr);
315 }
316 
317 struct symbol *get_final_type(struct expression *expr)
318 {
319 	struct symbol *ret;
320 
321 	ret = get_final_type_helper(expr);
322 	if (ret)
323 		return ret;
324 	return get_type_helper(expr);
325 }
326 
327 struct symbol *get_promoted_type(struct symbol *left, struct symbol *right)
328 {
329 	struct symbol *ret = &int_ctype;
330 
331 	if (type_positive_bits(left) > type_positive_bits(ret))
332 		ret = left;
333 	if (type_positive_bits(right) > type_positive_bits(ret))
334 		ret = right;
335 
336 	if (type_is_ptr(left))
337 		ret = left;
338 	if (type_is_ptr(right))
339 		ret = right;
340 
341 	return ret;
342 }
343 
344 int type_signed(struct symbol *base_type)
345 {
346 	if (!base_type)
347 		return 0;
348 	if (base_type->ctype.modifiers & MOD_SIGNED)
349 		return 1;
350 	return 0;
351 }
352 
353 int expr_unsigned(struct expression *expr)
354 {
355 	struct symbol *sym;
356 
357 	sym = get_type(expr);
358 	if (!sym)
359 		return 0;
360 	if (type_unsigned(sym))
361 		return 1;
362 	return 0;
363 }
364 
365 int expr_signed(struct expression *expr)
366 {
367 	struct symbol *sym;
368 
369 	sym = get_type(expr);
370 	if (!sym)
371 		return 0;
372 	if (type_signed(sym))
373 		return 1;
374 	return 0;
375 }
376 
377 int returns_unsigned(struct symbol *sym)
378 {
379 	if (!sym)
380 		return 0;
381 	sym = get_base_type(sym);
382 	if (!sym || sym->type != SYM_FN)
383 		return 0;
384 	sym = get_base_type(sym);
385 	return type_unsigned(sym);
386 }
387 
388 int is_pointer(struct expression *expr)
389 {
390 	return type_is_ptr(get_type(expr));
391 }
392 
393 int returns_pointer(struct symbol *sym)
394 {
395 	if (!sym)
396 		return 0;
397 	sym = get_base_type(sym);
398 	if (!sym || sym->type != SYM_FN)
399 		return 0;
400 	sym = get_base_type(sym);
401 	if (sym->type == SYM_PTR)
402 		return 1;
403 	return 0;
404 }
405 
406 sval_t sval_type_max(struct symbol *base_type)
407 {
408 	sval_t ret;
409 
410 	if (!base_type || !type_bits(base_type))
411 		base_type = &llong_ctype;
412 	ret.type = base_type;
413 
414 	ret.value = (~0ULL) >> (64 - type_positive_bits(base_type));
415 	return ret;
416 }
417 
418 sval_t sval_type_min(struct symbol *base_type)
419 {
420 	sval_t ret;
421 
422 	if (!base_type || !type_bits(base_type))
423 		base_type = &llong_ctype;
424 	ret.type = base_type;
425 
426 	if (type_unsigned(base_type) || is_ptr_type(base_type)) {
427 		ret.value = 0;
428 		return ret;
429 	}
430 
431 	ret.value = (~0ULL) << type_positive_bits(base_type);
432 
433 	return ret;
434 }
435 
436 int nr_bits(struct expression *expr)
437 {
438 	struct symbol *type;
439 
440 	type = get_type(expr);
441 	if (!type)
442 		return 0;
443 	return type_bits(type);
444 }
445 
446 int is_void_pointer(struct expression *expr)
447 {
448 	struct symbol *type;
449 
450 	type = get_type(expr);
451 	if (!type || type->type != SYM_PTR)
452 		return 0;
453 	type = get_real_base_type(type);
454 	if (type == &void_ctype)
455 		return 1;
456 	return 0;
457 }
458 
459 int is_char_pointer(struct expression *expr)
460 {
461 	struct symbol *type;
462 
463 	type = get_type(expr);
464 	if (!type || type->type != SYM_PTR)
465 		return 0;
466 	type = get_real_base_type(type);
467 	if (type == &char_ctype)
468 		return 1;
469 	return 0;
470 }
471 
472 int is_string(struct expression *expr)
473 {
474 	expr = strip_expr(expr);
475 	if (!expr || expr->type != EXPR_STRING)
476 		return 0;
477 	if (expr->string)
478 		return 1;
479 	return 0;
480 }
481 
482 int is_static(struct expression *expr)
483 {
484 	char *name;
485 	struct symbol *sym;
486 	int ret = 0;
487 
488 	name = expr_to_str_sym(expr, &sym);
489 	if (!name || !sym)
490 		goto free;
491 
492 	if (sym->ctype.modifiers & MOD_STATIC)
493 		ret = 1;
494 free:
495 	free_string(name);
496 	return ret;
497 }
498 
499 int is_local_variable(struct expression *expr)
500 {
501 	struct symbol *sym;
502 	char *name;
503 
504 	name = expr_to_var_sym(expr, &sym);
505 	free_string(name);
506 	if (!sym || !sym->scope || !sym->scope->token || !cur_func_sym)
507 		return 0;
508 	if (cmp_pos(sym->scope->token->pos, cur_func_sym->pos) < 0)
509 		return 0;
510 	if (is_static(expr))
511 		return 0;
512 	return 1;
513 }
514 
515 int types_equiv(struct symbol *one, struct symbol *two)
516 {
517 	if (!one && !two)
518 		return 1;
519 	if (!one || !two)
520 		return 0;
521 	if (one->type != two->type)
522 		return 0;
523 	if (one->type == SYM_PTR)
524 		return types_equiv(get_real_base_type(one), get_real_base_type(two));
525 	if (type_positive_bits(one) != type_positive_bits(two))
526 		return 0;
527 	return 1;
528 }
529 
530 int fn_static(void)
531 {
532 	return !!(cur_func_sym->ctype.modifiers & MOD_STATIC);
533 }
534 
535 const char *global_static(void)
536 {
537 	if (cur_func_sym->ctype.modifiers & MOD_STATIC)
538 		return "static";
539 	else
540 		return "global";
541 }
542 
543 struct symbol *cur_func_return_type(void)
544 {
545 	struct symbol *sym;
546 
547 	sym = get_real_base_type(cur_func_sym);
548 	if (!sym || sym->type != SYM_FN)
549 		return NULL;
550 	sym = get_real_base_type(sym);
551 	return sym;
552 }
553 
554 struct symbol *get_arg_type(struct expression *fn, int arg)
555 {
556 	struct symbol *fn_type;
557 	struct symbol *tmp;
558 	struct symbol *arg_type;
559 	int i;
560 
561 	fn_type = get_type(fn);
562 	if (!fn_type)
563 		return NULL;
564 	if (fn_type->type == SYM_PTR)
565 		fn_type = get_real_base_type(fn_type);
566 	if (fn_type->type != SYM_FN)
567 		return NULL;
568 
569 	i = 0;
570 	FOR_EACH_PTR(fn_type->arguments, tmp) {
571 		arg_type = get_real_base_type(tmp);
572 		if (i == arg) {
573 			return arg_type;
574 		}
575 		i++;
576 	} END_FOR_EACH_PTR(tmp);
577 
578 	return NULL;
579 }
580 
581 static struct symbol *get_member_from_string(struct symbol_list *symbol_list, const char *name)
582 {
583 	struct symbol *tmp, *sub;
584 	int chunk_len;
585 
586 	if (strncmp(name, ".", 1) == 0)
587 		name += 1;
588 	else if (strncmp(name, "->", 2) == 0)
589 		name += 2;
590 
591 	FOR_EACH_PTR(symbol_list, tmp) {
592 		if (!tmp->ident) {
593 			sub = get_real_base_type(tmp);
594 			sub = get_member_from_string(sub->symbol_list, name);
595 			if (sub)
596 				return sub;
597 			continue;
598 		}
599 
600 		if (strcmp(tmp->ident->name, name) == 0)
601 			return tmp;
602 
603 		chunk_len = tmp->ident->len;
604 		if (strncmp(tmp->ident->name, name, chunk_len) == 0 &&
605 		    (name[chunk_len] == '.' || name[chunk_len] == '-')) {
606 			sub = get_real_base_type(tmp);
607 			if (sub->type == SYM_PTR)
608 				sub = get_real_base_type(sub);
609 			return get_member_from_string(sub->symbol_list, name + chunk_len);
610 		}
611 
612 	} END_FOR_EACH_PTR(tmp);
613 
614 	return NULL;
615 }
616 
617 struct symbol *get_member_type_from_key(struct expression *expr, const char *key)
618 {
619 	struct symbol *sym;
620 
621 	if (strcmp(key, "$") == 0)
622 		return get_type(expr);
623 
624 	if (strcmp(key, "*$") == 0) {
625 		sym = get_type(expr);
626 		if (!sym || sym->type != SYM_PTR)
627 			return NULL;
628 		return get_real_base_type(sym);
629 	}
630 
631 	sym = get_type(expr);
632 	if (!sym)
633 		return NULL;
634 	if (sym->type == SYM_PTR)
635 		sym = get_real_base_type(sym);
636 
637 	key = key + 1;
638 	sym = get_member_from_string(sym->symbol_list, key);
639 	if (!sym)
640 		return NULL;
641 	return get_real_base_type(sym);
642 }
643 
644 struct symbol *get_arg_type_from_key(struct expression *fn, int param, struct expression *arg, const char *key)
645 {
646 	struct symbol *type;
647 
648 	if (!key)
649 		return NULL;
650 	if (strcmp(key, "$") == 0)
651 		return get_arg_type(fn, param);
652 	if (strcmp(key, "*$") == 0) {
653 		type = get_arg_type(fn, param);
654 		if (!type || type->type != SYM_PTR)
655 			return NULL;
656 		return get_real_base_type(type);
657 	}
658 	return get_member_type_from_key(arg, key);
659 }
660 
661 int is_struct(struct expression *expr)
662 {
663 	struct symbol *type;
664 
665 	type = get_type(expr);
666 	if (type && type->type == SYM_STRUCT)
667 		return 1;
668 	return 0;
669 }
670 
671 static struct {
672 	struct symbol *sym;
673 	const char *name;
674 } base_types[] = {
675 	{&bool_ctype, "bool"},
676 	{&void_ctype, "void"},
677 	{&type_ctype, "type"},
678 	{&char_ctype, "char"},
679 	{&schar_ctype, "schar"},
680 	{&uchar_ctype, "uchar"},
681 	{&short_ctype, "short"},
682 	{&sshort_ctype, "sshort"},
683 	{&ushort_ctype, "ushort"},
684 	{&int_ctype, "int"},
685 	{&sint_ctype, "sint"},
686 	{&uint_ctype, "uint"},
687 	{&long_ctype, "long"},
688 	{&slong_ctype, "slong"},
689 	{&ulong_ctype, "ulong"},
690 	{&llong_ctype, "llong"},
691 	{&sllong_ctype, "sllong"},
692 	{&ullong_ctype, "ullong"},
693 	{&lllong_ctype, "lllong"},
694 	{&slllong_ctype, "slllong"},
695 	{&ulllong_ctype, "ulllong"},
696 	{&float_ctype, "float"},
697 	{&double_ctype, "double"},
698 	{&ldouble_ctype, "ldouble"},
699 	{&string_ctype, "string"},
700 	{&ptr_ctype, "ptr"},
701 	{&lazy_ptr_ctype, "lazy_ptr"},
702 	{&incomplete_ctype, "incomplete"},
703 	{&label_ctype, "label"},
704 	{&bad_ctype, "bad"},
705 	{&null_ctype, "null"},
706 };
707 
708 static const char *base_type_str(struct symbol *sym)
709 {
710 	int i;
711 
712 	for (i = 0; i < ARRAY_SIZE(base_types); i++) {
713 		if (sym == base_types[i].sym)
714 			return base_types[i].name;
715 	}
716 	return "<unknown>";
717 }
718 
719 static int type_str_helper(char *buf, int size, struct symbol *type)
720 {
721 	int n;
722 
723 	if (!type)
724 		return snprintf(buf, size, "<unknown>");
725 
726 	if (type->type == SYM_BASETYPE) {
727 		return snprintf(buf, size, "%s", base_type_str(type));
728 	} else if (type->type == SYM_PTR) {
729 		type = get_real_base_type(type);
730 		n = type_str_helper(buf, size, type);
731 		if (n > size)
732 			return n;
733 		return n + snprintf(buf + n, size - n, "*");
734 	} else if (type->type == SYM_ARRAY) {
735 		type = get_real_base_type(type);
736 		n = type_str_helper(buf, size, type);
737 		if (n > size)
738 			return n;
739 		return n + snprintf(buf + n, size - n, "[]");
740 	} else if (type->type == SYM_STRUCT) {
741 		return snprintf(buf, size, "struct %s", type->ident ? type->ident->name : "");
742 	} else if (type->type == SYM_UNION) {
743 		if (type->ident)
744 			return snprintf(buf, size, "union %s", type->ident->name);
745 		else
746 			return snprintf(buf, size, "anonymous union");
747 	} else if (type->type == SYM_FN) {
748 		struct symbol *arg, *return_type, *arg_type;
749 		int i;
750 
751 		return_type = get_real_base_type(type);
752 		n = type_str_helper(buf, size, return_type);
753 		if (n > size)
754 			return n;
755 		n += snprintf(buf + n, size - n, "(*)(");
756 		if (n > size)
757 			return n;
758 
759 		i = 0;
760 		FOR_EACH_PTR(type->arguments, arg) {
761 			if (i++)
762 				n += snprintf(buf + n, size - n, ", ");
763 			if (n > size)
764 				return n;
765 			arg_type = get_real_base_type(arg);
766 			n += type_str_helper(buf + n, size - n, arg_type);
767 			if (n > size)
768 				return n;
769 		} END_FOR_EACH_PTR(arg);
770 
771 		return n + snprintf(buf + n, size - n, ")");
772 	} else if (type->type == SYM_NODE) {
773 		n = snprintf(buf, size, "node {");
774 		if (n > size)
775 			return n;
776 		type = get_real_base_type(type);
777 		n += type_str_helper(buf + n, size - n, type);
778 		if (n > size)
779 			return n;
780 		return n + snprintf(buf + n, size - n, "}");
781 	} else if (type->type == SYM_ENUM) {
782 		return snprintf(buf, size, "enum %s", type->ident ? type->ident->name : "<unknown>");
783 	} else {
784 		return snprintf(buf, size, "<type %d>", type->type);
785 	}
786 }
787 
788 char *type_to_str(struct symbol *type)
789 {
790 	static char buf[256];
791 
792 	buf[0] = '\0';
793 	type_str_helper(buf, sizeof(buf), type);
794 	return buf;
795 }
796