more tests for db, improvements to Trade and DB classes

This commit is contained in:
James Prestwich 2017-09-15 13:29:07 -06:00
parent e375537af4
commit 8dc7f77bae
No known key found for this signature in database
GPG Key ID: 519E010A79028CCC
6 changed files with 105 additions and 42 deletions

View File

@ -7,6 +7,7 @@ from xcat.db import DB
import xcat.userInput as userInput import xcat.userInput as userInput
import xcat.utils as utils import xcat.utils as utils
from xcat.protocol import Protocol from xcat.protocol import Protocol
from xcat.trades import Trade
def save_state(trade, tradeid): def save_state(trade, tradeid):
@ -26,9 +27,11 @@ def checkSellStatus(tradeid):
if status == 'init': if status == 'init':
userInput.authorize_fund_sell(trade) userInput.authorize_fund_sell(trade)
fund_tx = protocol.fund_sell_contract(trade) fund_tx = protocol.fund_sell_contract(trade)
print("Sent fund_tx", fund_tx) print("Sent fund_tx", fund_tx)
trade.sell.fund_tx = fund_tx trade.sell.fund_tx = fund_tx
save_state(trade, tradeid) save_state(trade, tradeid)
elif status == 'buyerFunded': elif status == 'buyerFunded':
secret = db.get_secret(tradeid) secret = db.get_secret(tradeid)
print("Retrieved secret to redeem funds for " print("Retrieved secret to redeem funds for "
@ -45,10 +48,12 @@ def checkSellStatus(tradeid):
save_state(trade, tradeid) save_state(trade, tradeid)
# Remove from db? Or just from temporary file storage # Remove from db? Or just from temporary file storage
utils.cleanup(tradeid) utils.cleanup(tradeid)
elif status == 'sellerFunded': elif status == 'sellerFunded':
print("Buyer has not yet funded the contract where you offered to " print("Buyer has not yet funded the contract where you offered to "
"buy {0}, please wait for them to complete " "buy {0}, please wait for them to complete "
"their part.".format(trade.buy.currency)) "their part.".format(trade.buy.currency))
elif status == 'sellerRedeemed': elif status == 'sellerRedeemed':
print("You have already redeemed the p2sh on the second chain of " print("You have already redeemed the p2sh on the second chain of "
"this trade.") "this trade.")
@ -137,10 +142,9 @@ def checkBuyStatus(tradeid):
# Import a trade in hex, and save to db # Import a trade in hex, and save to db
def importtrade(tradeid, hexstr=''): def importtrade(tradeid, hexstr=''):
db = DB()
protocol = Protocol() protocol = Protocol()
trade = utils.x2s(hexstr) trade = utils.x2s(hexstr)
trade = db.instantiate(trade) trade = Trade(trade)
protocol.import_addrs(trade) protocol.import_addrs(trade)
print(trade.toJSON()) print(trade.toJSON())
save_state(trade, tradeid) save_state(trade, tradeid)

View File

@ -16,36 +16,32 @@ class DB():
# Takes dict or obj, saves json str as bytes # Takes dict or obj, saves json str as bytes
def create(self, trade, tradeid): def create(self, trade, tradeid):
if type(trade) == dict: if isinstance(trade, dict):
trade = json.dumps(trade) trade = json.dumps(trade, sort_keys=True, indent=4)
else: elif isinstance(trade, Trade):
trade = trade.toJSON() trade = trade.toJSON()
else:
raise ValueError('Expected dictionary or Trade object')
self.db.put(utils.b(tradeid), utils.b(trade)) self.db.put(utils.b(tradeid), utils.b(trade))
# Uses the funding txid as the key to save trade # Uses the funding txid as the key to save trade
def createByFundtx(self, trade): def createByFundtx(self, trade):
trade = trade.toJSON() if isinstance(trade, dict):
# # Save trade by initiating txid txid = trade['sell']['fund_tx']
jt = json.loads(trade) trade = json.dumps(trade, sort_keys=True, indent=4)
txid = jt['sell']['fund_tx'] elif isinstance(trade, Trade):
txid = trade.sell.fund_tx
trade = trade.toJSON()
else:
raise ValueError('Expected dictionary or Trade object')
self.db.put(utils.b(txid), utils.b(trade)) self.db.put(utils.b(txid), utils.b(trade))
def get(self, tradeid): def get(self, tradeid):
rawtrade = self.db.get(utils.b(tradeid)) rawtrade = self.db.get(utils.b(tradeid))
tradestr = str(rawtrade, 'utf-8') tradestr = str(rawtrade, 'utf-8')
trade = self.instantiate(tradestr) trade = Trade(fromJSON=tradestr)
return trade return trade
@staticmethod
def instantiate(trade):
if type(trade) == str:
tradestr = json.loads(trade)
trade = Trade(
buy=Contract(tradestr['buy']),
sell=Contract(tradestr['sell']),
commitment=tradestr['commitment'])
return trade
############################################# #############################################
###### Preimages stored by tradeid ########## ###### Preimages stored by tradeid ##########
############################################# #############################################

View File

@ -1,8 +1,8 @@
import unittest import unittest
import xcat.cli as cli import xcat.cli as cli
import xcat.tests.utils as testutils
from xcat.db import DB from xcat.db import DB
from xcat.protocol import Protocol from xcat.protocol import Protocol
import xcat.tests.utils as testutils
from xcat.trades import Trade # , Contract from xcat.trades import Trade # , Contract

View File

@ -1,5 +1,6 @@
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
import json
import xcat.db as db import xcat.db as db
import xcat.tests.utils as utils import xcat.tests.utils as utils
@ -14,32 +15,63 @@ class TestDB(unittest.TestCase):
self.assertIsInstance(self.db.db, mock.Mock) self.assertIsInstance(self.db.db, mock.Mock)
self.assertIsInstance(self.db.preimageDB, mock.Mock) self.assertIsInstance(self.db.preimageDB, mock.Mock)
@mock.patch('xcat.db.json') def test_create_with_dict(self):
def test_create_with_dict(self, mock_json):
test_id = 'test trade id' test_id = 'test trade id'
trade_string = 'trade string'
mock_json.dumps.return_value = trade_string
test_trade = utils.test_trade
self.db.create(test_trade, test_id) self.db.create(utils.test_trade_dict, test_id)
mock_json.dumps.assert_called_with(test_trade)
self.db.db.put.assert_called_with( self.db.db.put.assert_called_with(
str.encode(test_id), str.encode(test_id),
str.encode(trade_string)) str.encode(str(utils.test_trade)))
def test_create_with_trade(self): def test_create_with_trade(self):
pass test_id = 'test trade id'
def test_createByFundtx(self): self.db.create(utils.test_trade, test_id)
pass
self.db.db.put.assert_called_with(
str.encode(test_id),
str.encode(json.dumps(utils.test_trade_dict,
sort_keys=True,
indent=4)))
def test_create_with_error(self):
with self.assertRaises(ValueError) as context:
self.db.create('this is not valid input', 'trade_id')
self.assertTrue(
'Expected dictionary or Trade object'
in str(context.exception))
def test_createByFundtx_with_dict(self):
self.db.createByFundtx(utils.test_trade_dict)
self.db.db.put.assert_called_with(
str.encode('5c5e91a89a08b2d6698f50c9fd9bb2fa22da6c74e226c3dd63d'
'59511566a2fdb'),
str.encode(str(utils.test_trade)))
def test_createByFundtx_with_trade(self):
self.db.createByFundtx(utils.test_trade)
self.db.db.put.assert_called_with(
str.encode('5c5e91a89a08b2d6698f50c9fd9bb2fa22da6c74e226c3dd63d'
'59511566a2fdb'),
str.encode(json.dumps(utils.test_trade_dict,
sort_keys=True,
indent=4)))
def test_createByFundtx_with_error(self):
with self.assertRaises(ValueError) as context:
self.db.createByFundtx('this is not valid input')
self.assertTrue(
'Expected dictionary or Trade object'
in str(context.exception))
def test_get(self): def test_get(self):
pass pass
def test_instantiate(self):
pass
def test_save_secret(self): def test_save_secret(self):
pass pass

View File

@ -1,6 +1,7 @@
from xcat.db import DB from xcat.db import DB
from xcat.trades import Contract, Trade
test_trade = { test_trade_dict = {
"sell": { "sell": {
"amount": 3.5, "amount": 3.5,
"redeemScript": "63a82003d58daab37238604b3e57d4a8bdcffa401dc497a9c1aa4f08ffac81616c22b68876a9147788b4511a25fba1092e67b307a6dcdb6da125d967022a04b17576a914c7043e62a7391596116f54f6a64c8548e97d3fd96888ac", "redeemScript": "63a82003d58daab37238604b3e57d4a8bdcffa401dc497a9c1aa4f08ffac81616c22b68876a9147788b4511a25fba1092e67b307a6dcdb6da125d967022a04b17576a914c7043e62a7391596116f54f6a64c8548e97d3fd96888ac",
@ -21,6 +22,12 @@ test_trade = {
"fulfiller": "tmTjZSg4pX2Us6V5HttiwFZwj464fD2ZgpY"}, "fulfiller": "tmTjZSg4pX2Us6V5HttiwFZwj464fD2ZgpY"},
"commitment": "03d58daab37238604b3e57d4a8bdcffa401dc497a9c1aa4f08ffac81616c22b6"} "commitment": "03d58daab37238604b3e57d4a8bdcffa401dc497a9c1aa4f08ffac81616c22b6"}
test_sell_contract = Contract(test_trade_dict['sell'])
test_buy_contract = Contract(test_trade_dict['buy'])
test_trade = Trade(sell=test_sell_contract,
buy=test_buy_contract,
commitment=test_trade_dict['commitment'])
def mktrade(): def mktrade():
db = DB() db = DB()

View File

@ -1,19 +1,43 @@
import json import json
class Trade(object): class Trade():
def __init__(self, sell=None, buy=None, commitment=None): def __init__(self, sell=None, buy=None, commitment=None,
fromJSON=None, fromDict=None):
'''Create a new trade with buy and sell contracts across two chains''' '''Create a new trade with buy and sell contracts across two chains'''
self.sell = sell
self.buy = buy if fromJSON is not None and fromDict is None:
self.commitment = commitment if isinstance(fromJSON, str):
fromDict = json.loads(fromJSON)
else:
raise ValueError('Expected json string')
if fromDict is not None:
self.sell = Contract(fromDict['sell'])
self.buy = Contract(fromDict['buy'])
self.commitment = fromDict['commitment']
else:
self.sell = sell
self.buy = buy
self.commitment = commitment
def toJSON(self): def toJSON(self):
return json.dumps( return json.dumps(
self, default=lambda o: o.__dict__, sort_keys=True, indent=4) self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
def __str__(self):
return self.toJSON()
class Contract(object): def __repr__(self):
return 'Trade:\n{0} {1} from {2}\nfor\n{3} {4} from {5}'.format(
self.sell.amount,
self.sell.currency,
self.sell.initiator,
self.buy.amount,
self.buy.currency,
self.buy.initiator)
class Contract():
def __init__(self, data): def __init__(self, data):
allowed = ('fulfiller', 'initiator', 'currency', 'p2sh', 'amount', allowed = ('fulfiller', 'initiator', 'currency', 'p2sh', 'amount',
'fund_tx', 'redeem_tx', 'secret', 'redeemScript', 'fund_tx', 'redeem_tx', 'secret', 'redeemScript',