Merge pull request #6 from readevalprint/pep8

Python Pep8 style fixes
This commit is contained in:
Jae Kwon 2016-01-05 14:54:12 -08:00
commit 57ce0c12e8
6 changed files with 358 additions and 329 deletions

View File

@ -1,82 +1,85 @@
import sys import sys
sys.path.insert(0, './tmsp')
from wire import * from tmsp.wire import hex2bytes, decode_big_endian, encode_big_endian
from server import * from tmsp.server import TMSPServer
from tmsp.reader import BytesBuffer
# tmsp application interface
class CounterApplication(): class CounterApplication():
def __init__(self): def __init__(self):
self.hashCount = 0 self.hashCount = 0
self.txCount = 0 self.txCount = 0
self.commitCount = 0 self.commitCount = 0
def open(self): def open(self):
return CounterAppContext(self) return CounterAppContext(self)
class CounterAppContext(): class CounterAppContext():
def __init__(self, app):
self.app = app
self.hashCount = app.hashCount
self.txCount = app.txCount
self.commitCount = app.commitCount
self.serial = False
def echo(self, msg):
return msg, 0
def info(self): def __init__(self, app):
return ["hash, tx, commit counts:%d, %d, %d"%(self.hashCount, self.txCount, self.commitCount)], 0 self.app = app
self.hashCount = app.hashCount
self.txCount = app.txCount
self.commitCount = app.commitCount
self.serial = False
def set_option(self, key, value): def echo(self, msg):
if key == "serial" and value == "on": return msg, 0
self.serial = True
return 0
def append_tx(self, txBytes): def info(self):
if self.serial: return ["hash, tx, commit counts:%d, %d, %d" % (self.hashCount,
txByteArray = bytearray(txBytes) self.txCount,
if len(txBytes) >= 2 and txBytes[:2] == "0x": self.commitCount)], 0
txByteArray = hex2bytes(txBytes[2:])
txValue = decode_big_endian(BytesBuffer(txByteArray), len(txBytes))
if txValue != self.txCount:
return None, 1
self.txCount += 1
return None, 0
def get_hash(self): def set_option(self, key, value):
self.hashCount += 1 if key == "serial" and value == "on":
if self.txCount == 0: self.serial = True
return "", 0 return 0
h = encode_big_endian(self.txCount, 8)
h.reverse()
return str(h), 0
def commit(self): def append_tx(self, txBytes):
self.commitCount += 1 if self.serial:
return 0 txByteArray = bytearray(txBytes)
if len(txBytes) >= 2 and txBytes[:2] == "0x":
txByteArray = hex2bytes(txBytes[2:])
txValue = decode_big_endian(
BytesBuffer(txByteArray), len(txBytes))
if txValue != self.txCount:
return None, 1
self.txCount += 1
return None, 0
def rollback(self): def get_hash(self):
return 0 self.hashCount += 1
if self.txCount == 0:
return "", 0
h = encode_big_endian(self.txCount, 8)
h.reverse()
return str(h), 0
def add_listener(self): def commit(self):
return 0 self.commitCount += 1
return 0
def rm_listener(self): def rollback(self):
return 0 return 0
def add_listener(self):
return 0
def rm_listener(self):
return 0
def event(self):
return
def event(self):
return
if __name__ == '__main__': if __name__ == '__main__':
l = len(sys.argv) l = len(sys.argv)
if l == 1: if l == 1:
port = 46658 port = 46658
elif l == 2: elif l == 2:
port = int(sys.argv[1]) port = int(sys.argv[1])
else: else:
print "too many arguments" print "too many arguments"
@ -84,6 +87,6 @@ if __name__ == '__main__':
print 'TMSP Demo APP (Python)' print 'TMSP Demo APP (Python)'
app = CounterApplication() app = CounterApplication()
server = TMSPServer(app, port) server = TMSPServer(app, port)
server.main_loop() server.main_loop()

View File

View File

