1 /*
2  * This file and its contents are supplied under the terms of the
3  * Common Development and Distribution License ("CDDL"), version 1.0.
4  * You may only use this file in accordance with the terms of version
5  * 1.0 of the CDDL.
6  *
7  * A full copy of the text of the CDDL should have accompanied this
8  * source.  A copy of the CDDL is also available via the Internet at
9  * http://www.illumos.org/license/CDDL.
10  */
11 
12 /*
13  * Copyright 2020 OmniOS Community Edition (OmniOSce) Association.
14  */
15 
16 /*
17  * Test file descriptor passing via SCM_RIGHTS, and in particular what happens
18  * on message truncation in terms of the represented size of the data in the
19  * control message. Ensure that no file descriptors are leaked - the kernel
20  * must close any that would not fit in the available buffer space and the
21  * userland application must close the rest.
22  */
23 
24 #include <stdio.h>
25 #include <errno.h>
26 #include <fcntl.h>
27 #include <signal.h>
28 #include <stdlib.h>
29 #include <string.h>
30 #include <strings.h>
31 #include <unistd.h>
32 #include <libproc.h>
33 
34 #include <sys/types.h>
35 #include <sys/param.h>
36 #include <sys/socket.h>
37 #include <sys/stat.h>
38 #include <sys/un.h>
39 #include <sys/wait.h>
40 #include <assert.h>
41 #include <alloca.h>
42 #include <err.h>
43 
44 static boolean_t debug;
45 
46 typedef struct cmsg_test {
47 	char *name;		/* Name of the test */
48 	uint_t send;		/* Number of FDs to send */
49 	uint_t recv;		/* Size receive buffer for this number of FDs */
50 	size_t predata;		/* Prepend dummy cmsg of this size */
51 	int bufsize;		/* Explicitly set receive buffer size. */
52 				/* Overrides 'recv' if non-zero */
53 	uint_t x_controllen;	/* Expected received msg_controllen */
54 	uint_t x_cmsg_datalen;	/* Expected received cmsg data length */
55 	uint32_t x_flags;	/* Expected received msf_flags */
56 } cmsg_test_t;
57 
58 static cmsg_test_t tests[] = {
59 	{
60 		.name = "send 1, recv 1",
61 		.send = 1,
62 		.recv = 1,
63 		.predata = 0,
64 		.bufsize = 0,
65 		.x_controllen = 16,
66 		.x_cmsg_datalen = 4,
67 		.x_flags = 0,
68 	},
69 	{
70 		.name = "send 10, recv 10",
71 		.send = 10,
72 		.recv = 10,
73 		.predata = 0,
74 		.bufsize = 0,
75 		.x_controllen = 52,
76 		.x_cmsg_datalen = 40,
77 		.x_flags = 0,
78 	},
79 	{
80 		.name = "send 2, recv 1",
81 		.send = 2,
82 		.recv = 1,
83 		.predata = 0,
84 		.bufsize = 0,
85 		.x_controllen = 16,
86 		.x_cmsg_datalen = 4,
87 		.x_flags = MSG_CTRUNC,
88 	},
89 	{
90 		.name = "send 2, recv 1, buffer 5",
91 		.send = 2,
92 		.recv = 1,
93 		.predata = 0,
94 		.bufsize = sizeof (int) * 2 - 3,
95 		.x_controllen = 17,
96 		.x_cmsg_datalen = 5,
97 		.x_flags = MSG_CTRUNC,
98 	},
99 	{
100 		.name = "send 2, recv 1, buffer 6",
101 		.send = 2,
102 		.recv = 1,
103 		.predata = 0,
104 		.bufsize = sizeof (int) * 2 - 2,
105 		.x_controllen = 18,
106 		.x_cmsg_datalen = 6,
107 		.x_flags = MSG_CTRUNC,
108 	},
109 	{
110 		.name = "send 2, recv 1, buffer 7",
111 		.send = 2,
112 		.recv = 1,
113 		.predata = 0,
114 		.bufsize = sizeof (int) * 2 - 1,
115 		.x_controllen = 19,
116 		.x_cmsg_datalen = 7,
117 		.x_flags = MSG_CTRUNC,
118 	},
119 
120 	/* Tests where there is no room allowed for data */
121 
122 	{
123 		.name = "send 2, recv 0, hdronly",
124 		.send = 2,
125 		.recv = 0,
126 		.predata = 0,
127 		.bufsize = 0,
128 		.x_controllen = 12,
129 		.x_cmsg_datalen = 0,
130 		.x_flags = MSG_CTRUNC,
131 	},
132 
133 	{
134 		.name = "send 2, recv 0, hdr - 1",
135 		.send = 2,
136 		.recv = 0,
137 		.predata = 0,
138 		.bufsize = -1,
139 		.x_controllen = 11,
140 		.x_cmsg_datalen = 0,
141 		.x_flags = MSG_CTRUNC,
142 	},
143 
144 	{
145 		.name = "send 2, recv 0, hdr - 5",
146 		.send = 2,
147 		.recv = 0,
148 		.predata = 0,
149 		.bufsize = -5,
150 		.x_controllen = 7,
151 		.x_cmsg_datalen = 0,
152 		.x_flags = MSG_CTRUNC,
153 	},
154 
155 	/* Tests where SCM_RIGHTS is not the first message */
156 
157 	{
158 		.name = "send 1, recv 1, pre 8",
159 		.send = 1,
160 		.recv = 1,
161 		.predata = 8,
162 		.bufsize = 0,
163 		.x_controllen = 36,
164 		.x_cmsg_datalen = 4,
165 		.x_flags = 0,
166 	},
167 	{
168 		.name = "send 1, recv 1, pre 7",
169 		.send = 1,
170 		.recv = 1,
171 		.predata = 7,
172 		.bufsize = 0,
173 		.x_controllen = 35,
174 		.x_cmsg_datalen = 4,
175 		.x_flags = 0,
176 	},
177 	{
178 		.name = "send 1, recv 1, pre 6",
179 		.send = 1,
180 		.recv = 1,
181 		.predata = 6,
182 		.bufsize = 0,
183 		.x_controllen = 34,
184 		.x_cmsg_datalen = 4,
185 		.x_flags = 0,
186 	},
187 	{
188 		.name = "send 1, recv 1, pre 5",
189 		.send = 1,
190 		.recv = 1,
191 		.predata = 5,
192 		.bufsize = 0,
193 		.x_controllen = 33,
194 		.x_cmsg_datalen = 4,
195 		.x_flags = 0,
196 	},
197 
198 	{
199 		.name = "send 2, recv 1, pre 8",
200 		.send = 2,
201 		.recv = 1,
202 		.predata = 8,
203 		.bufsize = 0,
204 		.x_controllen = 36,
205 		.x_cmsg_datalen = 8,
206 		.x_flags = MSG_CTRUNC,
207 	},
208 	{
209 		.name = "send 2, recv 1, pre 7",
210 		.send = 2,
211 		.recv = 1,
212 		.predata = 7,
213 		.bufsize = 0,
214 		.x_controllen = 36,
215 		.x_cmsg_datalen = 8,
216 		.x_flags = MSG_CTRUNC,
217 	},
218 	{
219 		.name = "send 2, recv 1, pre 6",
220 		.send = 2,
221 		.recv = 1,
222 		.predata = 6,
223 		.bufsize = 0,
224 		.x_controllen = 36,
225 		.x_cmsg_datalen = 8,
226 		.x_flags = MSG_CTRUNC,
227 	},
228 	{
229 		.name = "send 2, recv 1, pre 5",
230 		.send = 2,
231 		.recv = 1,
232 		.predata = 5,
233 		.bufsize = 0,
234 		.x_controllen = 36,
235 		.x_cmsg_datalen = 8,
236 		.x_flags = MSG_CTRUNC,
237 	},
238 	{
239 		.name = "send 2, recv 1, pre 4",
240 		.send = 2,
241 		.recv = 1,
242 		.predata = 4,
243 		.bufsize = 0,
244 		.x_controllen = 32,
245 		.x_cmsg_datalen = 8,
246 		.x_flags = MSG_CTRUNC,
247 	},
248 	{
249 		.name = "send 2, recv 1, pre 3",
250 		.send = 2,
251 		.recv = 1,
252 		.predata = 3,
253 		.bufsize = 0,
254 		.x_controllen = 32,
255 		.x_cmsg_datalen = 8,
256 		.x_flags = MSG_CTRUNC,
257 	},
258 	{
259 		.name = "send 2, recv 1, pre 2",
260 		.send = 2,
261 		.recv = 1,
262 		.predata = 2,
263 		.bufsize = 0,
264 		.x_controllen = 32,
265 		.x_cmsg_datalen = 8,
266 		.x_flags = MSG_CTRUNC,
267 	},
268 	{
269 		.name = "send 2, recv 1, pre 1",
270 		.send = 2,
271 		.recv = 1,
272 		.predata = 1,
273 		.bufsize = 0,
274 		.x_controllen = 32,
275 		.x_cmsg_datalen = 8,
276 		.x_flags = MSG_CTRUNC,
277 	},
278 
279 	{
280 		.name = "send 2, recv 1, pre 8, buffer 5",
281 		.send = 2,
282 		.recv = 1,
283 		.predata = 8,
284 		.bufsize = sizeof (int) * 2 - 3,
285 		.x_controllen = 37,
286 		.x_cmsg_datalen = 8,
287 		.x_flags = MSG_CTRUNC,
288 	},
289 	{
290 		.name = "send 2, recv 1, pre 8, buffer 6",
291 		.send = 2,
292 		.recv = 1,
293 		.predata = 8,
294 		.bufsize = sizeof (int) * 2 - 2,
295 		.x_controllen = 38,
296 		.x_cmsg_datalen = 8,
297 		.x_flags = MSG_CTRUNC,
298 	},
299 	{
300 		.name = "send 2, recv 1, pre 8, buffer 7",
301 		.send = 2,
302 		.recv = 1,
303 		.predata = 8,
304 		.bufsize = sizeof (int) * 2 - 1,
305 		.x_controllen = 39,
306 		.x_cmsg_datalen = 8,
307 		.x_flags = MSG_CTRUNC,
308 	},
309 	{
310 		.name = "send 10, recv 1, pre 8",
311 		.send = 10,
312 		.recv = 1,
313 		.predata = 8,
314 		.bufsize = 0,
315 		.x_controllen = 36,
316 		.x_cmsg_datalen = 24,
317 		.x_flags = MSG_CTRUNC,
318 	},
319 
320 	/* End of tests */
321 
322 	{
323 		.name = NULL
324 	}
325 };
326 
327 static int sock = -1, testfd = -1, cfd = -1;
328 static int fdcount;
329 
330 static int
331 fdwalkcb(const prfdinfo_t *info, void *arg)
332 {
333 	if (!S_ISDIR(info->pr_mode) && info->pr_fd > 2 &&
334 	    info->pr_fd != sock && info->pr_fd != testfd &&
335 	    info->pr_fd != cfd) {
336 		if (debug) {
337 			fprintf(stderr, "%s: unexpected fd: %d\n",
338 			    (char *)arg, info->pr_fd);
339 		}
340 		fdcount++;
341 	}
342 
343 	return (0);
344 
345 }
346 
347 static void
348 check_fds(char *tag)
349 {
350 	fdcount = 0;
351 	proc_fdwalk(getpid(), fdwalkcb, tag);
352 }
353 
354 static void
355 send_and_wait(pid_t pid, sigset_t *set, int osig, int isig)
356 {
357 	int sig;
358 
359 	if (osig > 0)
360 		kill(pid, osig);
361 
362 	if (isig > 0) {
363 		if (sigwait(set, &sig) != 0) {
364 			err(EXIT_FAILURE,
365 			    "sigwait failed waiting for %d", isig);
366 		}
367 		if (sig == SIGINT) {
368 			exit(1);
369 		}
370 		if (sig != isig) {
371 			err(EXIT_FAILURE,
372 			    "sigwait returned unexpected signal %d", sig);
373 		}
374 	}
375 }
376 
377 static void
378 sendtest(cmsg_test_t *t)
379 {
380 	struct msghdr msg;
381 	struct cmsghdr *cm;
382 	struct iovec iov;
383 	ssize_t nbytes;
384 	char c = '*';
385 	int i, *p;
386 
387 	bzero(&msg, sizeof (msg));
388 
389 	msg.msg_name = NULL;
390 	msg.msg_namelen = 0;
391 
392 	iov.iov_base = &c;
393 	iov.iov_len = sizeof (c);
394 	msg.msg_iov = &iov;
395 	msg.msg_iovlen = 1;
396 
397 	msg.msg_flags = 0;
398 
399 	msg.msg_controllen = CMSG_SPACE(sizeof (int) * t->send);
400 
401 	if (t->predata > 0) {
402 		/* A dummy cmsg will be inserted at the head of the data */
403 		msg.msg_controllen += CMSG_SPACE(t->predata);
404 	}
405 
406 	msg.msg_control = alloca(msg.msg_controllen);
407 	bzero(msg.msg_control, msg.msg_controllen);
408 
409 	cm = CMSG_FIRSTHDR(&msg);
410 
411 	if (t->predata > 0) {
412 		/* Insert the dummy cmsg */
413 		cm->cmsg_len = CMSG_LEN(t->predata);
414 		cm->cmsg_level = SOL_SOCKET;
415 		cm->cmsg_type = 0;
416 		cm = CMSG_NXTHDR(&msg, cm);
417 	}
418 
419 	cm->cmsg_len = CMSG_LEN(sizeof (int) * t->send);
420 	cm->cmsg_level = SOL_SOCKET;
421 	cm->cmsg_type = SCM_RIGHTS;
422 
423 	p = (int *)CMSG_DATA(cm);
424 	for (i = 0; i < t->send; i++) {
425 		int s = dup(testfd);
426 		if (s == -1)
427 			err(EXIT_FAILURE, "dup()");
428 		*p++ = s;
429 	}
430 
431 	if (debug)
432 		printf("Sending: controllen=%u\n", msg.msg_controllen);
433 
434 	nbytes = sendmsg(cfd, &msg, 0);
435 	if (nbytes == -1)
436 		err(EXIT_FAILURE, "sendmsg()");
437 
438 	p = (int *)CMSG_DATA(cm);
439 	for (i = 0; i < t->send; i++)
440 		(void) close(*p++);
441 }
442 
443 static int
444 server(const char *sockpath, pid_t pid)
445 {
446 	struct sockaddr_un addr;
447 	sigset_t set;
448 	cmsg_test_t *t;
449 
450 	sigemptyset(&set);
451 	sigaddset(&set, SIGUSR2);
452 	sigaddset(&set, SIGINT);
453 
454 	sock = socket(PF_LOCAL, SOCK_STREAM, 0);
455 	if (sock == -1)
456 		err(EXIT_FAILURE, "failed to create socket");
457 	addr.sun_family = AF_UNIX;
458 	strlcpy(addr.sun_path, sockpath, sizeof (addr.sun_path));
459 	if (bind(sock, (struct sockaddr *)&addr, sizeof (addr)) == -1)
460 		err(EXIT_FAILURE, "bind failed");
461 	if (listen(sock, 0) == -1)
462 		err(EXIT_FAILURE, "listen failed");
463 
464 	if ((testfd = open("/dev/null", O_RDONLY)) == -1)
465 		err(EXIT_FAILURE, "/dev/null");
466 
467 	check_fds("server");
468 
469 	/* Signal the child to connect to the socket */
470 	send_and_wait(pid, &set, SIGUSR1, SIGUSR2);
471 
472 	if ((cfd = accept(sock, NULL, 0)) == -1)
473 		err(EXIT_FAILURE, "accept failed");
474 
475 	for (t = tests; t->name != NULL; t++) {
476 		if (debug)
477 			printf("\n>>> Starting test %s\n", t->name);
478 
479 		sendtest(t);
480 		check_fds("server");
481 
482 		send_and_wait(pid, &set, SIGUSR1, SIGUSR2);
483 	}
484 
485 	close(cfd);
486 	close(testfd);
487 	close(sock);
488 
489 	return (0);
490 }
491 
492 static boolean_t pass;
493 
494 static void
495 check(uint_t actual, uint_t expected, char *tag)
496 {
497 	if (actual != expected) {
498 		fprintf(stderr, "    !!!: "
499 		    "%1$s = %2$u(%2$#x) (expected %3$u(%3$#x))\n",
500 		    tag, actual, expected);
501 		pass = _B_FALSE;
502 	} else if (debug) {
503 		fprintf(stderr, "       : "
504 		    "%1$s = %2$u(%2$#x)\n",
505 		    tag, actual);
506 	}
507 }
508 
509 static boolean_t
510 recvtest(cmsg_test_t *t)
511 {
512 	struct msghdr msg;
513 	struct cmsghdr *cm;
514 	struct iovec iov;
515 	size_t bufsize;
516 	ssize_t nbytes;
517 	char c = '*';
518 
519 	bzero(&msg, sizeof (msg));
520 
521 	msg.msg_name = NULL;
522 	msg.msg_namelen = 0;
523 
524 	iov.iov_base = &c;
525 	iov.iov_len = sizeof (c);
526 	msg.msg_iov = &iov;
527 	msg.msg_iovlen = 1;
528 
529 	msg.msg_flags = 0;
530 
531 	/*
532 	 * If the test does not specify a receive buffer size, calculate one
533 	 * from the number of file descriptors to receive.
534 	 */
535 	if (t->bufsize == 0) {
536 		bufsize = sizeof (int) * t->recv;
537 		bufsize = CMSG_SPACE(bufsize);
538 	} else {
539 		/*
540 		 * Use the specific buffer size provided but add in
541 		 * space for the header
542 		 */
543 		bufsize = t->bufsize + CMSG_LEN(0);
544 	}
545 
546 	if (t->predata > 0) {
547 		/* A dummy cmsg will be found at the head of the data */
548 		bufsize += CMSG_SPACE(t->predata);
549 	}
550 
551 	msg.msg_controllen = bufsize;
552 	msg.msg_control = alloca(bufsize);
553 	bzero(msg.msg_control, msg.msg_controllen);
554 
555 	pass = _B_TRUE;
556 
557 	if (debug)
558 		printf("Receiving: controllen=%u, \n", msg.msg_controllen);
559 
560 	nbytes = recvmsg(sock, &msg, 0);
561 
562 	if (nbytes == -1) {
563 		pass = _B_FALSE;
564 		fprintf(stderr, "recvmsg() failed: %s\n", strerror(errno));
565 		goto out;
566 	}
567 
568 	if (debug) {
569 		printf("Received: controllen=%u, flags=%#x\n",
570 		    msg.msg_controllen, msg.msg_flags);
571 	}
572 
573 	check(msg.msg_flags, t->x_flags, "msg_flags");
574 	check(msg.msg_controllen, t->x_controllen, "msg_controllen");
575 
576 	for (cm = CMSG_FIRSTHDR(&msg); cm; cm = CMSG_NXTHDR(&msg, cm)) {
577 		void *data, *end;
578 
579 		if (debug) {
580 			printf("    >> : Got cmsg %x/%x - %u\n", cm->cmsg_level,
581 			    cm->cmsg_type, cm->cmsg_len);
582 		}
583 
584 		if (cm->cmsg_type != SCM_RIGHTS) {
585 			if (debug)
586 				printf("       : skipping cmsg\n");
587 			continue;
588 		}
589 
590 		check(cm->cmsg_len - CMSG_LEN(0),
591 		    t->x_cmsg_datalen, "cmsg_len");
592 
593 		/* Close any received file descriptors */
594 		data = CMSG_DATA(cm);
595 
596 		if ((msg.msg_flags & MSG_CTRUNC) &&
597 		    CMSG_NXTHDR(&msg, cm) == NULL) {
598 			/*
599 			 * illumos did not previously adjust cmsg_len on
600 			 * truncation. This is the last cmsg, derive the
601 			 * length from msg_controllen
602 			 */
603 			end = msg.msg_control + msg.msg_controllen;
604 		} else {
605 			end = data + cm->cmsg_len - CMSG_LEN(0);
606 		}
607 
608 		while (data <= end - sizeof (int)) {
609 			int *a = (int *)data;
610 			if (debug)
611 				printf("       : close(%d)\n", *a);
612 			if (close(*a) == -1) {
613 				pass = _B_FALSE;
614 				fprintf(stderr, "    !!!: "
615 				    "failed to close fd %d - %s\n", *a,
616 				    strerror(errno));
617 			}
618 			data += sizeof (int);
619 		}
620 	}
621 
622 out:
623 
624 	check_fds("client");
625 	check(fdcount, 0, "client descriptors");
626 	printf("     + : %s %s\n", pass ? "PASS" : "FAIL", t->name);
627 
628 	return (pass);
629 }
630 
631 static int
632 client(const char *sockpath, pid_t pid)
633 {
634 	struct sockaddr_un addr;
635 	sigset_t set;
636 	cmsg_test_t *t;
637 	int ret = 0;
638 
639 	sigemptyset(&set);
640 	sigaddset(&set, SIGUSR1);
641 	sigaddset(&set, SIGINT);
642 
643 	send_and_wait(pid, &set, 0, SIGUSR1);
644 
645 	sock = socket(PF_LOCAL, SOCK_STREAM, 0);
646 	if (sock == -1)
647 		err(EXIT_FAILURE, "failed to create socket");
648 	addr.sun_family = AF_UNIX;
649 	strlcpy(addr.sun_path, sockpath, sizeof (addr.sun_path));
650 	if (connect(sock, (struct sockaddr *)&addr, sizeof (addr)) == -1)
651 		err(EXIT_FAILURE, "could not connect to server socket");
652 
653 	for (t = tests; t->name != NULL; t++) {
654 		send_and_wait(pid, &set, SIGUSR2, SIGUSR1);
655 		if (!recvtest(t))
656 			ret = 1;
657 	}
658 
659 	close(sock);
660 
661 	return (ret);
662 }
663 
664 int
665 main(int argc, const char **argv)
666 {
667 	char sockpath[] = "/tmp/cmsg.testsock.XXXXXX";
668 	pid_t pid, ppid;
669 	sigset_t set;
670 	int ret = 0;
671 
672 	if (argc > 1 && strcmp(argv[1], "-d") == 0)
673 		debug = _B_TRUE;
674 
675 	sigfillset(&set);
676 	sigdelset(&set, SIGINT);
677 	sigdelset(&set, SIGTSTP);
678 	sigprocmask(SIG_BLOCK, &set, NULL);
679 
680 	if (mktemp(sockpath) == NULL)
681 		err(EXIT_FAILURE, "Failed to make temporary socket path");
682 
683 	ppid = getpid();
684 	pid = fork();
685 	switch (pid) {
686 	case -1:
687 		err(EXIT_FAILURE, "fork failed");
688 	case 0:
689 		return (server(sockpath, ppid));
690 	default:
691 		break;
692 	}
693 
694 	ret = client(sockpath, pid);
695 	kill(pid, SIGINT);
696 
697 	unlink(sockpath);
698 
699 	return (ret);
700 }
701