From 03c1fd712a7564c618abc976d66b6fd8cdb3bf63 Mon Sep 17 00:00:00 2001 From: Hendrik Hofstadt Date: Sun, 7 Oct 2018 20:04:24 +0200 Subject: [PATCH] Implement IPC PV and abstract socket signing --- privval/ipc.go | 218 ++++++++++++++++++++++++++++ privval/ipc_test.go | 134 +++++++++++++++++ privval/rs_client.go | 236 ++++++++++++++++++++++++++++++ privval/socket.go | 316 +++++++++-------------------------------- privval/socket_tcp.go | 32 ++--- privval/socket_test.go | 48 ++----- 6 files changed, 678 insertions(+), 306 deletions(-) create mode 100644 privval/ipc.go create mode 100644 privval/ipc_test.go create mode 100644 privval/rs_client.go diff --git a/privval/ipc.go b/privval/ipc.go new file mode 100644 index 00000000..1e7219fe --- /dev/null +++ b/privval/ipc.go @@ -0,0 +1,218 @@ +package privval + +import ( + "io" + "net" + "sync" + "time" + + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" +) + +// IPCPV implements PrivValidator, it uses a unix socket to request signatures +// from an external process. +type IPCPV struct { + cmn.BaseService + *RemoteSignerClient + + addr string + + connTimeout time.Duration + + lock sync.Mutex + cancelPing chan bool +} + +// Check that IPCPV implements PrivValidator. +var _ types.PrivValidator = (*IPCPV)(nil) + +// NewIPCPV returns an instance of IPCPV. +func NewIPCPV( + logger log.Logger, + socketAddr string, +) *IPCPV { + sc := &IPCPV{ + addr: socketAddr, + connTimeout: connTimeout, + } + + sc.BaseService = *cmn.NewBaseService(logger, "IPCPV", sc) + sc.RemoteSignerClient = NewRemoteSignerClient(logger, nil) + + return sc +} + +// OnStart implements cmn.Service. +func (sc *IPCPV) OnStart() error { + err := sc.connect() + if err != nil { + err = cmn.ErrorWrap(err, "failed to connect") + sc.Logger.Error( + "OnStart", + "err", err, + ) + + return err + } + + err = sc.RemoteSignerClient.Start() + if err != nil { + err = cmn.ErrorWrap(err, "failed to start RemoteSignerClient") + sc.Logger.Error( + "OnStart", + "err", err, + ) + + return err + } + + return nil +} + +// OnStop implements cmn.Service. +func (sc *IPCPV) OnStop() { + if err := sc.RemoteSignerClient.Stop(); err != nil { + err = cmn.ErrorWrap(err, "failed to stop RemoteSignerClient") + sc.Logger.Error( + "OnStop", + "err", err, + ) + } + + if sc.conn != nil { + if err := sc.conn.Close(); err != nil { + err = cmn.ErrorWrap(err, "failed to close connection") + sc.Logger.Error( + "OnStop", + "err", err, + ) + } + } +} + +func (sc *IPCPV) connect() error { + la, err := net.ResolveUnixAddr("unix", sc.addr) + if err != nil { + return err + } + + conn, err := net.DialUnix("unix", nil, la) + if err != nil { + return err + } + + // Wrap in a timeoutConn + sc.conn = newTimeoutConn(conn, sc.connTimeout) + + return nil +} + +//--------------------------------------------------------- + +// IPCRemoteSigner is a RPC implementation of PrivValidator that listens on a unix socket. +type IPCRemoteSigner struct { + cmn.BaseService + + addr string + chainID string + connDeadline time.Duration + connRetries int + privVal types.PrivValidator + + listener *net.UnixListener +} + +// NewIPCRemoteSigner returns an instance of IPCRemoteSigner. +func NewIPCRemoteSigner( + logger log.Logger, + chainID, socketAddr string, + privVal types.PrivValidator, +) *IPCRemoteSigner { + rs := &IPCRemoteSigner{ + addr: socketAddr, + chainID: chainID, + connDeadline: time.Second * defaultConnDeadlineSeconds, + connRetries: defaultDialRetries, + privVal: privVal, + } + + rs.BaseService = *cmn.NewBaseService(logger, "IPCRemoteSigner", rs) + + return rs +} + +// OnStart implements cmn.Service. +func (rs *IPCRemoteSigner) OnStart() error { + err := rs.listen() + if err != nil { + err = cmn.ErrorWrap(err, "listen") + rs.Logger.Error("OnStart", "err", err) + return err + } + + go func() { + for { + conn, err := rs.listener.AcceptUnix() + if err != nil { + return + } + go rs.handleConnection(conn) + } + }() + + return nil +} + +// OnStop implements cmn.Service. +func (rs *IPCRemoteSigner) OnStop() { + if rs.listener != nil { + if err := rs.listener.Close(); err != nil { + rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed")) + } + } +} + +func (rs *IPCRemoteSigner) listen() error { + la, err := net.ResolveUnixAddr("unix", rs.addr) + if err != nil { + return err + } + + rs.listener, err = net.ListenUnix("unix", la) + + return err +} + +func (rs *IPCRemoteSigner) handleConnection(conn net.Conn) { + for { + if !rs.IsRunning() { + return // Ignore error from listener closing. + } + + // Reset the connection deadline + conn.SetDeadline(time.Now().Add(rs.connDeadline)) + + req, err := readMsg(conn) + if err != nil { + if err != io.EOF { + rs.Logger.Error("handleConnection", "err", err) + } + return + } + + res, err := handleRequest(req, rs.chainID, rs.privVal) + + if err != nil { + // only log the error; we'll reply with an error in res + rs.Logger.Error("handleConnection", "err", err) + } + + err = writeMsg(conn, res) + if err != nil { + rs.Logger.Error("handleConnection", "err", err) + return + } + } +} diff --git a/privval/ipc_test.go b/privval/ipc_test.go new file mode 100644 index 00000000..eb2d987d --- /dev/null +++ b/privval/ipc_test.go @@ -0,0 +1,134 @@ +package privval + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" + "io/ioutil" + "os" + "testing" + "time" +) + +func TestIPCPVVote(t *testing.T) { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV()) + + ts = time.Now() + vType = types.VoteTypePrecommit + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) +} + +func TestIPCPVVoteResetDeadline(t *testing.T) { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV()) + + ts = time.Now() + vType = types.VoteTypePrecommit + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + time.Sleep(3 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) + + // This would exceed the deadline if it was not extended by the previous message + time.Sleep(3 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) +} + +func TestIPCPVVoteKeepalive(t *testing.T) { + var ( + chainID = cmn.RandStr(12) + sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV()) + + ts = time.Now() + vType = types.VoteTypePrecommit + want = &types.Vote{Timestamp: ts, Type: vType} + have = &types.Vote{Timestamp: ts, Type: vType} + ) + defer sc.Stop() + defer rs.Stop() + + time.Sleep(10 * time.Millisecond) + + require.NoError(t, rs.privVal.SignVote(chainID, want)) + require.NoError(t, sc.SignVote(chainID, have)) + assert.Equal(t, want.Signature, have.Signature) +} + +func testSetupIPCSocketPair( + t *testing.T, + chainID string, + privValidator types.PrivValidator, +) (*IPCPV, *IPCRemoteSigner) { + var ( + addr = testUnixAddr() + logger = log.TestingLogger() + privVal = privValidator + readyc = make(chan struct{}) + rs = NewIPCRemoteSigner( + logger, + chainID, + addr, + privVal, + ) + sc = NewIPCPV( + logger, + addr, + ) + ) + + rs.connDeadline = time.Millisecond * 5 + sc.connTimeout = time.Millisecond * 5 + sc.connHeartbeat = time.Millisecond + + testStartIPCRemoteSigner(t, readyc, rs) + + <-readyc + + require.NoError(t, sc.Start()) + assert.True(t, sc.IsRunning()) + + return sc, rs +} + +func testStartIPCRemoteSigner(t *testing.T, readyc chan struct{}, rs *IPCRemoteSigner) { + go func(rs *IPCRemoteSigner) { + require.NoError(t, rs.Start()) + assert.True(t, rs.IsRunning()) + + readyc <- struct{}{} + }(rs) +} + +func testUnixAddr() string { + f, err := ioutil.TempFile("/tmp", "nettest") + if err != nil { + panic(err) + } + addr := f.Name() + f.Close() + os.Remove(addr) + return addr +} diff --git a/privval/rs_client.go b/privval/rs_client.go new file mode 100644 index 00000000..a0ce9c84 --- /dev/null +++ b/privval/rs_client.go @@ -0,0 +1,236 @@ +package privval + +import ( + "fmt" + "github.com/tendermint/tendermint/crypto" + cmn "github.com/tendermint/tendermint/libs/common" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/types" + "net" + "sync" + "time" +) + +// RemoteSignerClient implements PrivValidator, it uses a socket to request signatures +// from an external process. +type RemoteSignerClient struct { + cmn.BaseService + + connHeartbeat time.Duration + + conn net.Conn + lock sync.Mutex + cancelPing chan bool +} + +// Check that RemoteSignerClient implements PrivValidator. +var _ types.PrivValidator = (*RemoteSignerClient)(nil) + +// NewRemoteSignerClient returns an instance of RemoteSignerClient. +func NewRemoteSignerClient( + logger log.Logger, + conn net.Conn, +) *RemoteSignerClient { + sc := &RemoteSignerClient{ + conn: conn, + connHeartbeat: connHeartbeat, + } + + sc.BaseService = *cmn.NewBaseService(logger, "RemoteSignerClient", sc) + + return sc +} + +// GetAddress implements PrivValidator. +func (sc *RemoteSignerClient) GetAddress() types.Address { + addr, err := sc.getAddress() + if err != nil { + panic(err) + } + + return addr +} + +// Address is an alias for PubKey().Address(). +func (sc *RemoteSignerClient) getAddress() (cmn.HexBytes, error) { + p, err := sc.getPubKey() + if err != nil { + return nil, err + } + + return p.Address(), nil +} + +// GetPubKey implements PrivValidator. +func (sc *RemoteSignerClient) GetPubKey() crypto.PubKey { + pubKey, err := sc.getPubKey() + if err != nil { + panic(err) + } + + return pubKey +} + +func (sc *RemoteSignerClient) getPubKey() (crypto.PubKey, error) { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &PubKeyMsg{}) + if err != nil { + return nil, err + } + + res, err := readMsg(sc.conn) + if err != nil { + return nil, err + } + + return res.(*PubKeyMsg).PubKey, nil +} + +// SignVote implements PrivValidator. +func (sc *RemoteSignerClient) SignVote(chainID string, vote *types.Vote) error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + + resp, ok := res.(*SignedVoteResponse) + if !ok { + return ErrUnexpectedResponse + } + if resp.Error != nil { + return fmt.Errorf("remote error occurred: code: %v, description: %s", + resp.Error.Code, + resp.Error.Description) + } + *vote = *resp.Vote + + return nil +} + +// SignProposal implements PrivValidator. +func (sc *RemoteSignerClient) SignProposal( + chainID string, + proposal *types.Proposal, +) error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + resp, ok := res.(*SignedProposalResponse) + if !ok { + return ErrUnexpectedResponse + } + if resp.Error != nil { + return fmt.Errorf("remote error occurred: code: %v, description: %s", + resp.Error.Code, + resp.Error.Description) + } + *proposal = *resp.Proposal + + return nil +} + +// SignHeartbeat implements PrivValidator. +func (sc *RemoteSignerClient) SignHeartbeat( + chainID string, + heartbeat *types.Heartbeat, +) error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + resp, ok := res.(*SignedHeartbeatResponse) + if !ok { + return ErrUnexpectedResponse + } + if resp.Error != nil { + return fmt.Errorf("remote error occurred: code: %v, description: %s", + resp.Error.Code, + resp.Error.Description) + } + *heartbeat = *resp.Heartbeat + + return nil +} + +// Ping is used to check connection health. +func (sc *RemoteSignerClient) Ping() error { + sc.lock.Lock() + defer sc.lock.Unlock() + + err := writeMsg(sc.conn, &PingRequest{}) + if err != nil { + return err + } + + res, err := readMsg(sc.conn) + if err != nil { + return err + } + _, ok := res.(*PingResponse) + if !ok { + return ErrUnexpectedResponse + } + + return nil +} + +// OnStart implements cmn.Service. +func (sc *RemoteSignerClient) OnStart() error { + // Start a routine to keep the connection alive + sc.cancelPing = make(chan bool, 1) + go func() { + for { + select { + case <-time.Tick(sc.connHeartbeat): + err := sc.Ping() + if err != nil { + sc.Logger.Error( + "Ping", + "err", err, + ) + } + case <-sc.cancelPing: + return + } + } + }() + + return nil +} + +// OnStop implements cmn.Service. +func (sc *RemoteSignerClient) OnStop() { + if sc.cancelPing != nil { + select { + case sc.cancelPing <- true: + default: + } + } +} diff --git a/privval/socket.go b/privval/socket.go index cbe339ba..1760745c 100644 --- a/privval/socket.go +++ b/privval/socket.go @@ -36,50 +36,22 @@ var ( var ( acceptDeadline = time.Second * defaultAcceptDeadlineSeconds - connDeadline = time.Second * defaultConnDeadlineSeconds + connTimeout = time.Second * defaultConnDeadlineSeconds connHeartbeat = time.Second * defaultConnHeartBeatSeconds ) -// SocketPVOption sets an optional parameter on the SocketPV. -type SocketPVOption func(*SocketPV) - -// SocketPVAcceptDeadline sets the deadline for the SocketPV listener. -// A zero time value disables the deadline. -func SocketPVAcceptDeadline(deadline time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.acceptDeadline = deadline } -} - -// SocketPVConnDeadline sets the read and write deadline for connections -// from external signing processes. -func SocketPVConnDeadline(deadline time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.connDeadline = deadline } -} - -// SocketPVHeartbeat sets the period on which to check the liveness of the -// connected Signer connections. -func SocketPVHeartbeat(period time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.connHeartbeat = period } -} - -// SocketPVConnWait sets the timeout duration before connection of external -// signing processes are considered to be unsuccessful. -func SocketPVConnWait(timeout time.Duration) SocketPVOption { - return func(sc *SocketPV) { sc.connWaitTimeout = timeout } -} - // SocketPV implements PrivValidator, it uses a socket to request signatures // from an external process. type SocketPV struct { cmn.BaseService + *RemoteSignerClient addr string acceptDeadline time.Duration - connDeadline time.Duration - connHeartbeat time.Duration + connTimeout time.Duration connWaitTimeout time.Duration privKey ed25519.PrivKeyEd25519 - conn net.Conn listener net.Listener lock sync.Mutex cancelPing chan bool @@ -97,177 +69,17 @@ func NewSocketPV( sc := &SocketPV{ addr: socketAddr, acceptDeadline: acceptDeadline, - connDeadline: connDeadline, - connHeartbeat: connHeartbeat, + connTimeout: connTimeout, connWaitTimeout: time.Second * defaultConnWaitSeconds, privKey: privKey, } sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc) + sc.RemoteSignerClient = NewRemoteSignerClient(sc.Logger, nil) return sc } -// GetAddress implements PrivValidator. -func (sc *SocketPV) GetAddress() types.Address { - addr, err := sc.getAddress() - if err != nil { - panic(err) - } - - return addr -} - -// Address is an alias for PubKey().Address(). -func (sc *SocketPV) getAddress() (cmn.HexBytes, error) { - p, err := sc.getPubKey() - if err != nil { - return nil, err - } - - return p.Address(), nil -} - -// GetPubKey implements PrivValidator. -func (sc *SocketPV) GetPubKey() crypto.PubKey { - pubKey, err := sc.getPubKey() - if err != nil { - panic(err) - } - - return pubKey -} - -func (sc *SocketPV) getPubKey() (crypto.PubKey, error) { - sc.lock.Lock() - defer sc.lock.Unlock() - - err := writeMsg(sc.conn, &PubKeyMsg{}) - if err != nil { - return nil, err - } - - res, err := readMsg(sc.conn) - if err != nil { - return nil, err - } - - return res.(*PubKeyMsg).PubKey, nil -} - -// SignVote implements PrivValidator. -func (sc *SocketPV) SignVote(chainID string, vote *types.Vote) error { - sc.lock.Lock() - defer sc.lock.Unlock() - - err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote}) - if err != nil { - return err - } - - res, err := readMsg(sc.conn) - if err != nil { - return err - } - - resp, ok := res.(*SignedVoteResponse) - if !ok { - return ErrUnexpectedResponse - } - if resp.Error != nil { - return fmt.Errorf("remote error occurred: code: %v, description: %s", - resp.Error.Code, - resp.Error.Description) - } - *vote = *resp.Vote - - return nil -} - -// SignProposal implements PrivValidator. -func (sc *SocketPV) SignProposal( - chainID string, - proposal *types.Proposal, -) error { - sc.lock.Lock() - defer sc.lock.Unlock() - - err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal}) - if err != nil { - return err - } - - res, err := readMsg(sc.conn) - if err != nil { - return err - } - resp, ok := res.(*SignedProposalResponse) - if !ok { - return ErrUnexpectedResponse - } - if resp.Error != nil { - return fmt.Errorf("remote error occurred: code: %v, description: %s", - resp.Error.Code, - resp.Error.Description) - } - *proposal = *resp.Proposal - - return nil -} - -// SignHeartbeat implements PrivValidator. -func (sc *SocketPV) SignHeartbeat( - chainID string, - heartbeat *types.Heartbeat, -) error { - sc.lock.Lock() - defer sc.lock.Unlock() - - err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat}) - if err != nil { - return err - } - - res, err := readMsg(sc.conn) - if err != nil { - return err - } - resp, ok := res.(*SignedHeartbeatResponse) - if !ok { - return ErrUnexpectedResponse - } - if resp.Error != nil { - return fmt.Errorf("remote error occurred: code: %v, description: %s", - resp.Error.Code, - resp.Error.Description) - } - *heartbeat = *resp.Heartbeat - - return nil -} - -// Ping is used to check connection health. -func (sc *SocketPV) Ping() error { - sc.lock.Lock() - defer sc.lock.Unlock() - - err := writeMsg(sc.conn, &PingRequest{}) - if err != nil { - return err - } - - res, err := readMsg(sc.conn) - if err != nil { - return err - } - _, ok := res.(*PingResponse) - if !ok { - return ErrUnexpectedResponse - } - - return nil -} - // OnStart implements cmn.Service. func (sc *SocketPV) OnStart() error { if err := sc.listen(); err != nil { @@ -290,37 +102,30 @@ func (sc *SocketPV) OnStart() error { return err } - // Start a routine to keep the connection alive - sc.cancelPing = make(chan bool, 1) - go func() { - for { - select { - case <-time.Tick(sc.connHeartbeat): - err := sc.Ping() - if err != nil { - sc.Logger.Error( - "Ping", - "err", err, - ) - } - case <-sc.cancelPing: - return - } - } - }() - sc.conn = conn + err = sc.RemoteSignerClient.Start() + if err != nil { + err = cmn.ErrorWrap(err, "failed to start RemoteSignerClient") + sc.Logger.Error( + "OnStart", + "err", err, + ) + + return err + } + return nil } // OnStop implements cmn.Service. func (sc *SocketPV) OnStop() { - if sc.cancelPing != nil { - select { - case sc.cancelPing <- true: - default: - } + if err := sc.RemoteSignerClient.Stop(); err != nil { + err = cmn.ErrorWrap(err, "failed to stop RemoteSignerClient") + sc.Logger.Error( + "OnStop", + "err", err, + ) } if sc.conn != nil { @@ -371,7 +176,7 @@ func (sc *SocketPV) listen() error { sc.listener = newTCPTimeoutListener( ln, sc.acceptDeadline, - sc.connDeadline, + sc.connTimeout, sc.connHeartbeat, ) @@ -504,7 +309,7 @@ func (rs *RemoteSigner) connect() (net.Conn, error) { continue } - if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil { + if err := conn.SetDeadline(time.Now().Add(connTimeout)); err != nil { err = cmn.ErrorWrap(err, "setting connection timeout failed") rs.Logger.Error( "connect", @@ -547,39 +352,7 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) { return } - var res SocketPVMsg - - switch r := req.(type) { - case *PubKeyMsg: - var p crypto.PubKey - p = rs.privVal.GetPubKey() - res = &PubKeyMsg{p} - case *SignVoteRequest: - err = rs.privVal.SignVote(rs.chainID, r.Vote) - if err != nil { - res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} - } else { - res = &SignedVoteResponse{r.Vote, nil} - } - case *SignProposalRequest: - err = rs.privVal.SignProposal(rs.chainID, r.Proposal) - if err != nil { - res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} - } else { - res = &SignedProposalResponse{r.Proposal, nil} - } - case *SignHeartbeatRequest: - err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat) - if err != nil { - res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}} - } else { - res = &SignedHeartbeatResponse{r.Heartbeat, nil} - } - case *PingRequest: - res = &PingResponse{} - default: - err = fmt.Errorf("unknown msg: %v", r) - } + res, err := handleRequest(req, rs.chainID, rs.privVal) if err != nil { // only log the error; we'll reply with an error in res @@ -594,6 +367,45 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) { } } +func handleRequest(req SocketPVMsg, chainID string, privVal types.PrivValidator) (SocketPVMsg, error) { + var res SocketPVMsg + var err error + + switch r := req.(type) { + case *PubKeyMsg: + var p crypto.PubKey + p = privVal.GetPubKey() + res = &PubKeyMsg{p} + case *SignVoteRequest: + err = privVal.SignVote(chainID, r.Vote) + if err != nil { + res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}} + } else { + res = &SignedVoteResponse{r.Vote, nil} + } + case *SignProposalRequest: + err = privVal.SignProposal(chainID, r.Proposal) + if err != nil { + res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}} + } else { + res = &SignedProposalResponse{r.Proposal, nil} + } + case *SignHeartbeatRequest: + err = privVal.SignHeartbeat(chainID, r.Heartbeat) + if err != nil { + res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}} + } else { + res = &SignedHeartbeatResponse{r.Heartbeat, nil} + } + case *PingRequest: + res = &PingResponse{} + default: + err = fmt.Errorf("unknown msg: %v", r) + } + + return res, err +} + //--------------------------------------------------------- // SocketPVMsg is sent between RemoteSigner and SocketPV. diff --git a/privval/socket_tcp.go b/privval/socket_tcp.go index c0dccdac..143096fa 100644 --- a/privval/socket_tcp.go +++ b/privval/socket_tcp.go @@ -24,9 +24,9 @@ type tcpTimeoutListener struct { period time.Duration } -// tcpTimeoutConn wraps a *net.TCPConn to standardise protocol timeouts / deadline resets. -type tcpTimeoutConn struct { - *net.TCPConn +// timeoutConn wraps a net.Conn to standardise protocol timeouts / deadline resets. +type timeoutConn struct { + net.Conn connDeadline time.Duration } @@ -45,11 +45,11 @@ func newTCPTimeoutListener( } } -// newTCPTimeoutConn returns an instance of newTCPTimeoutConn. -func newTCPTimeoutConn( - conn *net.TCPConn, - connDeadline time.Duration) *tcpTimeoutConn { - return &tcpTimeoutConn{ +// newTimeoutConn returns an instance of newTCPTimeoutConn. +func newTimeoutConn( + conn net.Conn, + connDeadline time.Duration) *timeoutConn { + return &timeoutConn{ conn, connDeadline, } @@ -67,24 +67,24 @@ func (ln tcpTimeoutListener) Accept() (net.Conn, error) { return nil, err } - // Wrap the TCPConn in our timeout wrapper - conn := newTCPTimeoutConn(tc, ln.connDeadline) + // Wrap the conn in our timeout wrapper + conn := newTimeoutConn(tc, ln.connDeadline) return conn, nil } // Read implements net.Listener. -func (c tcpTimeoutConn) Read(b []byte) (int, error) { +func (c timeoutConn) Read(b []byte) (int, error) { // Reset deadline - c.TCPConn.SetReadDeadline(time.Now().Add(c.connDeadline)) + c.Conn.SetReadDeadline(time.Now().Add(c.connDeadline)) - return c.TCPConn.Read(b) + return c.Conn.Read(b) } // Write implements net.Listener. -func (c tcpTimeoutConn) Write(b []byte) (int, error) { +func (c timeoutConn) Write(b []byte) (int, error) { // Reset deadline - c.TCPConn.SetWriteDeadline(time.Now().Add(c.connDeadline)) + c.Conn.SetWriteDeadline(time.Now().Add(c.connDeadline)) - return c.TCPConn.Write(b) + return c.Conn.Write(b) } diff --git a/privval/socket_test.go b/privval/socket_test.go index 26288b71..2750c9fb 100644 --- a/privval/socket_test.go +++ b/privval/socket_test.go @@ -104,14 +104,14 @@ func TestSocketPVVoteResetDeadline(t *testing.T) { defer sc.Stop() defer rs.Stop() - time.Sleep(800 * time.Microsecond) + time.Sleep(3 * time.Millisecond) require.NoError(t, rs.privVal.SignVote(chainID, want)) require.NoError(t, sc.SignVote(chainID, have)) assert.Equal(t, want.Signature, have.Signature) // This would exceed the deadline if it was not extended by the previous message - time.Sleep(800 * time.Microsecond) + time.Sleep(3 * time.Millisecond) require.NoError(t, rs.privVal.SignVote(chainID, want)) require.NoError(t, sc.SignVote(chainID, have)) @@ -131,7 +131,7 @@ func TestSocketPVVoteKeepalive(t *testing.T) { defer sc.Stop() defer rs.Stop() - time.Sleep(2 * time.Millisecond) + time.Sleep(10 * time.Millisecond) require.NoError(t, rs.privVal.SignVote(chainID, want)) require.NoError(t, sc.SignVote(chainID, have)) @@ -154,21 +154,6 @@ func TestSocketPVHeartbeat(t *testing.T) { assert.Equal(t, want.Signature, have.Signature) } -func TestSocketPVAcceptDeadline(t *testing.T) { - var ( - sc = NewSocketPV( - log.TestingLogger(), - "127.0.0.1:0", - ed25519.GenPrivKey(), - ) - ) - defer sc.Stop() - - SocketPVAcceptDeadline(time.Millisecond)(sc) - - assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout) -} - func TestSocketPVDeadline(t *testing.T) { var ( addr = testFreeAddr(t) @@ -180,8 +165,8 @@ func TestSocketPVDeadline(t *testing.T) { ) ) - SocketPVConnDeadline(100 * time.Millisecond)(sc) - SocketPVConnWait(500 * time.Millisecond)(sc) + sc.connTimeout = 100 * time.Millisecond + sc.connWaitTimeout = 500 * time.Millisecond go func(sc *SocketPV) { defer close(listenc) @@ -212,19 +197,6 @@ func TestSocketPVDeadline(t *testing.T) { assert.Equal(t, err.(cmn.Error).Data(), ErrConnTimeout) } -func TestSocketPVWait(t *testing.T) { - sc := NewSocketPV( - log.TestingLogger(), - "127.0.0.1:0", - ed25519.GenPrivKey(), - ) - defer sc.Stop() - - SocketPVConnWait(time.Millisecond)(sc) - - assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout) -} - func TestRemoteSignerRetry(t *testing.T) { var ( attemptc = make(chan int) @@ -447,13 +419,13 @@ func testSetupSocketPair( ) ) - testStartSocketPV(t, readyc, sc) - - SocketPVConnDeadline(time.Millisecond)(sc) - SocketPVHeartbeat(500 * time.Microsecond)(sc) - RemoteSignerConnDeadline(time.Millisecond)(rs) + sc.connTimeout = 5 * time.Millisecond + sc.connHeartbeat = 2 * time.Millisecond + RemoteSignerConnDeadline(5 * time.Millisecond)(rs) RemoteSignerConnRetries(1e6)(rs) + testStartSocketPV(t, readyc, sc) + require.NoError(t, rs.Start()) assert.True(t, rs.IsRunning())