@ -1,54 +1,55 @@
from wire import * from wire import decode_string
# map type_byte to message name # map type_byte to message name
message_types = { message_types = {
0x01 : "echo", 0x01: "echo",
0x02 : "flush", 0x02: "flush",
0x03 : "info", 0x03: "info",
0x04 : "set_option", 0x04: "set_option",
0x21 : "append_tx", 0x21: "append_tx",
0x22 : "get_hash", 0x22: "get_hash",
0x23 : "commit", 0x23: "commit",
0x24 : "rollback", 0x24: "rollback",
0x25 : "add_listener", 0x25: "add_listener",
0x26 : "rm_listener", 0x26: "rm_listener",
} }
# return the decoded arguments of tmsp messages # return the decoded arguments of tmsp messages
class RequestDecoder(): class RequestDecoder():
def __init__(self, reader):
self.reader = reader
def echo(self): def __init__(self, reader):
return decode_string(self.reader) self.reader = reader
def flush(self): def echo(self):
return return decode_string(self.reader)
def info(self): def flush(self):
return return
def set_option(self): def info(self):
return decode_string(self.reader), decode_string(self.reader) return
def append_tx(self): def set_option(self):
return decode_string(self.reader) return decode_string(self.reader), decode_string(self.reader)
def get_hash(self): def append_tx(self):
return return decode_string(self.reader)
def commit(self): def get_hash(self):
return return
def rollback(self): def commit(self):
return return
def add_listener(self): def rollback(self):
# TODO return
return
def rm_listener(self):
# TODO
return
def add_listener(self):
# TODO
return
def rm_listener(self):
# TODO
return

View File

@ -1,50 +1,56 @@
# Simple read() method around a bytearray # Simple read() method around a bytearray
class BytesBuffer(): class BytesBuffer():
def __init__(self, b):
self.buf = b
self.readCount = 0
def count(self): def __init__(self, b):
return self.readCount self.buf = b
self.readCount = 0
def reset_count(self): def count(self):
self.readCount = 0 return self.readCount
def size(self):
return len(self.buf)
def peek(self): def reset_count(self):
return self.buf[0] self.readCount = 0
def write(self, b): def size(self):
# b should be castable to byte array return len(self.buf)
self.buf += bytearray(b)
def read(self, n): def peek(self):
if len(self.buf) < n: return self.buf[0]
print "reader err: buf less than n"
# TODO: exception def write(self, b):
return # b should be castable to byte array
self.readCount += n self.buf += bytearray(b)
r = self.buf[:n]
self.buf = self.buf[n:] def read(self, n):
return r if len(self.buf) < n:
print "reader err: buf less than n"
# TODO: exception
return
self.readCount += n
r = self.buf[:n]
self.buf = self.buf[n:]
return r
# Buffer bytes off a tcp connection and read them off in chunks # Buffer bytes off a tcp connection and read them off in chunks
class ConnReader(): class ConnReader():
def __init__(self, conn):
self.conn = conn
self.buf = bytearray()
# blocking def __init__(self, conn):
def read(self, n): self.conn = conn
while n > len(self.buf): self.buf = bytearray()
moreBuf = self.conn.recv(1024)
if not moreBuf:
raise IOError("dead connection")
self.buf = self.buf + bytearray(moreBuf)
r = self.buf[:n] # blocking
self.buf = self.buf[n:] def read(self, n):
return r while n > len(self.buf):
moreBuf = self.conn.recv(1024)
if not moreBuf:
raise IOError("dead connection")
self.buf = self.buf + bytearray(moreBuf)
r = self.buf[:n]
self.buf = self.buf[n:]
return r

View File

