diff --git a/client/client.go b/client/client.go index 07da0e65..adb2c463 100644 --- a/client/client.go +++ b/client/client.go @@ -12,6 +12,7 @@ import ( "github.com/tendermint/tmsp/types" ) +const reqQueueSize = 256 // TODO make configurable const maxResponseSize = 1048576 // 1MB TODO make configurable const flushThrottleMS = 20 // Don't wait longer than... @@ -24,20 +25,25 @@ type TMSPClient struct { QuitService sync.Mutex // [EB]: is this even used? - reqQueue chan *reqRes + reqQueue chan *ReqRes flushTimer *ThrottleTimer mtx sync.Mutex + addr string conn net.Conn bufWriter *bufio.Writer err error reqSent *list.List - resCb func(*types.Request, *types.Response) + resCb func(*types.Request, *types.Response) // listens to all callbacks } -func NewTMSPClient(conn net.Conn, bufferSize int) *TMSPClient { +func NewTMSPClient(addr string) (*TMSPClient, error) { + conn, err := Connect(addr) + if err != nil { + return nil, err + } cli := &TMSPClient{ - reqQueue: make(chan *reqRes, bufferSize), + reqQueue: make(chan *ReqRes, reqQueueSize), flushTimer: NewThrottleTimer("TMSPClient", flushThrottleMS), conn: conn, @@ -47,7 +53,7 @@ func NewTMSPClient(conn net.Conn, bufferSize int) *TMSPClient { } cli.QuitService = *NewQuitService(nil, "TMSPClient", cli) cli.Start() // Just start it, it's confusing for callers to remember to start. - return cli + return cli, nil } func (cli *TMSPClient) OnStart() error { @@ -62,6 +68,7 @@ func (cli *TMSPClient) OnStop() { cli.conn.Close() } +// Set listener for all responses // NOTE: callback may get internally generated flush responses. func (cli *TMSPClient) SetResponseCallback(resCb Callback) { cli.mtx.Lock() @@ -140,7 +147,7 @@ func (cli *TMSPClient) recvResponseRoutine() { } } -func (cli *TMSPClient) willSendReq(reqres *reqRes) { +func (cli *TMSPClient) willSendReq(reqres *ReqRes) { cli.mtx.Lock() defer cli.mtx.Unlock() cli.reqSent.PushBack(reqres) @@ -150,12 +157,12 @@ func (cli *TMSPClient) didRecvResponse(res *types.Response) error { cli.mtx.Lock() defer cli.mtx.Unlock() - // Get the first reqRes + // Get the first ReqRes next := cli.reqSent.Front() if next == nil { return fmt.Errorf("Unexpected result type %v when nothing expected", res.Type) } - reqres := next.Value.(*reqRes) + reqres := next.Value.(*ReqRes) if !resMatchesReq(reqres.Request, res) { return fmt.Errorf("Unexpected result type %v when response to %v expected", res.Type, reqres.Request.Type) @@ -165,7 +172,12 @@ func (cli *TMSPClient) didRecvResponse(res *types.Response) error { reqres.Done() // Release waiters cli.reqSent.Remove(next) // Pop first item from linked list - // Callback if there is a listener + // Notify reqRes listener if set + if cb := reqres.GetCallback(); cb != nil { + cb(res) + } + + // Notify client listener if set if cli.resCb != nil { cli.resCb(reqres.Request, res) } @@ -175,32 +187,32 @@ func (cli *TMSPClient) didRecvResponse(res *types.Response) error { //---------------------------------------- -func (cli *TMSPClient) EchoAsync(msg string) { - cli.queueRequest(types.RequestEcho(msg)) +func (cli *TMSPClient) EchoAsync(msg string) *ReqRes { + return cli.queueRequest(types.RequestEcho(msg)) } -func (cli *TMSPClient) FlushAsync() { - cli.queueRequest(types.RequestFlush()) +func (cli *TMSPClient) FlushAsync() *ReqRes { + return cli.queueRequest(types.RequestFlush()) } -func (cli *TMSPClient) SetOptionAsync(key string, value string) { - cli.queueRequest(types.RequestSetOption(key, value)) +func (cli *TMSPClient) SetOptionAsync(key string, value string) *ReqRes { + return cli.queueRequest(types.RequestSetOption(key, value)) } -func (cli *TMSPClient) AppendTxAsync(tx []byte) { - cli.queueRequest(types.RequestAppendTx(tx)) +func (cli *TMSPClient) AppendTxAsync(tx []byte) *ReqRes { + return cli.queueRequest(types.RequestAppendTx(tx)) } -func (cli *TMSPClient) CheckTxAsync(tx []byte) { - cli.queueRequest(types.RequestCheckTx(tx)) +func (cli *TMSPClient) CheckTxAsync(tx []byte) *ReqRes { + return cli.queueRequest(types.RequestCheckTx(tx)) } -func (cli *TMSPClient) GetHashAsync() { - cli.queueRequest(types.RequestGetHash()) +func (cli *TMSPClient) GetHashAsync() *ReqRes { + return cli.queueRequest(types.RequestGetHash()) } -func (cli *TMSPClient) QueryAsync(query []byte) { - cli.queueRequest(types.RequestQuery(query)) +func (cli *TMSPClient) QueryAsync(query []byte) *ReqRes { + return cli.queueRequest(types.RequestQuery(query)) } //---------------------------------------- @@ -261,7 +273,7 @@ func (cli *TMSPClient) QuerySync(query []byte) (code types.CodeType, result []by //---------------------------------------- -func (cli *TMSPClient) queueRequest(req *types.Request) *reqRes { +func (cli *TMSPClient) queueRequest(req *types.Request) *ReqRes { reqres := newReqRes(req) // TODO: set cli.err if reqQueue times out cli.reqQueue <- reqres @@ -283,20 +295,57 @@ func resMatchesReq(req *types.Request, res *types.Response) (ok bool) { return req.Type == res.Type } -type reqRes struct { +type ReqRes struct { *types.Request *sync.WaitGroup *types.Response // Not set atomically, so be sure to use WaitGroup. + + mtx sync.Mutex + done bool // Gets set to true once *after* WaitGroup.Done(). + cb func(*types.Response) // A single callback that may be set. } -func newReqRes(req *types.Request) *reqRes { - return &reqRes{ +func newReqRes(req *types.Request) *ReqRes { + return &ReqRes{ Request: req, WaitGroup: waitGroup1(), Response: nil, + + done: false, + cb: nil, } } +// Sets the callback for this ReqRes atomically. +// If reqRes is already done, calls cb immediately. +// NOTE: reqRes.cb should not change if reqRes.done. +// NOTE: only one callback is supported. +func (reqRes *ReqRes) SetCallback(cb func(res *types.Response)) { + reqRes.mtx.Lock() + + if reqRes.done { + reqRes.mtx.Unlock() + cb(reqRes.Response) + return + } + + defer reqRes.mtx.Unlock() + reqRes.cb = cb +} + +func (reqRes *ReqRes) GetCallback() func(*types.Response) { + reqRes.mtx.Lock() + defer reqRes.mtx.Unlock() + return reqRes.cb +} + +// NOTE: it should be safe to read reqRes.cb without locks after this. +func (reqRes *ReqRes) SetDone() { + reqRes.mtx.Lock() + reqRes.done = true + reqRes.mtx.Unlock() +} + func waitGroup1() (wg *sync.WaitGroup) { wg = &sync.WaitGroup{} wg.Add(1)