1#! /usr/bin/env python
2
3"""
4handle plan9 server <-> client connections
5
6(We can act as either server or client.)
7
8This code needs some doctests or other unit tests...
9"""
10
11import collections
12import errno
13import logging
14import math
15import os
16import socket
17import stat
18import struct
19import sys
20import threading
21import time
22
23import lerrno
24import numalloc
25import p9err
26import pfod
27import protocol
28
29# Timespec based timestamps, if present, have
30# both seconds and nanoseconds.
31Timespec = collections.namedtuple('Timespec', 'sec nsec')
32
33# File attributes from Tgetattr, or given to Tsetattr.
34# (move to protocol.py?)  We use pfod here instead of
35# namedtuple so that we can create instances with all-None
36# fields easily.
37Fileattrs = pfod.pfod('Fileattrs',
38    'ino mode uid gid nlink rdev size blksize blocks '
39    'atime mtime ctime btime gen data_version')
40
41qt2n = protocol.qid_type2name
42
43STD_P9_PORT=564
44
45class P9Error(Exception):
46    pass
47
48class RemoteError(P9Error):
49    """
50    Used when the remote returns an error.  We track the client
51    (connection instance), the operation being attempted, the
52    message, and an error number and type.  The message may be
53    from the Rerror reply, or from converting the errno in a dot-L
54    or dot-u Rerror reply.  The error number may be None if the
55    type is 'Rerror' rather than 'Rlerror'.  The message may be
56    None or empty string if a non-None errno supplies the error
57    instead.
58    """
59    def __init__(self, client, op, msg, etype, errno):
60        self.client = str(client)
61        self.op = op
62        self.msg = msg
63        self.etype = etype # 'Rerror' or 'Rlerror'
64        self.errno = errno # may be None
65        self.message = self._get_message()
66        super(RemoteError, self).__init__(self, self.message)
67
68    def __repr__(self):
69        return ('{0!r}({1}, {2}, {3}, {4}, '
70                '{5})'.format(self.__class__.__name__, self.client, self.op,
71                              self.msg, self.errno, self.etype))
72    def __str__(self):
73        prefix = '{0}: {1}: '.format(self.client, self.op)
74        if self.errno: # check for "is not None", or just non-false-y?
75            name = {'Rerror': '.u', 'Rlerror': 'Linux'}[self.etype]
76            middle = '[{0} error {1}] '.format(name, self.errno)
77        else:
78            middle = ''
79        return '{0}{1}{2}'.format(prefix, middle, self.message)
80
81    def is_ENOTSUP(self):
82        if self.etype == 'Rlerror':
83            return self.errno == lerrno.EOPNOTSUPP
84        return self.errno == errno.EOPNOTSUPP
85
86    def _get_message(self):
87        "get message based on self.msg or self.errno"
88        if self.errno is not None:
89            return {
90                'Rlerror': p9err.dotl_strerror,
91                'Rerror' : p9err.dotu_strerror,
92            }[self.etype](self.errno)
93        return self.msg
94
95class LocalError(P9Error):
96    pass
97
98class TEError(LocalError):
99    pass
100
101class P9SockIO(object):
102    """
103    Common base for server and client, handle send and
104    receive to communications channel.  Note that this
105    need not set up the channel initially, only the logger.
106    The channel is typically connected later.  However, you
107    can provide one initially.
108    """
109    def __init__(self, logger, name=None, server=None, port=STD_P9_PORT):
110        self.logger = logger
111        self.channel = None
112        self.name = name
113        self.maxio = None
114        self.size_coder = struct.Struct('<I')
115        if server is not None:
116            self.connect(server, port)
117        self.max_payload = 2**32 - self.size_coder.size
118
119    def __str__(self):
120        if self.name:
121            return self.name
122        return repr(self)
123
124    def get_recommended_maxio(self):
125        "suggest a max I/O size, for when self.maxio is 0 / unset"
126        return 16 * 4096
127
128    def min_maxio(self):
129        "return a minimum size below which we refuse to work"
130        return self.size_coder.size + 100
131
132    def connect(self, server, port=STD_P9_PORT):
133        """
134        Connect to given server name / IP address.
135
136        If self.name was none, sets self.name to ip:port on success.
137        """
138        if self.is_connected():
139            raise LocalError('already connected')
140        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
141        sock.connect((server, port))
142        if self.name is None:
143            if port == STD_P9_PORT:
144                name = server
145            else:
146                name = '{0}:{1}'.format(server, port)
147        else:
148            name = None
149        self.declare_connected(sock, name, None)
150
151    def is_connected(self):
152        "predicate: are we connected?"
153        return self.channel != None
154
155    def declare_connected(self, chan, name, maxio):
156        """
157        Now available for normal protocol (size-prefixed) I/O.
158
159        Replaces chan and name and adjusts maxio, if those
160        parameters are not None.
161        """
162        if maxio:
163            minio = self.min_maxio()
164            if maxio < minio:
165                raise LocalError('maxio={0} < minimum {1}'.format(maxio, minio))
166        if chan is not None:
167            self.channel = chan
168        if name is not None:
169            self.name = name
170        if maxio is not None:
171            self.maxio = maxio
172            self.max_payload = maxio - self.size_coder.size
173
174    def reduce_maxio(self, maxio):
175        "Reduce maximum I/O size per other-side request"
176        minio = self.min_maxio()
177        if maxio < minio:
178            raise LocalError('new maxio={0} < minimum {1}'.format(maxio, minio))
179        if maxio > self.maxio:
180            raise LocalError('new maxio={0} > current {1}'.format(maxio,
181                                                                  self.maxio))
182        self.maxio = maxio
183        self.max_payload = maxio - self.size_coder.size
184
185    def declare_disconnected(self):
186        "Declare comm channel dead (note: leaves self.name set!)"
187        self.channel = None
188        self.maxio = None
189
190    def shutwrite(self):
191        "Do a SHUT_WR on the outbound channel - can't send more"
192        chan = self.channel
193        # we're racing other threads here
194        try:
195            chan.shutdown(socket.SHUT_WR)
196        except (OSError, AttributeError):
197            pass
198
199    def shutdown(self):
200        "Shut down comm channel"
201        if self.channel:
202            try:
203                self.channel.shutdown(socket.SHUT_RDWR)
204            except socket.error:
205                pass
206            self.channel.close()
207            self.declare_disconnected()
208
209    def read(self):
210        """
211        Try to read a complete packet.
212
213        Returns '' for EOF, as read() usually does.
214
215        If we can't even get the size, this still returns ''.
216        If we get a sensible size but are missing some data,
217        we can return a short packet.  Since we know if we did
218        this, we also return a boolean: True means "really got a
219        complete packet."
220
221        Note that '' EOF always returns False: EOF is never a
222        complete packet.
223        """
224        if self.channel is None:
225            return b'', False
226        size_field = self.xread(self.size_coder.size)
227        if len(size_field) < self.size_coder.size:
228            if len(size_field) == 0:
229                self.logger.log(logging.INFO, '%s: normal EOF', self)
230            else:
231                self.logger.log(logging.ERROR,
232                               '%s: EOF while reading size (got %d bytes)',
233                               self, len(size_field))
234                # should we raise an error here?
235            return b'', False
236
237        size = self.size_coder.unpack(size_field)[0] - self.size_coder.size
238        if size <= 0 or size > self.max_payload:
239            self.logger.log(logging.ERROR,
240                            '%s: incoming size %d is insane '
241                            '(max payload is %d)',
242                            self, size, self.max_payload)
243            # indicate EOF - should we raise an error instead, here?
244            return b'', False
245        data = self.xread(size)
246        return data, len(data) == size
247
248    def xread(self, nbytes):
249        """
250        Read nbytes bytes, looping if necessary.  Return '' for
251        EOF; may return a short count if we get some data, then
252        EOF.
253        """
254        assert nbytes > 0
255        # Try to get everything at once (should usually succeed).
256        # Return immediately for EOF or got-all-data.
257        data = self.channel.recv(nbytes)
258        if data == b'' or len(data) == nbytes:
259            return data
260
261        # Gather data fragments into an array, then join it all at
262        # the end.
263        count = len(data)
264        data = [data]
265        while count < nbytes:
266            more = self.channel.recv(nbytes - count)
267            if more == b'':
268                break
269            count += len(more)
270            data.append(more)
271        return b''.join(data)
272
273    def write(self, data):
274        """
275        Write all the data, in the usual encoding.  Note that
276        the length of the data, including the length of the length
277        itself, is already encoded in the first 4 bytes of the
278        data.
279
280        Raises IOError if we can't write everything.
281
282        Raises LocalError if len(data) exceeds max_payload.
283        """
284        size = len(data)
285        assert size >= 4
286        if size > self.max_payload:
287            raise LocalError('data length {0} exceeds '
288                             'maximum {1}'.format(size, self.max_payload))
289        self.channel.sendall(data)
290
291def _pathcat(prefix, suffix):
292    """
293    Concatenate paths we are using on the server side.  This is
294    basically just prefix + / + suffix, with two complications:
295
296    It's possible we don't have a prefix path, in which case
297    we want the suffix without a leading slash.
298
299    It's possible that the prefix is just b'/', in which case we
300    want prefix + suffix.
301    """
302    if prefix:
303        if prefix == b'/':  # or prefix.endswith(b'/')?
304            return prefix + suffix
305        return prefix + b'/' + suffix
306    return suffix
307
308class P9Client(P9SockIO):
309    """
310    Act as client.
311
312    We need the a logger (see logging), a timeout, and a protocol
313    version to request.  By default, we will downgrade to a lower
314    version if asked.
315
316    If server and port are supplied, they are remembered and become
317    the default for .connect() (which is still deferred).
318
319    Note that we keep a table of fid-to-path in self.live_fids,
320    but at any time (except while holding the lock) a fid can
321    be deleted entirely, and the table entry may just be True
322    if we have no path name.  In general, we update the name
323    when we can.
324    """
325    def __init__(self, logger, timeout, version, may_downgrade=True,
326                 server=None, port=None):
327        super(P9Client, self).__init__(logger)
328        self.timeout = timeout
329        self.iproto = protocol.p9_version(version)
330        self.may_downgrade = may_downgrade
331        self.tagalloc = numalloc.NumAlloc(0, 65534)
332        self.tagstate = {}
333        # The next bit is slighlty dirty: perhaps we should just
334        # allocate NOFID out of the 2**32-1 range, so as to avoid
335        # "knowing" that it's 2**32-1.
336        self.fidalloc = numalloc.NumAlloc(0, protocol.td.NOFID - 1)
337        self.live_fids = {}
338        self.rootfid = None
339        self.rootqid = None
340        self.rthread = None
341        self.lock = threading.Lock()
342        self.new_replies = threading.Condition(self.lock)
343        self._monkeywrench = {}
344        self._server = server
345        self._port = port
346        self._unsup = {}
347
348    def get_monkey(self, what):
349        "check for a monkey-wrench"
350        with self.lock:
351            wrench = self._monkeywrench.get(what)
352            if wrench is None:
353                return None
354            if isinstance(wrench, list):
355                # repeats wrench[0] times, or forever if that's 0
356                ret = wrench[1]
357                if wrench[0] > 0:
358                    wrench[0] -= 1
359                    if wrench[0] == 0:
360                        del self._monkeywrench[what]
361            else:
362                ret = wrench
363                del self._monkeywrench[what]
364        return ret
365
366    def set_monkey(self, what, how, repeat=None):
367        """
368        Set a monkey-wrench.  If repeat is not None it is the number of
369        times the wrench is applied (0 means forever, or until you call
370        set again with how=None).  What is what to monkey-wrench, which
371        depends on the op.  How is generally a replacement value.
372        """
373        if how is None:
374            with self.lock:
375                try:
376                    del self._monkeywrench[what]
377                except KeyError:
378                    pass
379            return
380        if repeat is not None:
381            how = [repeat, how]
382        with self.lock:
383            self._monkeywrench[what] = how
384
385    def get_tag(self, for_Tversion=False):
386        "get next available tag ID"
387        with self.lock:
388            if for_Tversion:
389                tag = 65535
390            else:
391                tag = self.tagalloc.alloc()
392            if tag is None:
393                raise LocalError('all tags in use')
394            self.tagstate[tag] = True # ie, in use, still waiting
395        return tag
396
397    def set_tag(self, tag, reply):
398        "set the reply info for the given tag"
399        assert tag >= 0 and tag < 65536
400        with self.lock:
401            # check whether we're still waiting for the tag
402            state = self.tagstate.get(tag)
403            if state is True:
404                self.tagstate[tag] = reply # i.e., here's the answer
405                self.new_replies.notify_all()
406                return
407            # state must be one of these...
408            if state is False:
409                # We gave up on this tag.  Reply came anyway.
410                self.logger.log(logging.INFO,
411                                '%s: got tag %d = %r after timing out on it',
412                                self, tag, reply)
413                self.retire_tag_locked(tag)
414                return
415            if state is None:
416                # We got a tag back from the server that was not
417                # outstanding!
418                self.logger.log(logging.WARNING,
419                                '%s: got tag %d = %r when tag %d not in use!',
420                                self, tag, reply, tag)
421                return
422            # We got a second reply before handling the first reply!
423            self.logger.log(logging.WARNING,
424                            '%s: got tag %d = %r when tag %d = %r!',
425                            self, tag, reply, tag, state)
426            return
427
428    def retire_tag(self, tag):
429        "retire the given tag - only used by the thread that handled the result"
430        if tag == 65535:
431            return
432        assert tag >= 0 and tag < 65535
433        with self.lock:
434            self.retire_tag_locked(tag)
435
436    def retire_tag_locked(self, tag):
437        "retire the given tag while holding self.lock"
438        # must check "in tagstate" because we can race
439        # with retire_all_tags.
440        if tag in self.tagstate:
441            del self.tagstate[tag]
442            self.tagalloc.free(tag)
443
444    def retire_all_tags(self):
445        "retire all tags, after connection drop"
446        with self.lock:
447            # release all tags in any state (waiting, answered, timedout)
448            self.tagalloc.free_multi(self.tagstate.keys())
449            self.tagstate = {}
450            self.new_replies.notify_all()
451
452    def alloc_fid(self):
453        "allocate new fid"
454        with self.lock:
455            fid = self.fidalloc.alloc()
456            self.live_fids[fid] = True
457        return fid
458
459    def getpath(self, fid):
460        "get path from fid, or return None if no path known, or not valid"
461        with self.lock:
462            path = self.live_fids.get(fid)
463        if path is True:
464            path = None
465        return path
466
467    def getpathX(self, fid):
468        """
469        Much like getpath, but return <fid N, unknown path> if necessary.
470        If we do have a path, return its repr().
471        """
472        path = self.getpath(fid)
473        if path is None:
474            return '<fid {0}, unknown path>'.format(fid)
475        return repr(path)
476
477    def setpath(self, fid, path):
478        "associate fid with new path (possibly from another fid)"
479        with self.lock:
480            if isinstance(path, int):
481                path = self.live_fids.get(path)
482            # path might now be None (not a live fid after all), or
483            # True (we have no path name), or potentially even the
484            # empty string (invalid for our purposes).  Treat all of
485            # those as True, meaning "no known path".
486            if not path:
487                path = True
488            if self.live_fids.get(fid):
489                # Existing fid maps to either True or its old path.
490                # Set the new path (which may be just a placeholder).
491                self.live_fids[fid] = path
492
493    def did_rename(self, fid, ncomp, newdir=None):
494        """
495        Announce that we renamed using a fid - we'll try to update
496        other fids based on this (we can't really do it perfectly).
497
498        NOTE: caller must provide a final-component.
499        The caller can supply the new path (and should
500        do so if the rename is not based on the retained path
501        for the supplied fid, i.e., for rename ops where fid
502        can move across directories).  The rules:
503
504         - If newdir is None (default), we use stored path.
505         - Otherwise, newdir provides the best approximation
506           we have to the path that needs ncomp appended.
507
508        (This is based on the fact that renames happen via Twstat
509        or Trename, or Trenameat, which change just one tail component,
510        but the path names vary.)
511        """
512        if ncomp is None:
513            return
514        opath = self.getpath(fid)
515        if newdir is None:
516            if opath is None:
517                return
518            ocomps = opath.split(b'/')
519            ncomps = ocomps[0:-1]
520        else:
521            ocomps = None           # well, none yet anyway
522            ncomps = newdir.split(b'/')
523        ncomps.append(ncomp)
524        if opath is None or opath[0] != '/':
525            # We don't have enough information to fix anything else.
526            # Just store the new path and return.  We have at least
527            # a partial path now, which is no worse than before.
528            npath = b'/'.join(ncomps)
529            with self.lock:
530                if fid in self.live_fids:
531                    self.live_fids[fid] = npath
532            return
533        if ocomps is None:
534            ocomps = opath.split(b'/')
535        olen = len(ocomps)
536        ofinal = ocomps[olen - 1]
537        # Old paths is full path.  Find any other fids that start
538        # with some or all the components in ocomps.  Note that if
539        # we renamed /one/two/three to /four/five this winds up
540        # renaming files /one/a to /four/a, /one/two/b to /four/five/b,
541        # and so on.
542        with self.lock:
543            for fid2, path2 in self.live_fids.iteritems():
544                # Skip fids without byte-string paths
545                if not isinstance(path2, bytes):
546                    continue
547                # Before splitting (which is a bit expensive), try
548                # a straightforward prefix match.  This might give
549                # some false hits, e.g., prefix /one/two/threepenny
550                # starts with /one/two/three, but it quickly eliminates
551                # /raz/baz/mataz and the like.
552                if not path2.startswith(opath):
553                    continue
554                # Split up the path, and use that to make sure that
555                # the final component is a full match.
556                parts2 = path2.split(b'/')
557                if parts2[olen - 1] != ofinal:
558                    continue
559                # OK, path2 starts with the old (renamed) sequence.
560                # Replace the old components with the new ones.
561                # This updates the renamed fid when we come across
562                # it!  It also handles a change in the number of
563                # components, thanks to Python's slice assignment.
564                parts2[0:olen] = ncomps
565                self.live_fids[fid2] = b'/'.join(parts2)
566
567    def retire_fid(self, fid):
568        "retire one fid"
569        with self.lock:
570            self.fidalloc.free(fid)
571            del self.live_fids[fid]
572
573    def retire_all_fids(self):
574        "return live fids to pool"
575        # this is useful for debugging fid leaks:
576        #for fid in self.live_fids:
577        #    print 'retiring', fid, self.getpathX(fid)
578        with self.lock:
579            self.fidalloc.free_multi(self.live_fids.keys())
580            self.live_fids = {}
581
582    def read_responses(self):
583        "Read responses.  This gets spun off as a thread."
584        while self.is_connected():
585            pkt, is_full = super(P9Client, self).read()
586            if pkt == b'':
587                self.shutwrite()
588                self.retire_all_tags()
589                return
590            if not is_full:
591                self.logger.log(logging.WARNING, '%s: got short packet', self)
592            try:
593                # We have one special case: if we're not yet connected
594                # with a version, we must unpack *as if* it's a plain
595                # 9P2000 response.
596                if self.have_version:
597                    resp = self.proto.unpack(pkt)
598                else:
599                    resp = protocol.plain.unpack(pkt)
600            except protocol.SequenceError as err:
601                self.logger.log(logging.ERROR, '%s: bad response: %s',
602                                self, err)
603                try:
604                    resp = self.proto.unpack(pkt, noerror=True)
605                except protocol.SequenceError:
606                    header = self.proto.unpack_header(pkt, noerror=True)
607                    self.logger.log(logging.ERROR,
608                                    '%s: (not even raw-decodable)', self)
609                    self.logger.log(logging.ERROR,
610                                    '%s: header decode produced %r',
611                                    self, header)
612                else:
613                    self.logger.log(logging.ERROR,
614                                    '%s: raw decode produced %r',
615                                    self, resp)
616                # after this kind of problem, probably need to
617                # shut down, but let's leave that out for a bit
618            else:
619                # NB: all protocol responses have a "tag",
620                # so resp['tag'] always exists.
621                self.logger.log(logging.DEBUG, "read_resp: tag %d resp %r", resp.tag, resp)
622                self.set_tag(resp.tag, resp)
623
624    def wait_for(self, tag):
625        """
626        Wait for a response to the given tag.  Return the response,
627        releasing the tag.  If self.timeout is not None, wait at most
628        that long (and release the tag even if there's no reply), else
629        wait forever.
630
631        If this returns None, either the tag was bad initially, or
632        a timeout occurred, or the connection got shut down.
633        """
634        self.logger.log(logging.DEBUG, "wait_for: tag %d", tag)
635        if self.timeout is None:
636            deadline = None
637        else:
638            deadline = time.time() + self.timeout
639        with self.lock:
640            while True:
641                # tagstate is True (waiting) or False (timedout) or
642                # a valid response, or None if we've reset the tag
643                # states (retire_all_tags, after connection drop).
644                resp = self.tagstate.get(tag, None)
645                if resp is None:
646                    # out of sync, exit loop
647                    break
648                if resp is True:
649                    # still waiting for a response - wait some more
650                    self.new_replies.wait(self.timeout)
651                    if deadline and time.time() > deadline:
652                        # Halt the waiting, but go around once more.
653                        # Note we may have killed the tag by now though.
654                        if tag in self.tagstate:
655                            self.tagstate[tag] = False
656                    continue
657                # resp is either False (timeout) or a reply.
658                # If resp is False, change it to None; the tag
659                # is now dead until we get a reply (then we
660                # just toss the reply).
661                # Otherwise, we're done with the tag: free it.
662                # In either case, stop now.
663                if resp is False:
664                    resp = None
665                else:
666                    self.tagalloc.free(tag)
667                    del self.tagstate[tag]
668                break
669        return resp
670
671    def badresp(self, req, resp):
672        """
673        Complain that a response was not something expected.
674        """
675        if resp is None:
676            self.shutdown()
677            raise TEError('{0}: {1}: timeout or EOF'.format(self, req))
678        if isinstance(resp, protocol.rrd.Rlerror):
679            raise RemoteError(self, req, None, 'Rlerror', resp.ecode)
680        if isinstance(resp, protocol.rrd.Rerror):
681            if resp.errnum is None:
682                raise RemoteError(self, req, resp.errstr, 'Rerror', None)
683            raise RemoteError(self, req, None, 'Rerror', resp.errnum)
684        raise LocalError('{0}: {1} got response {2!r}'.format(self, req, resp))
685
686    def supports(self, req_code):
687        """
688        Test self.proto.support(req_code) unless we've recorded that
689        while the protocol supports it, the client does not.
690        """
691        return req_code not in self._unsup and self.proto.supports(req_code)
692
693    def supports_all(self, *req_codes):
694        "basically just all(supports(...))"
695        return all(self.supports(code) for code in req_codes)
696
697    def unsupported(self, req_code):
698        """
699        Record an ENOTSUP (RemoteError was ENOTSUP) for a request.
700        Must be called from the op, this does not happen automatically.
701        (It's just an optimization.)
702        """
703        self._unsup[req_code] = True
704
705    def connect(self, server=None, port=None):
706        """
707        Connect to given server/port pair.
708
709        The server and port are remembered.  If given as None,
710        the last remembered values are used.  The initial
711        remembered values are from the creation of this client
712        instance.
713
714        New values are only remembered here on a *successful*
715        connect, however.
716        """
717        if server is None:
718            server = self._server
719            if server is None:
720                raise LocalError('connect: no server specified and no default')
721        if port is None:
722            port = self._port
723            if port is None:
724                port = STD_P9_PORT
725        self.name = None            # wipe out previous name, if any
726        super(P9Client, self).connect(server, port)
727        maxio = self.get_recommended_maxio()
728        self.declare_connected(None, None, maxio)
729        self.proto = self.iproto    # revert to initial protocol
730        self.have_version = False
731        self.rthread = threading.Thread(target=self.read_responses)
732        self.rthread.start()
733        tag = self.get_tag(for_Tversion=True)
734        req = protocol.rrd.Tversion(tag=tag, msize=maxio,
735                                    version=self.get_monkey('version'))
736        super(P9Client, self).write(self.proto.pack_from(req))
737        resp = self.wait_for(tag)
738        if not isinstance(resp, protocol.rrd.Rversion):
739            self.shutdown()
740            if isinstance(resp, protocol.rrd.Rerror):
741                version = req.version or self.proto.get_version()
742                # for python3, we need to convert version to string
743                if not isinstance(version, str):
744                    version = version.decode('utf-8', 'surrogateescape')
745                raise RemoteError(self, 'version ' + version,
746                                  resp.errstr, 'Rerror', None)
747            self.badresp('version', resp)
748        their_maxio = resp.msize
749        try:
750            self.reduce_maxio(their_maxio)
751        except LocalError as err:
752            raise LocalError('{0}: sent maxio={1}, they tried {2}: '
753                             '{3}'.format(self, maxio, their_maxio,
754                                          err.args[0]))
755        if resp.version != self.proto.get_version():
756            if not self.may_downgrade:
757                self.shutdown()
758                raise LocalError('{0}: they only support '
759                                 'version {1!r}'.format(self, resp.version))
760            # raises LocalError if the version is bad
761            # (should we wrap it with a connect-to-{0} msg?)
762            self.proto = self.proto.downgrade_to(resp.version)
763        self._server = server
764        self._port = port
765        self.have_version = True
766
767    def attach(self, afid, uname, aname, n_uname):
768        """
769        Attach.
770
771        Currently we don't know how to do authentication,
772        but we'll pass any provided afid through.
773        """
774        if afid is None:
775            afid = protocol.td.NOFID
776        if uname is None:
777            uname = ''
778        if aname is None:
779            aname = ''
780        if n_uname is None:
781            n_uname = protocol.td.NONUNAME
782        tag = self.get_tag()
783        fid = self.alloc_fid()
784        pkt = self.proto.Tattach(tag=tag, fid=fid, afid=afid,
785                                 uname=uname, aname=aname,
786                                 n_uname=n_uname)
787        super(P9Client, self).write(pkt)
788        resp = self.wait_for(tag)
789        if not isinstance(resp, protocol.rrd.Rattach):
790            self.retire_fid(fid)
791            self.badresp('attach', resp)
792        # probably should check resp.qid
793        self.rootfid = fid
794        self.rootqid = resp.qid
795        self.setpath(fid, b'/')
796
797    def shutdown(self):
798        "disconnect from server"
799        if self.rootfid is not None:
800            self.clunk(self.rootfid, ignore_error=True)
801        self.retire_all_tags()
802        self.retire_all_fids()
803        self.rootfid = None
804        self.rootqid = None
805        super(P9Client, self).shutdown()
806        if self.rthread:
807            self.rthread.join()
808            self.rthread = None
809
810    def dupfid(self, fid):
811        """
812        Copy existing fid to a new fid.
813        """
814        tag = self.get_tag()
815        newfid = self.alloc_fid()
816        pkt = self.proto.Twalk(tag=tag, fid=fid, newfid=newfid, nwname=0,
817                               wname=[])
818        super(P9Client, self).write(pkt)
819        resp = self.wait_for(tag)
820        if not isinstance(resp, protocol.rrd.Rwalk):
821            self.retire_fid(newfid)
822            self.badresp('walk {0}'.format(self.getpathX(fid)), resp)
823        # Copy path too
824        self.setpath(newfid, fid)
825        return newfid
826
827    def lookup(self, fid, components):
828        """
829        Do Twalk.  Caller must provide a starting fid, which should
830        be rootfid to look up from '/' - we do not do / vs . here.
831        Caller must also provide a component-ized path (on purpose,
832        so that caller can provide invalid components like '' or '/').
833        The components must be byte-strings as well, for the same
834        reason.
835
836        We do allocate the new fid ourselves here, though.
837
838        There's no logic here to split up long walks (yet?).
839        """
840        # these are too easy to screw up, so check
841        if self.rootfid is None:
842            raise LocalError('{0}: not attached'.format(self))
843        if (isinstance(components, (str, bytes) or
844            not all(isinstance(i, bytes) for i in components))):
845            raise LocalError('{0}: lookup: invalid '
846                             'components {1!r}'.format(self, components))
847        tag = self.get_tag()
848        newfid = self.alloc_fid()
849        startpath = self.getpath(fid)
850        pkt = self.proto.Twalk(tag=tag, fid=fid, newfid=newfid,
851                               nwname=len(components), wname=components)
852        super(P9Client, self).write(pkt)
853        resp = self.wait_for(tag)
854        if not isinstance(resp, protocol.rrd.Rwalk):
855            self.retire_fid(newfid)
856            self.badresp('walk {0} in '
857                         '{1}'.format(components, self.getpathX(fid)),
858                         resp)
859        # Just because we got Rwalk does not mean we got ALL the
860        # way down the path.  Raise OSError(ENOENT) if we're short.
861        if resp.nwqid > len(components):
862            # ??? this should be impossible. Local error?  Remote error?
863            # OS Error?
864            self.clunk(newfid, ignore_error=True)
865            raise LocalError('{0}: walk {1} in {2} returned {3} '
866                             'items'.format(self, components,
867                                            self.getpathX(fid), resp.nwqid))
868        if resp.nwqid < len(components):
869            self.clunk(newfid, ignore_error=True)
870            # Looking up a/b/c and got just a/b, c is what's missing.
871            # Looking up a/b/c and got just a, b is what's missing.
872            missing = components[resp.nwqid]
873            within = _pathcat(startpath, b'/'.join(components[:resp.nwqid]))
874            raise OSError(errno.ENOENT,
875                          '{0}: {1} in {2}'.format(os.strerror(errno.ENOENT),
876                                                   missing, within))
877        self.setpath(newfid, _pathcat(startpath, b'/'.join(components)))
878        return newfid, resp.wqid
879
880    def lookup_last(self, fid, components):
881        """
882        Like lookup, but return only the last component's qid.
883        As a special case, if components is an empty list, we
884        handle that.
885        """
886        rfid, wqid = self.lookup(fid, components)
887        if len(wqid):
888            return rfid, wqid[-1]
889        if fid == self.rootfid:         # usually true, if we get here at all
890            return rfid, self.rootqid
891        tag = self.get_tag()
892        pkt = self.proto.Tstat(tag=tag, fid=rfid)
893        super(P9Client, self).write(pkt)
894        resp = self.wait_for(tag)
895        if not isinstance(resp, protocol.rrd.Rstat):
896            self.badresp('stat {0}'.format(self.getpathX(fid)), resp)
897        statval = self.proto.unpack_wirestat(resp.data)
898        return rfid, statval.qid
899
900    def clunk(self, fid, ignore_error=False):
901        "issue clunk(fid)"
902        tag = self.get_tag()
903        pkt = self.proto.Tclunk(tag=tag, fid=fid)
904        super(P9Client, self).write(pkt)
905        resp = self.wait_for(tag)
906        if not isinstance(resp, protocol.rrd.Rclunk):
907            if ignore_error:
908                return
909            self.badresp('clunk {0}'.format(self.getpathX(fid)), resp)
910        self.retire_fid(fid)
911
912    def remove(self, fid, ignore_error=False):
913        "issue remove (old style), which also clunks fid"
914        tag = self.get_tag()
915        pkt = self.proto.Tremove(tag=tag, fid=fid)
916        super(P9Client, self).write(pkt)
917        resp = self.wait_for(tag)
918        if not isinstance(resp, protocol.rrd.Rremove):
919            if ignore_error:
920                # remove failed: still need to clunk the fid
921                self.clunk(fid, True)
922                return
923            self.badresp('remove {0}'.format(self.getpathX(fid)), resp)
924        self.retire_fid(fid)
925
926    def create(self, fid, name, perm, mode, filetype=None, extension=b''):
927        """
928        Issue create op (note that this may be mkdir, symlink, etc).
929        fid is the directory in which the create happens, and for
930        regular files, it becomes, on success, a fid referring to
931        the now-open file.  perm is, e.g., 0644, 0755, etc.,
932        optionally with additional high bits.  mode is a mode
933        byte (e.g., protocol.td.ORDWR, or OWRONLY|OTRUNC, etc.).
934
935        As a service to callers, we take two optional arguments
936        specifying the file type ('dir', 'symlink', 'device',
937        'fifo', or 'socket') and additional info if needed.
938        The additional info for a symlink is the target of the
939        link (a byte string), and the additional info for a device
940        is a byte string with "b <major> <minor>" or "c <major> <minor>".
941
942        Otherwise, callers can leave filetype=None and encode the bits
943        into the mode (caller must still provide extension if needed).
944
945        We do NOT check whether the extension matches extra DM bits,
946        or that there's only one DM bit set, or whatever, since this
947        is a testing setup.
948        """
949        tag = self.get_tag()
950        if filetype is not None:
951            perm |= {
952                'dir': protocol.td.DMDIR,
953                'symlink': protocol.td.DMSYMLINK,
954                'device': protocol.td.DMDEVICE,
955                'fifo': protocol.td.DMNAMEDPIPE,
956                'socket': protocol.td.DMSOCKET,
957            }[filetype]
958        pkt = self.proto.Tcreate(tag=tag, fid=fid, name=name,
959            perm=perm, mode=mode, extension=extension)
960        super(P9Client, self).write(pkt)
961        resp = self.wait_for(tag)
962        if not isinstance(resp, protocol.rrd.Rcreate):
963            self.badresp('create {0} in {1}'.format(name, self.getpathX(fid)),
964                         resp)
965        if resp.qid.type == protocol.td.QTFILE:
966            # Creating a regular file opens the file,
967            # thus changing the fid's path.
968            self.setpath(fid, _pathcat(self.getpath(fid), name))
969        return resp.qid, resp.iounit
970
971    def open(self, fid, mode):
972        "use Topen to open file or directory fid (mode is 1 byte)"
973        tag = self.get_tag()
974        pkt = self.proto.Topen(tag=tag, fid=fid, mode=mode)
975        super(P9Client, self).write(pkt)
976        resp = self.wait_for(tag)
977        if not isinstance(resp, protocol.rrd.Ropen):
978            self.badresp('open {0}'.format(self.getpathX(fid)), resp)
979        return resp.qid, resp.iounit
980
981    def lopen(self, fid, flags):
982        "use Tlopen to open file or directory fid (flags from L_O_*)"
983        tag = self.get_tag()
984        pkt = self.proto.Tlopen(tag=tag, fid=fid, flags=flags)
985        super(P9Client, self).write(pkt)
986        resp = self.wait_for(tag)
987        if not isinstance(resp, protocol.rrd.Rlopen):
988            self.badresp('lopen {0}'.format(self.getpathX(fid)), resp)
989        return resp.qid, resp.iounit
990
991    def read(self, fid, offset, count):
992        "read (up to) count bytes from offset, given open fid"
993        tag = self.get_tag()
994        pkt = self.proto.Tread(tag=tag, fid=fid, offset=offset, count=count)
995        super(P9Client, self).write(pkt)
996        resp = self.wait_for(tag)
997        if not isinstance(resp, protocol.rrd.Rread):
998            self.badresp('read {0} bytes at offset {1} in '
999                         '{2}'.format(count, offset, self.getpathX(fid)),
1000                         resp)
1001        return resp.data
1002
1003    def write(self, fid, offset, data):
1004        "write (up to) count bytes to offset, given open fid"
1005        tag = self.get_tag()
1006        pkt = self.proto.Twrite(tag=tag, fid=fid, offset=offset,
1007                                count=len(data), data=data)
1008        super(P9Client, self).write(pkt)
1009        resp = self.wait_for(tag)
1010        if not isinstance(resp, protocol.rrd.Rwrite):
1011            self.badresp('write {0} bytes at offset {1} in '
1012                         '{2}'.format(len(data), offset, self.getpathX(fid)),
1013                         resp)
1014        return resp.count
1015
1016    # Caller may
1017    #  - pass an actual stat object, or
1018    #  - pass in all the individual to-set items by keyword, or
1019    #  - mix and match a bit: get an existing stat, then use
1020    #    keywords to override fields.
1021    # We convert "None"s to the internal "do not change" values,
1022    # and for diagnostic purposes, can turn "do not change" back
1023    # to None at the end, too.
1024    def wstat(self, fid, statobj=None, **kwargs):
1025        if statobj is None:
1026            statobj = protocol.td.stat()
1027        else:
1028            statobj = statobj._copy()
1029        # Fields in stat that you can't send as a wstat: the
1030        # type and qid are informative.  Similarly, the
1031        # 'extension' is an input when creating a file but
1032        # read-only when stat-ing.
1033        #
1034        # It's not clear what it means to set dev, but we'll leave
1035        # it in as an optional parameter here.  fs/backend.c just
1036        # errors out on an attempt to change it.
1037        if self.proto == protocol.plain:
1038            forbid = ('type', 'qid', 'extension',
1039                      'n_uid', 'n_gid', 'n_muid')
1040        else:
1041            forbid = ('type', 'qid', 'extension')
1042        nochange = {
1043            'type': 0,
1044            'qid': protocol.td.qid(0, 0, 0),
1045            'dev': 2**32 - 1,
1046            'mode': 2**32 - 1,
1047            'atime': 2**32 - 1,
1048            'mtime': 2**32 - 1,
1049            'length': 2**64 - 1,
1050            'name': b'',
1051            'uid': b'',
1052            'gid': b'',
1053            'muid': b'',
1054            'extension': b'',
1055            'n_uid': 2**32 - 1,
1056            'n_gid': 2**32 - 1,
1057            'n_muid': 2**32 - 1,
1058        }
1059        for field in statobj._fields:
1060            if field in kwargs:
1061                if field in forbid:
1062                    raise ValueError('cannot wstat a stat.{0}'.format(field))
1063                statobj[field] = kwargs.pop(field)
1064            else:
1065                if field in forbid or statobj[field] is None:
1066                    statobj[field] = nochange[field]
1067        if kwargs:
1068            raise TypeError('wstat() got an unexpected keyword argument '
1069                            '{0!r}'.format(kwargs.popitem()))
1070
1071        data = self.proto.pack_wirestat(statobj)
1072        tag = self.get_tag()
1073        pkt = self.proto.Twstat(tag=tag, fid=fid, data=data)
1074        super(P9Client, self).write(pkt)
1075        resp = self.wait_for(tag)
1076        if not isinstance(resp, protocol.rrd.Rwstat):
1077            # For error viewing, switch all the do-not-change
1078            # and can't-change fields to None.
1079            statobj.qid = None
1080            for field in statobj._fields:
1081                if field in forbid:
1082                    statobj[field] = None
1083                elif field in nochange and statobj[field] == nochange[field]:
1084                    statobj[field] = None
1085            self.badresp('wstat {0}={1}'.format(self.getpathX(fid), statobj),
1086                         resp)
1087        # wstat worked - change path names if needed
1088        if statobj.name != b'':
1089            self.did_rename(fid, statobj.name)
1090
1091    def readdir(self, fid, offset, count):
1092        "read (up to) count bytes of dir data from offset, given open fid"
1093        tag = self.get_tag()
1094        pkt = self.proto.Treaddir(tag=tag, fid=fid, offset=offset, count=count)
1095        super(P9Client, self).write(pkt)
1096        resp = self.wait_for(tag)
1097        if not isinstance(resp, protocol.rrd.Rreaddir):
1098            self.badresp('readdir {0} bytes at offset {1} in '
1099                         '{2}'.format(count, offset, self.getpathX(fid)),
1100                         resp)
1101        return resp.data
1102
1103    def rename(self, fid, dfid, name):
1104        "invoke Trename: rename file <fid> to <dfid>/name"
1105        tag = self.get_tag()
1106        pkt = self.proto.Trename(tag=tag, fid=fid, dfid=dfid, name=name)
1107        super(P9Client, self).write(pkt)
1108        resp = self.wait_for(tag)
1109        if not isinstance(resp, protocol.rrd.Rrename):
1110            self.badresp('rename {0} to {2} in '
1111                         '{1}'.format(self.getpathX(fid),
1112                                      self.getpathX(dfid), name),
1113                         resp)
1114        self.did_rename(fid, name, self.getpath(dfid))
1115
1116    def renameat(self, olddirfid, oldname, newdirfid, newname):
1117        "invoke Trenameat: rename <olddirfid>/oldname to <newdirfid>/newname"
1118        tag = self.get_tag()
1119        pkt = self.proto.Trenameat(tag=tag,
1120                                   olddirfid=olddirfid, oldname=oldname,
1121                                   newdirfid=newdirfid, newname=newname)
1122        super(P9Client, self).write(pkt)
1123        resp = self.wait_for(tag)
1124        if not isinstance(resp, protocol.rrd.Rrenameat):
1125            self.badresp('rename {1} in {0} to {3} in '
1126                         '{2}'.format(oldname, self.getpathX(olddirfid),
1127                                      newname, self.getpathX(newdirdfid)),
1128                         resp)
1129        # There's no renamed *fid*, just a renamed file!  So no
1130        # call to self.did_rename().
1131
1132    def unlinkat(self, dirfd, name, flags):
1133        "invoke Tunlinkat - flags should be 0 or protocol.td.AT_REMOVEDIR"
1134        tag = self.get_tag()
1135        pkt = self.proto.Tunlinkat(tag=tag, dirfd=dirfd,
1136                                   name=name, flags=flags)
1137        super(P9Client, self).write(pkt)
1138        resp = self.wait_for(tag)
1139        if not isinstance(resp, protocol.rrd.Runlinkat):
1140            self.badresp('unlinkat {0} in '
1141                         '{1}'.format(name, self.getpathX(dirfd)), resp)
1142
1143    def decode_stat_objects(self, bstring, noerror=False):
1144        """
1145        Read on a directory returns an array of stat objects.
1146        Note that for .u these encode extra data.
1147
1148        It's possible for this to produce a SequenceError, if
1149        the data are incorrect, unless you pass noerror=True.
1150        """
1151        objlist = []
1152        offset = 0
1153        while offset < len(bstring):
1154            obj, offset = self.proto.unpack_wirestat(bstring, offset, noerror)
1155            objlist.append(obj)
1156        return objlist
1157
1158    def decode_readdir_dirents(self, bstring, noerror=False):
1159        """
1160        Readdir on a directory returns an array of dirent objects.
1161
1162        It's possible for this to produce a SequenceError, if
1163        the data are incorrect, unless you pass noerror=True.
1164        """
1165        objlist = []
1166        offset = 0
1167        while offset < len(bstring):
1168            obj, offset = self.proto.unpack_dirent(bstring, offset, noerror)
1169            objlist.append(obj)
1170        return objlist
1171
1172    def lcreate(self, fid, name, lflags, mode, gid):
1173        "issue lcreate (.L)"
1174        tag = self.get_tag()
1175        pkt = self.proto.Tlcreate(tag=tag, fid=fid, name=name,
1176                                  flags=lflags, mode=mode, gid=gid)
1177        super(P9Client, self).write(pkt)
1178        resp = self.wait_for(tag)
1179        if not isinstance(resp, protocol.rrd.Rlcreate):
1180            self.badresp('create {0} in '
1181                         '{1}'.format(name, self.getpathX(fid)), resp)
1182        # Creating a file opens the file,
1183        # thus changing the fid's path.
1184        self.setpath(fid, _pathcat(self.getpath(fid), name))
1185        return resp.qid, resp.iounit
1186
1187    def mkdir(self, dfid, name, mode, gid):
1188        "issue mkdir (.L)"
1189        tag = self.get_tag()
1190        pkt = self.proto.Tmkdir(tag=tag, dfid=dfid, name=name,
1191                                mode=mode, gid=gid)
1192        super(P9Client, self).write(pkt)
1193        resp = self.wait_for(tag)
1194        if not isinstance(resp, protocol.rrd.Rmkdir):
1195            self.badresp('mkdir {0} in '
1196                         '{1}'.format(name, self.getpathX(dfid)), resp)
1197        return resp.qid
1198
1199    # We don't call this getattr(), for the obvious reason.
1200    def Tgetattr(self, fid, request_mask=protocol.td.GETATTR_ALL):
1201        "issue Tgetattr.L - get what you ask for, or everything by default"
1202        tag = self.get_tag()
1203        pkt = self.proto.Tgetattr(tag=tag, fid=fid, request_mask=request_mask)
1204        super(P9Client, self).write(pkt)
1205        resp = self.wait_for(tag)
1206        if not isinstance(resp, protocol.rrd.Rgetattr):
1207            self.badresp('Tgetattr {0} of '
1208                         '{1}'.format(request_mask, self.getpathX(fid)), resp)
1209        attrs = Fileattrs()
1210        # Handle the simplest valid-bit tests:
1211        for name in ('mode', 'nlink', 'uid', 'gid', 'rdev',
1212                     'size', 'blocks', 'gen', 'data_version'):
1213            bit = getattr(protocol.td, 'GETATTR_' + name.upper())
1214            if resp.valid & bit:
1215                attrs[name] = resp[name]
1216        # Handle the timestamps, which are timespec pairs
1217        for name in ('atime', 'mtime', 'ctime', 'btime'):
1218            bit = getattr(protocol.td, 'GETATTR_' + name.upper())
1219            if resp.valid & bit:
1220                attrs[name] = Timespec(sec=resp[name + '_sec'],
1221                                       nsec=resp[name + '_nsec'])
1222        # There is no control bit for blksize; qemu and Linux always
1223        # provide one.
1224        attrs.blksize = resp.blksize
1225        # Handle ino, which comes out of qid.path
1226        if resp.valid & protocol.td.GETATTR_INO:
1227            attrs.ino = resp.qid.path
1228        return attrs
1229
1230    # We don't call this setattr(), for the obvious reason.
1231    # See wstat for usage.  Note that time fields can be set
1232    # with either second or nanosecond resolutions, and some
1233    # can be set without supplying an actual timestamp, so
1234    # this is all pretty ad-hoc.
1235    #
1236    # There's also one keyword-only argument, ctime=<anything>,
1237    # which means "set SETATTR_CTIME".  This has the same effect
1238    # as supplying valid=protocol.td.SETATTR_CTIME.
1239    def Tsetattr(self, fid, valid=0, attrs=None, **kwargs):
1240        if attrs is None:
1241            attrs = Fileattrs()
1242        else:
1243            attrs = attrs._copy()
1244
1245        # Start with an empty (all-zero) Tsetattr instance.  We
1246        # don't really need to zero out tag and fid, but it doesn't
1247        # hurt.  Note that if caller says, e.g., valid=SETATTR_SIZE
1248        # but does not supply an incoming size (via "attrs" or a size=
1249        # argument), we'll ask to set that field to 0.
1250        attrobj = protocol.rrd.Tsetattr()
1251        for field in attrobj._fields:
1252            attrobj[field] = 0
1253
1254        # In this case, forbid means "only as kwargs": these values
1255        # in an incoming attrs object are merely ignored.
1256        forbid = ('ino', 'nlink', 'rdev', 'blksize', 'blocks', 'btime',
1257                  'gen', 'data_version')
1258        for field in attrs._fields:
1259            if field in kwargs:
1260                if field in forbid:
1261                    raise ValueError('cannot Tsetattr {0}'.format(field))
1262                attrs[field] = kwargs.pop(field)
1263            elif attrs[field] is None:
1264                continue
1265            # OK, we're setting this attribute.  Many are just
1266            # numeric - if that's the case, we're good, set the
1267            # field and the appropriate bit.
1268            bitname = 'SETATTR_' + field.upper()
1269            bit = getattr(protocol.td, bitname)
1270            if field in ('mode', 'uid', 'gid', 'size'):
1271                valid |= bit
1272                attrobj[field] = attrs[field]
1273                continue
1274            # Timestamps are special:  The value may be given as
1275            # an integer (seconds), or as a float (we convert to
1276            # (we convert to sec+nsec), or as a timespec (sec+nsec).
1277            # If specified as 0, we mean "we are not providing the
1278            # actual time, use the server's time."
1279            #
1280            # The ctime field's value, if any, is *ignored*.
1281            if field in ('atime', 'mtime'):
1282                value = attrs[field]
1283                if hasattr(value, '__len__'):
1284                    if len(value) != 2:
1285                        raise ValueError('invalid {0}={1!r}'.format(field,
1286                                                                    value))
1287                    sec = value[0]
1288                    nsec = value[1]
1289                else:
1290                    sec = value
1291                    if isinstance(sec, float):
1292                        nsec, sec = math.modf(sec)
1293                        nsec = int(round(nsec * 1000000000))
1294                    else:
1295                        nsec = 0
1296                valid |= bit
1297                attrobj[field + '_sec'] = sec
1298                attrobj[field + '_nsec'] = nsec
1299                if sec != 0 or nsec != 0:
1300                    # Add SETATTR_ATIME_SET or SETATTR_MTIME_SET
1301                    # as appropriate, to tell the server to *this
1302                    # specific* time, instead of just "server now".
1303                    bit = getattr(protocol.td, bitname + '_SET')
1304                    valid |= bit
1305        if 'ctime' in kwargs:
1306            kwargs.pop('ctime')
1307            valid |= protocol.td.SETATTR_CTIME
1308        if kwargs:
1309            raise TypeError('Tsetattr() got an unexpected keyword argument '
1310                            '{0!r}'.format(kwargs.popitem()))
1311
1312        tag = self.get_tag()
1313        attrobj.valid = valid
1314        attrobj.tag = tag
1315        attrobj.fid = fid
1316        pkt = self.proto.pack(attrobj)
1317        super(P9Client, self).write(pkt)
1318        resp = self.wait_for(tag)
1319        if not isinstance(resp, protocol.rrd.Rsetattr):
1320            self.badresp('Tsetattr {0} {1} of '
1321                         '{2}'.format(valid, attrs, self.getpathX(fid)), resp)
1322
1323    def xattrwalk(self, fid, name=None):
1324        "walk one name or all names: caller should read() the returned fid"
1325        tag = self.get_tag()
1326        newfid = self.alloc_fid()
1327        pkt = self.proto.Txattrwalk(tag=tag, fid=fid, newfid=newfid,
1328                                    name=name or '')
1329        super(P9Client, self).write(pkt)
1330        resp = self.wait_for(tag)
1331        if not isinstance(resp, protocol.rrd.Rxattrwalk):
1332            self.retire_fid(newfid)
1333            self.badresp('Txattrwalk {0} of '
1334                         '{1}'.format(name, self.getpathX(fid)), resp)
1335        if name:
1336            self.setpath(newfid, 'xattr:' + name)
1337        else:
1338            self.setpath(newfid, 'xattr')
1339        return newfid, resp.size
1340
1341    def _pathsplit(self, path, startdir, allow_empty=False):
1342        "common code for uxlookup and uxopen"
1343        if self.rootfid is None:
1344            raise LocalError('{0}: not attached'.format(self))
1345        if path.startswith(b'/') or startdir is None:
1346            startdir = self.rootfid
1347        components = [i for i in path.split(b'/') if i != b'']
1348        if len(components) == 0 and not allow_empty:
1349            raise LocalError('{0}: {1!r}: empty path'.format(self, path))
1350        return components, startdir
1351
1352    def uxlookup(self, path, startdir=None):
1353        """
1354        Unix-style lookup.  That is, lookup('/foo/bar') or
1355        lookup('foo/bar').  If startdir is not None and the
1356        path does not start with '/' we look up from there.
1357        """
1358        components, startdir = self._pathsplit(path, startdir, allow_empty=True)
1359        return self.lookup_last(startdir, components)
1360
1361    def uxopen(self, path, oflags=0, perm=None, gid=None,
1362               startdir=None, filetype=None):
1363        """
1364        Unix-style open()-with-option-to-create, or mkdir().
1365        oflags is 0/1/2 with optional os.O_CREAT, perm defaults
1366        to 0o666 (files) or 0o777 (directories).  If we use
1367        a Linux create or mkdir op, we will need a gid, but it's
1368        not required if you are opening an existing file.
1369
1370        Adds a final boolean value for "did we actually create".
1371        Raises OSError if you ask for a directory but it's a file,
1372        or vice versa.  (??? reconsider this later)
1373
1374        Note that this does not handle other file types, only
1375        directories.
1376        """
1377        needtype = {
1378            'dir': protocol.td.QTDIR,
1379            None: protocol.td.QTFILE,
1380        }[filetype]
1381        omode_byte = oflags & 3 # cheating
1382        # allow looking up /, but not creating /
1383        allow_empty = (oflags & os.O_CREAT) == 0
1384        components, startdir = self._pathsplit(path, startdir,
1385                                               allow_empty=allow_empty)
1386        if not (oflags & os.O_CREAT):
1387            # Not creating, i.e., just look up and open existing file/dir.
1388            fid, qid = self.lookup_last(startdir, components)
1389            # If we got this far, use Topen on the fid; we did not
1390            # create the file.
1391            return self._uxopen2(path, needtype, fid, qid, omode_byte, False)
1392
1393        # Only used if using dot-L, but make sure it's always provided
1394        # since this is generic.
1395        if gid is None:
1396            raise ValueError('gid is required when creating file or dir')
1397
1398        if len(components) > 1:
1399            # Look up all but last component; this part must succeed.
1400            fid, _ = self.lookup(startdir, components[:-1])
1401
1402            # Now proceed with the final component, using fid
1403            # as the start dir.  Remember to clunk it!
1404            startdir = fid
1405            clunk_startdir = True
1406            components = components[-1:]
1407        else:
1408            # Use startdir as the start dir, and get a new fid.
1409            # Do not clunk startdir!
1410            clunk_startdir = False
1411            fid = self.alloc_fid()
1412
1413        # Now look up the (single) component.  If this fails,
1414        # assume the file or directory needs to be created.
1415        tag = self.get_tag()
1416        pkt = self.proto.Twalk(tag=tag, fid=startdir, newfid=fid,
1417                               nwname=1, wname=components)
1418        super(P9Client, self).write(pkt)
1419        resp = self.wait_for(tag)
1420        if isinstance(resp, protocol.rrd.Rwalk):
1421            if clunk_startdir:
1422                self.clunk(startdir, ignore_error=True)
1423            # fid successfully walked to refer to final component.
1424            # Just need to actually open the file.
1425            self.setpath(fid, _pathcat(self.getpath(startdir), components[0]))
1426            qid = resp.wqid[0]
1427            return self._uxopen2(needtype, fid, qid, omode_byte, False)
1428
1429        # Walk failed.  If we allocated a fid, retire it.  Then set
1430        # up a fid that points to the parent directory in which to
1431        # create the file or directory.  Note that if we're creating
1432        # a file, this fid will get changed so that it points to the
1433        # file instead of the directory, but if we're creating a
1434        # directory, it will be unchanged.
1435        if fid != startdir:
1436            self.retire_fid(fid)
1437        fid = self.dupfid(startdir)
1438
1439        try:
1440            qid, iounit = self._uxcreate(filetype, fid, components[0],
1441                                         oflags, omode_byte, perm, gid)
1442
1443            # Success.  If we created an ordinary file, we have everything
1444            # now as create alters the incoming (dir) fid to open the file.
1445            # Otherwise (mkdir), we need to open the file, as with
1446            # a successful lookup.
1447            #
1448            # Note that qid type should match "needtype".
1449            if filetype != 'dir':
1450                if qid.type == needtype:
1451                    return fid, qid, iounit, True
1452                self.clunk(fid, ignore_error=True)
1453                raise OSError(_wrong_file_type(qid),
1454                             '{0}: server told to create {1} but '
1455                             'created {2} instead'.format(path,
1456                                                          qt2n(needtype),
1457                                                          qt2n(qid.type)))
1458
1459            # Success: created dir; but now need to walk to and open it.
1460            fid = self.alloc_fid()
1461            tag = self.get_tag()
1462            pkt = self.proto.Twalk(tag=tag, fid=startdir, newfid=fid,
1463                                   nwname=1, wname=components)
1464            super(P9Client, self).write(pkt)
1465            resp = self.wait_for(tag)
1466            if not isinstance(resp, protocol.rrd.Rwalk):
1467                self.clunk(fid, ignore_error=True)
1468                raise OSError(errno.ENOENT,
1469                              '{0}: server made dir but then failed to '
1470                              'find it again'.format(path))
1471                self.setpath(fid, _pathcat(self.getpath(fid), components[0]))
1472            return self._uxopen2(needtype, fid, qid, omode_byte, True)
1473        finally:
1474            # Regardless of success/failure/exception, make sure
1475            # we clunk startdir if needed.
1476            if clunk_startdir:
1477                self.clunk(startdir, ignore_error=True)
1478
1479    def _uxcreate(self, filetype, fid, name, oflags, omode_byte, perm, gid):
1480        """
1481        Helper for creating dir-or-file.  The fid argument is the
1482        parent directory on input, but will point to the file (if
1483        we're creating a file) on return.  oflags only applies if
1484        we're creating a file (even then we use omode_byte if we
1485        are using the plan9 create op).
1486        """
1487        # Try to create or mkdir as appropriate.
1488        if self.supports_all(protocol.td.Tlcreate, protocol.td.Tmkdir):
1489            # Use Linux style create / mkdir.
1490            if filetype == 'dir':
1491                if perm is None:
1492                    perm = 0o777
1493                return self.mkdir(startdir, name, perm, gid), None
1494            if perm is None:
1495                perm = 0o666
1496            lflags = flags_to_linux_flags(oflags)
1497            return self.lcreate(fid, name, lflags, perm, gid)
1498
1499        if filetype == 'dir':
1500            if perm is None:
1501                perm = protocol.td.DMDIR | 0o777
1502            else:
1503                perm |= protocol.td.DMDIR
1504        else:
1505            if perm is None:
1506                perm = 0o666
1507        return self.create(fid, name, perm, omode_byte)
1508
1509    def _uxopen2(self, needtype, fid, qid, omode_byte, didcreate):
1510        "common code for finishing up uxopen"
1511        if qid.type != needtype:
1512            self.clunk(fid, ignore_error=True)
1513            raise OSError(_wrong_file_type(qid),
1514                          '{0}: is {1}, expected '
1515                          '{2}'.format(path, qt2n(qid.type), qt2n(needtype)))
1516        qid, iounit = self.open(fid, omode_byte)
1517        # ? should we re-check qid? it should not have changed
1518        return fid, qid, iounit, didcreate
1519
1520    def uxmkdir(self, path, perm, gid, startdir=None):
1521        """
1522        Unix-style mkdir.
1523
1524        The gid is only applied if we are using .L style mkdir.
1525        """
1526        components, startdir = self._pathsplit(path, startdir)
1527        clunkme = None
1528        if len(components) > 1:
1529            fid, _ = self.lookup(startdir, components[:-1])
1530            startdir = fid
1531            clunkme = fid
1532            components = components[-1:]
1533        try:
1534            if self.supports(protocol.td.Tmkdir):
1535                qid = self.mkdir(startdir, components[0], perm, gid)
1536            else:
1537                qid, _ = self.create(startdir, components[0],
1538                                     protocol.td.DMDIR | perm,
1539                                     protocol.td.OREAD)
1540                # Should we chown/chgrp the dir?
1541        finally:
1542            if clunkme:
1543                self.clunk(clunkme, ignore_error=True)
1544        return qid
1545
1546    def uxreaddir(self, path, startdir=None, no_dotl=False):
1547        """
1548        Read a directory to get a list of names (which may or may not
1549        include '.' and '..').
1550
1551        If no_dotl is True (or anything non-false-y), this uses the
1552        plain or .u readdir format, otherwise it uses dot-L readdir
1553        if possible.
1554        """
1555        components, startdir = self._pathsplit(path, startdir, allow_empty=True)
1556        fid, qid = self.lookup_last(startdir, components)
1557        try:
1558            if qid.type != protocol.td.QTDIR:
1559                raise OSError(errno.ENOTDIR,
1560                              '{0}: {1}'.format(self.getpathX(fid),
1561                                                os.strerror(errno.ENOTDIR)))
1562            # We need both Tlopen and Treaddir to use Treaddir.
1563            if not self.supports_all(protocol.td.Tlopen, protocol.td.Treaddir):
1564                no_dotl = True
1565            if no_dotl:
1566                statvals = self.uxreaddir_stat_fid(fid)
1567                return [i.name for i in statvals]
1568
1569            dirents = self.uxreaddir_dotl_fid(fid)
1570            return [dirent.name for dirent in dirents]
1571        finally:
1572            self.clunk(fid, ignore_error=True)
1573
1574    def uxreaddir_stat(self, path, startdir=None):
1575        """
1576        Use directory read to get plan9 style stat data (plain or .u readdir).
1577
1578        Note that this gets a fid, then opens it, reads, then clunks
1579        the fid.  If you already have a fid, you may want to use
1580        uxreaddir_stat_fid (but note that this opens, yet does not
1581        clunk, the fid).
1582
1583        We return the qid plus the list of the contents.  If the
1584        target is not a directory, the qid will not have type QTDIR
1585        and the contents list will be empty.
1586
1587        Raises OSError if this is applied to a non-directory.
1588        """
1589        components, startdir = self._pathsplit(path, startdir)
1590        fid, qid = self.lookup_last(startdir, components)
1591        try:
1592            if qid.type != protocol.td.QTDIR:
1593                raise OSError(errno.ENOTDIR,
1594                              '{0}: {1}'.format(self.getpathX(fid),
1595                                                os.strerror(errno.ENOTDIR)))
1596            statvals = self.ux_readdir_stat_fid(fid)
1597            return qid, statvals
1598        finally:
1599            self.clunk(fid, ignore_error=True)
1600
1601    def uxreaddir_stat_fid(self, fid):
1602        """
1603        Implement readdir loop that extracts stat values.
1604        This opens, but does not clunk, the given fid.
1605
1606        Unlike uxreaddir_stat(), if this is applied to a file,
1607        rather than a directory, it just returns no entries.
1608        """
1609        statvals = []
1610        qid, iounit = self.open(fid, protocol.td.OREAD)
1611        # ?? is a zero iounit allowed? if so, what do we use here?
1612        if qid.type == protocol.td.QTDIR:
1613            if iounit <= 0:
1614                iounit = 512 # probably good enough
1615            offset = 0
1616            while True:
1617                bstring = self.read(fid, offset, iounit)
1618                if bstring == b'':
1619                    break
1620                statvals.extend(self.decode_stat_objects(bstring))
1621                offset += len(bstring)
1622        return statvals
1623
1624    def uxreaddir_dotl_fid(self, fid):
1625        """
1626        Implement readdir loop that uses dot-L style dirents.
1627        This opens, but does not clunk, the given fid.
1628
1629        If applied to a file, the lopen should fail, because of the
1630        L_O_DIRECTORY flag.
1631        """
1632        dirents = []
1633        qid, iounit = self.lopen(fid, protocol.td.OREAD |
1634                                      protocol.td.L_O_DIRECTORY)
1635        # ?? is a zero iounit allowed? if so, what do we use here?
1636        # but, we want a minimum of over 256 anyway, let's go for 512
1637        if iounit < 512:
1638            iounit = 512
1639        offset = 0
1640        while True:
1641            bstring = self.readdir(fid, offset, iounit)
1642            if bstring == b'':
1643                break
1644            ents = self.decode_readdir_dirents(bstring)
1645            if len(ents) == 0:
1646                break               # ???
1647            dirents.extend(ents)
1648            offset = ents[-1].offset
1649        return dirents
1650
1651    def uxremove(self, path, startdir=None, filetype=None,
1652                 force=False, recurse=False):
1653        """
1654        Implement rm / rmdir, with optional -rf.
1655        if filetype is None, remove dir or file.  If 'dir' or 'file'
1656        remove only if it's one of those.  If force is set, ignore
1657        failures to remove.  If recurse is True, remove contents of
1658        directories (recursively).
1659
1660        File type mismatches (when filetype!=None) raise OSError (?).
1661        """
1662        components, startdir = self._pathsplit(path, startdir, allow_empty=True)
1663        # Look up all components. If
1664        # we get an error we'll just assume the file does not
1665        # exist (is this good?).
1666        try:
1667            fid, qid = self.lookup_last(startdir, components)
1668        except RemoteError:
1669            return
1670        if qid.type == protocol.td.QTDIR:
1671            # it's a directory, remove only if allowed.
1672            # Note that we must check for "rm -r /" (len(components)==0).
1673            if filetype == 'file':
1674                self.clunk(fid, ignore_error=True)
1675                raise OSError(_wrong_file_type(qid),
1676                              '{0}: is dir, expected file'.format(path))
1677            isroot = len(components) == 0
1678            closer = self.clunk if isroot else self.remove
1679            if recurse:
1680                # NB: _rm_recursive does not clunk fid
1681                self._rm_recursive(fid, filetype, force)
1682            # This will fail if the directory is non-empty, unless of
1683            # course we tell it to ignore error.
1684            closer(fid, ignore_error=force)
1685            return
1686        # Not a directory, call it a file (even if socket or fifo etc).
1687        if filetype == 'dir':
1688            self.clunk(fid, ignore_error=True)
1689            raise OSError(_wrong_file_type(qid),
1690                          '{0}: is file, expected dir'.format(path))
1691        self.remove(fid, ignore_error=force)
1692
1693    def _rm_file_by_dfid(self, dfid, name, force=False):
1694        """
1695        Remove a file whose name is <name> (no path, just a component
1696        name) whose parent directory is <dfid>.  We may assume that the
1697        file really is a file (or a socket, or fifo, or some such, but
1698        definitely not a directory).
1699
1700        If force is set, ignore failures.
1701        """
1702        # If we have unlinkat, that's the fast way.  But it may
1703        # return an ENOTSUP error.  If it does we shouldn't bother
1704        # doing this again.
1705        if self.supports(protocol.td.Tunlinkat):
1706            try:
1707                self.unlinkat(dfid, name, 0)
1708                return
1709            except RemoteError as err:
1710                if not err.is_ENOTSUP():
1711                    raise
1712                self.unsupported(protocol.td.Tunlinkat)
1713                # fall through to remove() op
1714        # Fall back to lookup + remove.
1715        try:
1716            fid, qid = self.lookup_last(dfid, [name])
1717        except RemoteError:
1718            # If this has an errno we could tell ENOENT from EPERM,
1719            # and actually raise an error for the latter.  Should we?
1720            return
1721        self.remove(fid, ignore_error=force)
1722
1723    def _rm_recursive(self, dfid, filetype, force):
1724        """
1725        Recursively remove a directory.  filetype is probably None,
1726        but if it's 'dir' we fail if the directory contains non-dir
1727        files.
1728
1729        If force is set, ignore failures.
1730
1731        Although we open dfid (via the readdir.*_fid calls) we
1732        do not clunk it here; that's the caller's job.
1733        """
1734        # first, remove contents
1735        if self.supports_all(protocol.td.Tlopen, protocol.td.Treaddir):
1736            for entry in self.uxreaddir_dotl_fid(dfid):
1737                if entry.name in (b'.', b'..'):
1738                    continue
1739                fid, qid = self.lookup(dfid, [entry.name])
1740                try:
1741                    attrs = self.Tgetattr(fid, protocol.td.GETATTR_MODE)
1742                    if stat.S_ISDIR(attrs.mode):
1743                        self.uxremove(entry.name, dfid, filetype, force, True)
1744                    else:
1745                        self.remove(fid)
1746                        fid = None
1747                finally:
1748                    if fid is not None:
1749                        self.clunk(fid, ignore_error=True)
1750        else:
1751            for statobj in self.uxreaddir_stat_fid(dfid):
1752                # skip . and ..
1753                name = statobj.name
1754                if name in (b'.', b'..'):
1755                    continue
1756                if statobj.qid.type == protocol.td.QTDIR:
1757                    self.uxremove(name, dfid, filetype, force, True)
1758                else:
1759                    self._rm_file_by_dfid(dfid, name, force)
1760
1761def _wrong_file_type(qid):
1762    "return EISDIR or ENOTDIR for passing to OSError"
1763    if qid.type == protocol.td.QTDIR:
1764        return errno.EISDIR
1765    return errno.ENOTDIR
1766
1767def flags_to_linux_flags(flags):
1768    """
1769    Convert OS flags (O_CREAT etc) to Linux flags (protocol.td.L_O_CREAT etc).
1770    """
1771    flagmap = {
1772        os.O_CREAT: protocol.td.L_O_CREAT,
1773        os.O_EXCL: protocol.td.L_O_EXCL,
1774        os.O_NOCTTY: protocol.td.L_O_NOCTTY,
1775        os.O_TRUNC: protocol.td.L_O_TRUNC,
1776        os.O_APPEND: protocol.td.L_O_APPEND,
1777        os.O_DIRECTORY: protocol.td.L_O_DIRECTORY,
1778    }
1779
1780    result = flags & os.O_RDWR
1781    flags &= ~os.O_RDWR
1782    for key, value in flagmap.iteritems():
1783        if flags & key:
1784            result |= value
1785            flags &= ~key
1786    if flags:
1787        raise ValueError('untranslated bits 0x{0:x} in os flags'.format(flags))
1788    return result
1789