diff --git a/example/python/app.py b/example/python/app.py index a9e9d3b7..2875fcd6 100644 --- a/example/python/app.py +++ b/example/python/app.py @@ -41,7 +41,7 @@ class CounterAppContext(): txByteArray = bytearray(txBytes) if len(txBytes) >= 2 and txBytes[:2] == "0x": txByteArray = hex2bytes(txBytes[2:]) - txValue = decode_big_endian(BytesReader(txByteArray), len(txBytes)) + txValue = decode_big_endian(BytesBuffer(txByteArray), len(txBytes)) if txValue != self.txCount: return None, 1 self.txCount += 1 diff --git a/example/python/tmsp/reader.py b/example/python/tmsp/reader.py index f1b3dfae..3b1f87fc 100644 --- a/example/python/tmsp/reader.py +++ b/example/python/tmsp/reader.py @@ -1,14 +1,32 @@ # Simple read() method around a bytearray -class BytesReader(): +class BytesBuffer(): def __init__(self, b): self.buf = b + self.readCount = 0 + + def count(self): + return self.readCount + + def reset_count(self): + self.readCount = 0 + + def size(self): + return len(self.buf) + + def peek(self): + return self.buf[0] + + def write(self, b): + # b should be castable to byte array + self.buf += bytearray(b) def read(self, n): if len(self.buf) < n: print "reader err: buf less than n" # TODO: exception return + self.readCount += n r = self.buf[:n] self.buf = self.buf[n:] return r diff --git a/example/python/tmsp/server.py b/example/python/tmsp/server.py index e25e4e1f..0beb59d1 100644 --- a/example/python/tmsp/server.py +++ b/example/python/tmsp/server.py @@ -8,11 +8,29 @@ from wire import * from reader import * from msg import * +# hold the asyncronous state of a connection +# ie. we may not get enough bytes on one read to decode the message +class Connection(): + def __init__(self, fd, appCtx): + self.fd = fd + self.appCtx = appCtx + self.recBuf = BytesBuffer(bytearray()) + self.resBuf = BytesBuffer(bytearray()) + self.msgLength = 0 + self.decoder = RequestDecoder(self.recBuf) + self.inProgress = False # are we in the middle of a message + + def recv(this): + data = this.fd.recv(1024) + if not data: # what about len(data) == 0 + raise IOError("dead connection") + this.recBuf.write(data) + # TMSP server responds to messges by calling methods on the app class TMSPServer(): def __init__(self, app, port=5410): self.app = app - self.appMap = {} # map conn file descriptors to (appContext, msgDecoder) + self.appMap = {} # map conn file descriptors to (appContext, reqBuf, resBuf, msgDecoder) self.port = port self.listen_backlog = 10 @@ -31,12 +49,13 @@ class TMSPServer(): def handle_new_connection(self, r): new_fd, new_addr = r.accept() + new_fd.setblocking(0) # non-blocking self.read_list.append(new_fd) self.write_list.append(new_fd) print 'new connection to', new_addr appContext = self.app.open() - self.appMap[new_fd] = (appContext, RequestDecoder(ConnReader(new_fd))) + self.appMap[new_fd] = Connection(new_fd, appContext) def handle_conn_closed(self, r): self.read_list.remove(r) @@ -45,25 +64,63 @@ class TMSPServer(): print "connection closed" def handle_recv(self, r): - appCtx, conn = self.appMap[r] - response = bytearray() +# appCtx, recBuf, resBuf, conn + conn = self.appMap[r] while True: try: - # first read the request type and get the msg decoder - typeByte = conn.reader.read(1) + print "recv loop" + # check if we need more data first + if conn.inProgress: + if conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength: + conn.recv() + else: + if conn.recBuf.size() == 0: + conn.recv() + + conn.inProgress = True + + # see if we have enough to get the message length + if conn.msgLength == 0: + ll = conn.recBuf.peek() + if conn.recBuf.size() < 1 + ll: + # we don't have enough bytes to read the length yet + return + print "decoding msg length" + conn.msgLength = decode_varint(conn.recBuf) + + # see if we have enough to decode the message + if conn.recBuf.size() < conn.msgLength: + return + + # now we can decode the message + + # first read the request type and get the particular msg decoder + typeByte = conn.recBuf.read(1) typeByte = int(typeByte[0]) resTypeByte = typeByte+0x10 req_type = message_types[typeByte] if req_type == "flush": - response += bytearray([resTypeByte]) - sent = r.send(str(response)) + # messages are length prefixed + conn.resBuf.write(encode(1)) + conn.resBuf.write([resTypeByte]) + sent = conn.fd.send(str(conn.resBuf.buf)) + conn.msgLength = 0 + conn.inProgress = False + conn.resBuf = BytesBuffer(bytearray()) return - decoder = getattr(conn, req_type) + decoder = getattr(conn.decoder, req_type) + print "decoding args" req_args = decoder() - req_f = getattr(appCtx, req_type) + print "got args", req_args + + # done decoding message + conn.msgLength = 0 + conn.inProgress = False + + req_f = getattr(conn.appCtx, req_type) if req_args == None: res = req_f() elif isinstance(req_args, tuple): @@ -82,9 +139,18 @@ class TMSPServer(): print "non-zero retcode:", ret_code if req_type in ("echo", "info"): # these dont return a ret code - response += bytearray([resTypeByte]) + encode(res) + enc = encode(res) + # messages are length prefixed + conn.resBuf.write(encode(len(enc) + 1)) + conn.resBuf.write([resTypeByte]) + conn.resBuf.write(enc) else: - response += bytearray([resTypeByte]) + encode(ret_code) + encode(res) + enc, encRet = encode(res), encode(ret_code) + # messages are length prefixed + conn.resBuf.write(encode(len(enc)+len(encRet)+1)) + conn.resBuf.write([resTypeByte]) + conn.resBuf.write(encRet) + conn.resBuf.write(enc) except TypeError as e: print "TypeError on reading from connection:", e self.handle_conn_closed(r) @@ -97,8 +163,8 @@ class TMSPServer(): print "IOError on reading from connection:", e self.handle_conn_closed(r) return - except: - print "error reading from connection", sys.exc_info()[0] # TODO better + except Exception as e: + print "error reading from connection", str(e) # sys.exc_info()[0] # TODO better self.handle_conn_closed(r) return @@ -112,7 +178,7 @@ class TMSPServer(): self.handle_new_connection(r) # undo adding to read list ... - except NameError as e: + except rameError as e: print "Could not connect due to NameError:", e except TypeError as e: print "Could not connect due to TypeError:", e