@ -1,38 +1,44 @@
import socket import socket
import select import select
import sys import sys
import os
from wire import * from wire import decode_varint, encode
from reader import * from reader import BytesBuffer
from msg import * from msg import RequestDecoder, message_types
# hold the asyncronous state of a connection # hold the asyncronous state of a connection
# ie. we may not get enough bytes on one read to decode the message # ie. we may not get enough bytes on one read to decode the message
class Connection(): class Connection():
def __init__(self, fd, appCtx):
self.fd = fd def __init__(self, fd, appCtx):
self.appCtx = appCtx self.fd = fd
self.recBuf = BytesBuffer(bytearray()) self.appCtx = appCtx
self.resBuf = BytesBuffer(bytearray()) self.recBuf = BytesBuffer(bytearray())
self.msgLength = 0 self.resBuf = BytesBuffer(bytearray())
self.decoder = RequestDecoder(self.recBuf) self.msgLength = 0
self.inProgress = False # are we in the middle of a message self.decoder = RequestDecoder(self.recBuf)
self.inProgress = False # are we in the middle of a message
def recv(this):
data = this.fd.recv(1024) def recv(this):
if not data: # what about len(data) == 0 data = this.fd.recv(1024)
raise IOError("dead connection") if not data: # what about len(data) == 0
this.recBuf.write(data) raise IOError("dead connection")
this.recBuf.write(data)
# TMSP server responds to messges by calling methods on the app # TMSP server responds to messges by calling methods on the app
class TMSPServer():
def __init__(self, app, port=5410):
self.app = app
self.appMap = {} # map conn file descriptors to (appContext, reqBuf, resBuf, msgDecoder)
self.port = port
class TMSPServer():
def __init__(self, app, port=5410):
self.app = app
# map conn file descriptors to (appContext, reqBuf, resBuf, msgDecoder)
self.appMap = {}
self.port = port
self.listen_backlog = 10 self.listen_backlog = 10
self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -49,13 +55,13 @@ class TMSPServer():
def handle_new_connection(self, r): def handle_new_connection(self, r):
new_fd, new_addr = r.accept() new_fd, new_addr = r.accept()
new_fd.setblocking(0) # non-blocking new_fd.setblocking(0) # non-blocking
self.read_list.append(new_fd) self.read_list.append(new_fd)
self.write_list.append(new_fd) self.write_list.append(new_fd)
print 'new connection to', new_addr print 'new connection to', new_addr
appContext = self.app.open() appContext = self.app.open()
self.appMap[new_fd] = Connection(new_fd, appContext) self.appMap[new_fd] = Connection(new_fd, appContext)
def handle_conn_closed(self, r): def handle_conn_closed(self, r):
self.read_list.remove(r) self.read_list.remove(r)
@ -64,137 +70,137 @@ class TMSPServer():
print "connection closed" print "connection closed"
def handle_recv(self, r): def handle_recv(self, r):
# appCtx, recBuf, resBuf, conn # appCtx, recBuf, resBuf, conn
conn = self.appMap[r] conn = self.appMap[r]
while True: while True:
try: try:
print "recv loop" print "recv loop"
# check if we need more data first # check if we need more data first
if conn.inProgress: if conn.inProgress:
if conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength: if (conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength):
conn.recv() conn.recv()
else: else:
if conn.recBuf.size() == 0: if conn.recBuf.size() == 0:
conn.recv() conn.recv()
conn.inProgress = True conn.inProgress = True
# see if we have enough to get the message length # see if we have enough to get the message length
if conn.msgLength == 0: if conn.msgLength == 0:
ll = conn.recBuf.peek() ll = conn.recBuf.peek()
if conn.recBuf.size() < 1 + ll: if conn.recBuf.size() < 1 + ll:
# we don't have enough bytes to read the length yet # we don't have enough bytes to read the length yet
return return
print "decoding msg length" print "decoding msg length"
conn.msgLength = decode_varint(conn.recBuf) conn.msgLength = decode_varint(conn.recBuf)
# see if we have enough to decode the message # see if we have enough to decode the message
if conn.recBuf.size() < conn.msgLength: if conn.recBuf.size() < conn.msgLength:
return return
# now we can decode the message # now we can decode the message
# first read the request type and get the particular msg decoder # first read the request type and get the particular msg
typeByte = conn.recBuf.read(1) # decoder
typeByte = int(typeByte[0]) typeByte = conn.recBuf.read(1)
resTypeByte = typeByte+0x10 typeByte = int(typeByte[0])
req_type = message_types[typeByte] resTypeByte = typeByte + 0x10
req_type = message_types[typeByte]
if req_type == "flush": if req_type == "flush":
# messages are length prefixed # messages are length prefixed
conn.resBuf.write(encode(1)) conn.resBuf.write(encode(1))
conn.resBuf.write([resTypeByte]) conn.resBuf.write([resTypeByte])
sent = conn.fd.send(str(conn.resBuf.buf)) conn.fd.send(str(conn.resBuf.buf))
conn.msgLength = 0 conn.msgLength = 0
conn.inProgress = False conn.inProgress = False
conn.resBuf = BytesBuffer(bytearray()) conn.resBuf = BytesBuffer(bytearray())
return return
decoder = getattr(conn.decoder, req_type) decoder = getattr(conn.decoder, req_type)
print "decoding args" print "decoding args"
req_args = decoder() req_args = decoder()
print "got args", req_args print "got args", req_args
# done decoding message # done decoding message
conn.msgLength = 0 conn.msgLength = 0
conn.inProgress = False conn.inProgress = False
req_f = getattr(conn.appCtx, req_type) req_f = getattr(conn.appCtx, req_type)
if req_args == None: if req_args is None:
res = req_f() res = req_f()
elif isinstance(req_args, tuple): elif isinstance(req_args, tuple):
res = req_f(*req_args) res = req_f(*req_args)
else: else:
res = req_f(req_args) res = req_f(req_args)
if isinstance(res, tuple): if isinstance(res, tuple):
res, ret_code = res res, ret_code = res
else: else:
ret_code = res ret_code = res
res = None res = None
print "called", req_type, "ret code:", ret_code print "called", req_type, "ret code:", ret_code
if ret_code != 0: if ret_code != 0:
print "non-zero retcode:", ret_code print "non-zero retcode:", ret_code
if req_type in ("echo", "info"): # these dont return a ret code if req_type in ("echo", "info"): # these dont return a ret code
enc = encode(res) enc = encode(res)
# messages are length prefixed # messages are length prefixed
conn.resBuf.write(encode(len(enc) + 1)) conn.resBuf.write(encode(len(enc) + 1))
conn.resBuf.write([resTypeByte]) conn.resBuf.write([resTypeByte])
conn.resBuf.write(enc) conn.resBuf.write(enc)
else: else:
enc, encRet = encode(res), encode(ret_code) enc, encRet = encode(res), encode(ret_code)
# messages are length prefixed # messages are length prefixed
conn.resBuf.write(encode(len(enc)+len(encRet)+1)) conn.resBuf.write(encode(len(enc) + len(encRet) + 1))
conn.resBuf.write([resTypeByte]) conn.resBuf.write([resTypeByte])
conn.resBuf.write(encRet) conn.resBuf.write(encRet)
conn.resBuf.write(enc) conn.resBuf.write(enc)
except TypeError as e: except TypeError as e:
print "TypeError on reading from connection:", e print "TypeError on reading from connection:", e
self.handle_conn_closed(r) self.handle_conn_closed(r)
return return
except ValueError as e: except ValueError as e:
print "ValueError on reading from connection:", e print "ValueError on reading from connection:", e
self.handle_conn_closed(r) self.handle_conn_closed(r)
return return
except IOError as e: except IOError as e:
print "IOError on reading from connection:", e print "IOError on reading from connection:", e
self.handle_conn_closed(r) self.handle_conn_closed(r)
return return
except Exception as e: except Exception as e:
print "error reading from connection", str(e) # sys.exc_info()[0] # TODO better # sys.exc_info()[0] # TODO better
self.handle_conn_closed(r) print "error reading from connection", str(e)
return self.handle_conn_closed(r)
return
def main_loop(self): def main_loop(self):
while not self.shutdown: while not self.shutdown:
r_list, w_list, _ = select.select(self.read_list, self.write_list, [], 2.5) r_list, w_list, _ = select.select(
self.read_list, self.write_list, [], 2.5)
for r in r_list: for r in r_list:
if (r == self.listener): if (r == self.listener):
try: try:
self.handle_new_connection(r) self.handle_new_connection(r)
# undo adding to read list ...
# undo adding to read list ... except NameError as e:
except rameError as e: print "Could not connect due to NameError:", e
print "Could not connect due to NameError:", e except TypeError as e:
except TypeError as e: print "Could not connect due to TypeError:", e
print "Could not connect due to TypeError:", e except:
except: print "Could not connect due to unexpected error:", sys.exc_info()[0]
print "Could not connect due to unexpected error:", sys.exc_info()[0] else:
else:
self.handle_recv(r) self.handle_recv(r)
def handle_shutdown(self): def handle_shutdown(self):
for r in self.read_list: for r in self.read_list:
r.close() r.close()
for w in self.write_list: for w in self.write_list:
try: try:
w.close() w.close()
except: pass except Exception as e:
print(e) # TODO: add logging
self.shutdown = True self.shutdown = True

