1 /*
2  * Copyright 2016 Jakub Klama <jceel@FreeBSD.org>
3  * All rights reserved
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted providing that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
18  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
22  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
23  * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24  * POSSIBILITY OF SUCH DAMAGE.
25  *
26  * Copyright 2021 Joyent, Inc.
27  */
28 
29 #include <stdlib.h>
30 #include <errno.h>
31 #include <string.h>
32 #include <unistd.h>
33 #include <pthread.h>
34 #include <assert.h>
35 #include <sys/types.h>
36 #ifdef __APPLE__
37 # include "../apple_endian.h"
38 #elif __illumos__
39 # include <sys/param.h>
40 # include <port.h>
41 # include "../illumos_endian.h"
42 #else
43 # include <sys/endian.h>
44 #endif
45 #include <sys/socket.h>
46 #ifndef __illumos__
47 # include <sys/event.h>
48 #endif
49 #include <sys/uio.h>
50 #include <netdb.h>
51 #include "../lib9p.h"
52 #include "../lib9p_impl.h"
53 #include "../log.h"
54 #include "socket.h"
55 
56 struct l9p_socket_softc
57 {
58 	struct l9p_connection *ls_conn;
59 	struct sockaddr ls_sockaddr;
60 	socklen_t ls_socklen;
61 	pthread_t ls_thread;
62 	int ls_fd;
63 };
64 
65 #ifdef __FreeBSD__
66 struct event_svr {
67 	struct kevent *ev_kev;
68 	struct kevent *ev_event;
69 	int ev_kq;
70 };
71 #elif __illumos__
72 struct event_svr {
73 	port_event_t *ev_pe;
74 	int ev_port;
75 };
76 #else
77 #error "No event server defined"
78 #endif
79 
80 static int l9p_init_event_svr(struct event_svr *, uint_t);
81 static uint_t l9p_get_server_addrs(const char *, const char *,
82     struct addrinfo **);
83 static uint_t l9p_bind_addrs(struct event_svr *, struct addrinfo *, uint_t,
84     int **);
85 static int l9p_event_get(struct l9p_server *, struct event_svr *, uint_t,
86     void (*cb)(struct l9p_server *, int));
87 static int l9p_socket_readmsg(struct l9p_socket_softc *, void **, size_t *);
88 static int l9p_socket_get_response_buffer(struct l9p_request *,
89     struct iovec *, size_t *, void *);
90 static int l9p_socket_send_response(struct l9p_request *, const struct iovec *,
91     const size_t, const size_t, void *);
92 static void l9p_socket_drop_response(struct l9p_request *, const struct iovec *,
93     size_t, void *);
94 static void *l9p_socket_thread(void *);
95 static ssize_t xread(int, void *, size_t);
96 static ssize_t xwrite(int, void *, size_t);
97 
98 int
l9p_start_server(struct l9p_server * server,const char * host,const char * port)99 l9p_start_server(struct l9p_server *server, const char *host, const char *port)
100 {
101 	struct addrinfo *res = NULL;
102 	int *sockets = NULL;
103 	uint_t naddrs = 0;
104 	uint_t nsockets = 0;
105 	uint_t i;
106 	struct event_svr esvr;
107 
108 	naddrs = l9p_get_server_addrs(host, port, &res);
109 	if (naddrs == 0)
110 		return (-1);
111 
112 	if (l9p_init_event_svr(&esvr, naddrs) != 0) {
113 		freeaddrinfo(res);
114 		return (-1);
115 	}
116 
117 	nsockets = l9p_bind_addrs(&esvr, res, naddrs, &sockets);
118 
119 	/*
120 	 * We don't need res, after this, so free it and NULL it to prevent
121 	 * any possible use after free.
122 	 */
123 	freeaddrinfo(res);
124 	res = NULL;
125 
126 	if (nsockets == 0)
127 		goto fail;
128 
129 	for (;;) {
130 		if (l9p_event_get(server, &esvr, nsockets,
131 		    l9p_socket_accept) < 0)
132 			break;
133 	}
134 
135 	/* We get here if something failed */
136 	for (i = 0; i < nsockets; i++)
137 		close(sockets[i]);
138 
139 fail:
140 	free(sockets);
141 
142 #ifdef __FreeBSD__
143 	close(esvr.ev_kq);
144 	free(esvr.ev_kev);
145 	free(esvr.ev_event);
146 #elif __illumos__
147 	close(esvr.ev_port);
148 	free(esvr.ev_pe);
149 #else
150 #error "Port me"
151 #endif
152 
153 	return (-1);
154 }
155 
156 static uint_t
l9p_get_server_addrs(const char * host,const char * port,struct addrinfo ** resp)157 l9p_get_server_addrs(const char *host, const char *port, struct addrinfo **resp)
158 {
159 	struct addrinfo *res, hints;
160 	uint_t naddrs;
161 	int rc;
162 
163 	memset(&hints, 0, sizeof(hints));
164 	hints.ai_family = PF_UNSPEC;
165 	hints.ai_socktype = SOCK_STREAM;
166 	rc = getaddrinfo(host, port, &hints, resp);
167 	if (rc > 0) {
168 		L9P_LOG(L9P_ERROR, "getaddrinfo(): %s", gai_strerror(rc));
169 		return (0);
170 	}
171 
172 	naddrs = 0;
173 	for (res = *resp; res != NULL; res = res->ai_next)
174 		naddrs++;
175 
176 	if (naddrs == 0) {
177 		L9P_LOG(L9P_ERROR, "no addresses found for %s:%s", host, port);
178 	}
179 
180 	return (naddrs);
181 }
182 
183 #ifdef __FreeBSD__
184 static int
l9p_init_event_svr(struct event_svr * svr,uint_t nsockets)185 l9p_init_event_svr(struct event_svr *svr, uint_t nsockets)
186 {
187 	svr->ev_kev = calloc(nsockets, sizeof(struct kevent));
188 	if (svr->ev_kev == NULL) {
189 		L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
190 		return (-1);
191 	}
192 
193 	svr->ev_event = calloc(nsockets, sizeof(struct kevent));
194 	if (svr->ev_event == NULL) {
195 		L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
196 		free(svr->ev_key);
197 		svr->ev_key = NULL;
198 		return (-1);
199 	}
200 
201 	svr->ev_kq = kqueue();
202 	if (svr->ev_kq == -1) {
203 		L9P_LOG(L9P_ERROR, "kqueue(): %s", strerror(errno));
204 		free(svr->ev_kev);
205 		free(svr->ev_event);
206 		svr->ev_kev = NULL;
207 		svr->ev_event = NULL;
208 		return (-1);
209 	}
210 
211 	return (0);
212 }
213 #elif __illumos__
214 static int
l9p_init_event_svr(struct event_svr * svr,uint_t nsockets)215 l9p_init_event_svr(struct event_svr *svr, uint_t nsockets)
216 {
217 	svr->ev_pe = calloc(nsockets, sizeof(port_event_t));
218 	if (svr->ev_pe == NULL) {
219 		L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
220 		return (-1);
221 	}
222 
223 	svr->ev_port = port_create();
224 	if (svr->ev_port == -1) {
225 		L9P_LOG(L9P_ERROR, "port_create(): %s", strerror(errno));
226 		return (-1);
227 	}
228 
229 	return (0);
230 }
231 #else
232 #error "No event server defined"
233 #endif
234 
235 static uint_t
l9p_bind_addrs(struct event_svr * svr,struct addrinfo * addrs,uint_t naddrs,int ** socketsp)236 l9p_bind_addrs(struct event_svr *svr, struct addrinfo *addrs, uint_t naddrs,
237     int **socketsp)
238 {
239 	struct addrinfo *addr;
240 	uint_t i, j;
241 
242 	*socketsp = calloc(naddrs, sizeof(int));
243 	if (*socketsp == NULL) {
244 		L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
245 		return (0);
246 	}
247 
248 	for (i = 0, addr = addrs; addr != NULL; addr = addr->ai_next) {
249 		int s;
250 		int val = 1;
251 
252 		s = socket(addr->ai_family, addr->ai_socktype,
253 		    addr->ai_protocol);
254 		if (s == -1) {
255 			L9P_LOG(L9P_ERROR, "socket(): %s", strerror(errno));
256 			continue;
257 		}
258 
259 		if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &val,
260 		    sizeof(val)) < 0) {
261 			L9P_LOG(L9P_ERROR, "setsockopt(): %s", strerror(errno));
262 			close(s);
263 			continue;
264 		}
265 
266 		if (bind(s, addr->ai_addr, addr->ai_addrlen) < 0) {
267 			L9P_LOG(L9P_ERROR, "bind(): %s", strerror(errno));
268 			close(s);
269 			continue;
270 		}
271 
272 		if (listen(s, 10) < 0) {
273 			L9P_LOG(L9P_ERROR, "listen(): %s", strerror(errno));
274 			close(s);
275 			continue;
276 		}
277 
278 #ifdef __FreeBSD__
279 		EV_SET(&svr->ev_kev[i], s, EVFILT_READ, EV_ADD | EV_ENABLE, 0,
280 		    0, 0);
281 #elif __illumos__
282 		if (port_associate(svr->ev_port, PORT_SOURCE_FD, s,
283 		    POLLIN|POLLHUP, NULL) < 0) {
284 			L9P_LOG(L9P_ERROR, "port_associate(%d): %s", s,
285 			    strerror(errno));
286 			close(s);
287 			continue;
288 		}
289 #else
290 #error "Port me"
291 #endif
292 
293 		*socketsp[i++] = s;
294 	}
295 
296 	if (i < 1) {
297 		free(*socketsp);
298 		*socketsp = NULL;
299 		return (0);
300 	}
301 
302 	for (j = i; j < naddrs; j++)
303 		*socketsp[j++] = -1;
304 
305 #ifdef __FreeBSD__
306 	if (kevent(svr->ev_kq, svr->ev_kev, i, NULL, 0, NULL) < 0) {
307 		L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));
308 
309 		for (j = 0; j < i; j++)
310 			close(j);
311 
312 		free(*socketsp);
313 		*socketsp = NULL;
314 
315 		return (0);
316 	}
317 #endif
318 
319 	return (i);
320 }
321 
322 #ifdef __FreeBSD__
323 static int
l9p_event_get(struct l9p_server * l9svr,struct event_svr * esvr,uint_t nsockets,void (* cb)(struct l9p_server *,int))324 l9p_event_get(struct l9p_server *l9svr, struct event_svr *esvr, uint_t nsockets,
325     void (*cb)(struct l9p_server *, int))
326 {
327 	int i, evs;
328 
329 	evs = kevent(esvr->ev_kq, NULL, 0, esvr->ev_event, nsockets, NULL);
330 	if (evs < 0) {
331 		if (errno == EINTR)
332 			return (0);
333 		L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));
334 		return (-1);
335 	}
336 
337 	for (i = 0; i < evs; i++)
338 		cb(l9svr, (int)sevr->ev_event[i].ident);
339 
340 	return (0);
341 }
342 #elif __illumos__
343 static int
l9p_event_get(struct l9p_server * l9svr,struct event_svr * esvr,uint_t nsockets,void (* cb)(struct l9p_server *,int))344 l9p_event_get(struct l9p_server *l9svr, struct event_svr *esvr, uint_t nsockets,
345     void (*cb)(struct l9p_server *, int))
346 {
347 	uint_t evs = 1;
348 	int i;
349 
350 	if (port_getn(esvr->ev_port, esvr->ev_pe, nsockets, &evs, NULL) < 0) {
351 		if (errno == EINTR)
352 			return (0);
353 		L9P_LOG(L9P_ERROR, "port_getn(): %s", strerror(errno));
354 		return (-1);
355 	}
356 
357 	for (i = 0; i < evs; i++) {
358 		if (esvr->ev_pe[i].portev_source != PORT_SOURCE_FD)
359 			continue;
360 
361 		cb(l9svr, (int)esvr->ev_pe[i].portev_object);
362 	}
363 
364 	return (0);
365 }
366 #else
367 #error "Port me"
368 #endif
369 
370 void
l9p_socket_accept(struct l9p_server * server,int svr_fd)371 l9p_socket_accept(struct l9p_server *server, int svr_fd)
372 {
373 	struct l9p_socket_softc *sc;
374 	struct l9p_connection *conn;
375 	char host[NI_MAXHOST + 1];
376 	char serv[NI_MAXSERV + 1];
377 	struct sockaddr client_addr;
378 	socklen_t client_addr_len = sizeof(client_addr);
379 	int conn_fd, err;
380 
381 	conn_fd = accept(svr_fd, &client_addr, &client_addr_len);
382 	if (conn_fd < 0) {
383 		L9P_LOG(L9P_WARNING, "accept(): %s", strerror(errno));
384 		return;
385 	}
386 
387 	err = getnameinfo(&client_addr, client_addr_len, host, NI_MAXHOST,
388 	    serv, NI_MAXSERV, NI_NUMERICHOST | NI_NUMERICSERV);
389 
390 	if (err != 0) {
391 		L9P_LOG(L9P_WARNING, "cannot look up client name: %s",
392 		    gai_strerror(err));
393 	} else {
394 		L9P_LOG(L9P_INFO, "new connection from %s:%s", host, serv);
395 	}
396 
397 	if (l9p_connection_init(server, &conn) != 0) {
398 		L9P_LOG(L9P_ERROR, "cannot create new connection");
399 		return;
400 	}
401 
402 	sc = l9p_calloc(1, sizeof(*sc));
403 	sc->ls_conn = conn;
404 	sc->ls_fd = conn_fd;
405 
406 	/*
407 	 * Fill in transport handler functions and aux argument.
408 	 */
409 	conn->lc_lt.lt_aux = sc;
410 	conn->lc_lt.lt_get_response_buffer = l9p_socket_get_response_buffer;
411 	conn->lc_lt.lt_send_response = l9p_socket_send_response;
412 	conn->lc_lt.lt_drop_response = l9p_socket_drop_response;
413 
414 	err = pthread_create(&sc->ls_thread, NULL, l9p_socket_thread, sc);
415 	if (err) {
416 		L9P_LOG(L9P_ERROR,
417 		    "pthread_create (for connection from %s:%s): error %s",
418 		    host, serv, strerror(err));
419 		l9p_connection_close(sc->ls_conn);
420 		free(sc);
421 	}
422 }
423 
424 static void *
l9p_socket_thread(void * arg)425 l9p_socket_thread(void *arg)
426 {
427 	struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;
428 	struct iovec iov;
429 	void *buf;
430 	size_t length;
431 
432 	for (;;) {
433 		if (l9p_socket_readmsg(sc, &buf, &length) != 0)
434 			break;
435 
436 		iov.iov_base = buf;
437 		iov.iov_len = length;
438 		l9p_connection_recv(sc->ls_conn, &iov, 1, NULL);
439 		free(buf);
440 	}
441 
442 	L9P_LOG(L9P_INFO, "connection closed");
443 	l9p_connection_close(sc->ls_conn);
444 	free(sc);
445 	return (NULL);
446 }
447 
448 static int
l9p_socket_readmsg(struct l9p_socket_softc * sc,void ** buf,size_t * size)449 l9p_socket_readmsg(struct l9p_socket_softc *sc, void **buf, size_t *size)
450 {
451 	uint32_t msize;
452 	size_t toread;
453 	ssize_t ret;
454 	void *buffer;
455 	int fd = sc->ls_fd;
456 
457 	assert(fd > 0);
458 
459 	buffer = l9p_malloc(sizeof(uint32_t));
460 
461 	ret = xread(fd, buffer, sizeof(uint32_t));
462 	if (ret < 0) {
463 		L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
464 		return (-1);
465 	}
466 
467 	if (ret != sizeof(uint32_t)) {
468 		if (ret == 0) {
469 			L9P_LOG(L9P_DEBUG, "%p: EOF", (void *)sc->ls_conn);
470 		} else {
471 			L9P_LOG(L9P_ERROR,
472 			    "short read: %zd bytes of %zd expected",
473 			    ret, sizeof(uint32_t));
474 		}
475 		return (-1);
476 	}
477 
478 	msize = le32toh(*(uint32_t *)buffer);
479 	toread = msize - sizeof(uint32_t);
480 	buffer = l9p_realloc(buffer, msize);
481 
482 	ret = xread(fd, (char *)buffer + sizeof(uint32_t), toread);
483 	if (ret < 0) {
484 		L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
485 		return (-1);
486 	}
487 
488 	if (ret != (ssize_t)toread) {
489 		L9P_LOG(L9P_ERROR, "short read: %zd bytes of %zd expected",
490 		    ret, toread);
491 		return (-1);
492 	}
493 
494 	*size = msize;
495 	*buf = buffer;
496 	L9P_LOG(L9P_INFO, "%p: read complete message, buf=%p size=%d",
497 	    (void *)sc->ls_conn, buffer, msize);
498 
499 	return (0);
500 }
501 
502 static int
l9p_socket_get_response_buffer(struct l9p_request * req,struct iovec * iov,size_t * niovp,void * arg __unused)503 l9p_socket_get_response_buffer(struct l9p_request *req, struct iovec *iov,
504     size_t *niovp, void *arg __unused)
505 {
506 	size_t size = req->lr_conn->lc_msize;
507 	void *buf;
508 
509 	buf = l9p_malloc(size);
510 	iov[0].iov_base = buf;
511 	iov[0].iov_len = size;
512 
513 	*niovp = 1;
514 	return (0);
515 }
516 
517 static int
l9p_socket_send_response(struct l9p_request * req __unused,const struct iovec * iov,const size_t niov __unused,const size_t iolen,void * arg)518 l9p_socket_send_response(struct l9p_request *req __unused,
519     const struct iovec *iov, const size_t niov __unused, const size_t iolen,
520     void *arg)
521 {
522 	struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;
523 
524 	assert(sc->ls_fd >= 0);
525 
526 	L9P_LOG(L9P_DEBUG, "%p: sending reply, buf=%p, size=%d", arg,
527 	    iov[0].iov_base, iolen);
528 
529 	if (xwrite(sc->ls_fd, iov[0].iov_base, iolen) != (int)iolen) {
530 		L9P_LOG(L9P_ERROR, "short write: %s", strerror(errno));
531 		return (-1);
532 	}
533 
534 	free(iov[0].iov_base);
535 	return (0);
536 }
537 
538 static void
l9p_socket_drop_response(struct l9p_request * req __unused,const struct iovec * iov,size_t niov __unused,void * arg)539 l9p_socket_drop_response(struct l9p_request *req __unused,
540     const struct iovec *iov, size_t niov __unused, void *arg)
541 {
542 
543 	L9P_LOG(L9P_DEBUG, "%p: drop buf=%p", arg, iov[0].iov_base);
544 	free(iov[0].iov_base);
545 }
546 
547 static ssize_t
xread(int fd,void * buf,size_t count)548 xread(int fd, void *buf, size_t count)
549 {
550 	size_t done = 0;
551 	ssize_t ret;
552 
553 	while (done < count) {
554 		ret = read(fd, (char *)buf + done, count - done);
555 		if (ret < 0) {
556 			if (errno == EINTR)
557 				continue;
558 
559 			return (-1);
560 		}
561 
562 		if (ret == 0)
563 			return ((ssize_t)done);
564 
565 		done += (size_t)ret;
566 	}
567 
568 	return ((ssize_t)done);
569 }
570 
571 static ssize_t
xwrite(int fd,void * buf,size_t count)572 xwrite(int fd, void *buf, size_t count)
573 {
574 	size_t done = 0;
575 	ssize_t ret;
576 
577 	while (done < count) {
578 		ret = write(fd, (char *)buf + done, count - done);
579 		if (ret < 0) {
580 			if (errno == EINTR)
581 				continue;
582 
583 			return (-1);
584 		}
585 
586 		if (ret == 0)
587 			return ((ssize_t)done);
588 
589 		done += (size_t)ret;
590 	}
591 
592 	return ((ssize_t)done);
593 }
594