View File

@ -2,101 +2,114 @@
# the decoder works off a reader # the decoder works off a reader
# the encoder returns bytearray # the encoder returns bytearray
def hex2bytes(h): def hex2bytes(h):
return bytearray(h.decode('hex')) return bytearray(h.decode('hex'))
def bytes2hex(b): def bytes2hex(b):
if type(b) in (str, unicode): if type(b) in (str, unicode):
return "".join([hex(ord(c))[2:].zfill(2) for c in b]) return "".join([hex(ord(c))[2:].zfill(2) for c in b])
else: else:
return bytes2hex(b.decode()) return bytes2hex(b.decode())
# expects uvarint64 (no crazy big nums!) # expects uvarint64 (no crazy big nums!)
def uvarint_size(i): def uvarint_size(i):
if i == 0: if i == 0:
return 0 return 0
for j in xrange(1, 8): for j in xrange(1, 8):
if i < 1<<j*8: if i < 1 << j * 8:
return j return j
return 8 return 8
# expects i < 2**size # expects i < 2**size
def encode_big_endian(i, size): def encode_big_endian(i, size):
if size == 0: if size == 0:
return bytearray() return bytearray()
return encode_big_endian(i/256, size-1) + bytearray([i%256]) return encode_big_endian(i / 256, size - 1) + bytearray([i % 256])
def decode_big_endian(reader, size): def decode_big_endian(reader, size):
if size == 0: if size == 0:
return 0 return 0
firstByte = reader.read(1)[0] firstByte = reader.read(1)[0]
return firstByte*(256**(size-1)) + decode_big_endian(reader, size-1) return firstByte * (256 ** (size - 1)) + decode_big_endian(reader, size - 1)
# ints are max 16 bytes long # ints are max 16 bytes long
def encode_varint(i): def encode_varint(i):
negate = False negate = False
if i < 0: if i < 0:
negate = True negate = True
i = -i i = -i
size = uvarint_size(i) size = uvarint_size(i)
if size == 0: if size == 0:
return bytearray([0]) return bytearray([0])
big_end = encode_big_endian(i, size) big_end = encode_big_endian(i, size)
if negate: if negate:
size += 0xF0 size += 0xF0
return bytearray([size]) + big_end return bytearray([size]) + big_end
# returns the int and whats left of the byte array # returns the int and whats left of the byte array
def decode_varint(reader):
size = reader.read(1)[0]
if size == 0:
return 0
negate = True if size > int(0xF0) else False
if negate: size = size -0xF0 def decode_varint(reader):
i = decode_big_endian(reader, size) size = reader.read(1)[0]
if negate: i = i*(-1) if size == 0:
return i return 0
negate = True if size > int(0xF0) else False
if negate:
size = size - 0xF0
i = decode_big_endian(reader, size)
if negate:
i = i * (-1)
return i
def encode_string(s): def encode_string(s):
size = encode_varint(len(s)) size = encode_varint(len(s))
return size + bytearray(s) return size + bytearray(s)
def decode_string(reader): def decode_string(reader):
length = decode_varint(reader) length = decode_varint(reader)
return str(reader.read(length)) return str(reader.read(length))
def encode_list(s): def encode_list(s):
b = bytearray() b = bytearray()
map(b.extend, map(encode, s)) map(b.extend, map(encode, s))
return encode_varint(len(s)) + b return encode_varint(len(s)) + b
def encode(s): def encode(s):
if s == None: if s is None:
return bytearray() return bytearray()
if isinstance(s, int): if isinstance(s, int):
return encode_varint(s) return encode_varint(s)
elif isinstance(s, str): elif isinstance(s, str):
return encode_string(s) return encode_string(s)
elif isinstance(s, list): elif isinstance(s, list):
return encode_list(s) return encode_list(s)
else: else:
print "UNSUPPORTED TYPE!", type(s), s print "UNSUPPORTED TYPE!", type(s), s
import binascii
if __name__ == '__main__': if __name__ == '__main__':
ns = [100,100,1000,256] ns = [100, 100, 1000, 256]
ss = [2,5,5,2] ss = [2, 5, 5, 2]
bs = map(encode_big_endian, ns,ss) bs = map(encode_big_endian, ns, ss)
ds = map(decode_big_endian, bs,ss) ds = map(decode_big_endian, bs, ss)
print ns print ns
print [i[0] for i in ds] print [i[0] for i in ds]
ss = ["abc", "hi there jim", "ok now what"] ss = ["abc", "hi there jim", "ok now what"]
e = map(encode_string, ss) e = map(encode_string, ss)
d = map(decode_string, e) d = map(decode_string, e)
print ss print ss
print [i[0] for i in d] print [i[0] for i in d]