Implement IPC PV and abstract socket signing

This commit is contained in:
Hendrik Hofstadt 2018-10-07 20:04:24 +02:00
parent 2e4cae1fdb
commit 03c1fd712a
6 changed files with 678 additions and 306 deletions

218
privval/ipc.go Normal file
View File

@ -0,0 +1,218 @@
package privval
import (
"io"
"net"
"sync"
"time"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
)
// IPCPV implements PrivValidator, it uses a unix socket to request signatures
// from an external process.
type IPCPV struct {
cmn.BaseService
*RemoteSignerClient
addr string
connTimeout time.Duration
lock sync.Mutex
cancelPing chan bool
}
// Check that IPCPV implements PrivValidator.
var _ types.PrivValidator = (*IPCPV)(nil)
// NewIPCPV returns an instance of IPCPV.
func NewIPCPV(
logger log.Logger,
socketAddr string,
) *IPCPV {
sc := &IPCPV{
addr: socketAddr,
connTimeout: connTimeout,
}
sc.BaseService = *cmn.NewBaseService(logger, "IPCPV", sc)
sc.RemoteSignerClient = NewRemoteSignerClient(logger, nil)
return sc
}
// OnStart implements cmn.Service.
func (sc *IPCPV) OnStart() error {
err := sc.connect()
if err != nil {
err = cmn.ErrorWrap(err, "failed to connect")
sc.Logger.Error(
"OnStart",
"err", err,
)
return err
}
err = sc.RemoteSignerClient.Start()
if err != nil {
err = cmn.ErrorWrap(err, "failed to start RemoteSignerClient")
sc.Logger.Error(
"OnStart",
"err", err,
)
return err
}
return nil
}
// OnStop implements cmn.Service.
func (sc *IPCPV) OnStop() {
if err := sc.RemoteSignerClient.Stop(); err != nil {
err = cmn.ErrorWrap(err, "failed to stop RemoteSignerClient")
sc.Logger.Error(
"OnStop",
"err", err,
)
}
if sc.conn != nil {
if err := sc.conn.Close(); err != nil {
err = cmn.ErrorWrap(err, "failed to close connection")
sc.Logger.Error(
"OnStop",
"err", err,
)
}
}
}
func (sc *IPCPV) connect() error {
la, err := net.ResolveUnixAddr("unix", sc.addr)
if err != nil {
return err
}
conn, err := net.DialUnix("unix", nil, la)
if err != nil {
return err
}
// Wrap in a timeoutConn
sc.conn = newTimeoutConn(conn, sc.connTimeout)
return nil
}
//---------------------------------------------------------
// IPCRemoteSigner is a RPC implementation of PrivValidator that listens on a unix socket.
type IPCRemoteSigner struct {
cmn.BaseService
addr string
chainID string
connDeadline time.Duration
connRetries int
privVal types.PrivValidator
listener *net.UnixListener
}
// NewIPCRemoteSigner returns an instance of IPCRemoteSigner.
func NewIPCRemoteSigner(
logger log.Logger,
chainID, socketAddr string,
privVal types.PrivValidator,
) *IPCRemoteSigner {
rs := &IPCRemoteSigner{
addr: socketAddr,
chainID: chainID,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connRetries: defaultDialRetries,
privVal: privVal,
}
rs.BaseService = *cmn.NewBaseService(logger, "IPCRemoteSigner", rs)
return rs
}
// OnStart implements cmn.Service.
func (rs *IPCRemoteSigner) OnStart() error {
err := rs.listen()
if err != nil {
err = cmn.ErrorWrap(err, "listen")
rs.Logger.Error("OnStart", "err", err)
return err
}
go func() {
for {
conn, err := rs.listener.AcceptUnix()
if err != nil {
return
}
go rs.handleConnection(conn)
}
}()
return nil
}
// OnStop implements cmn.Service.
func (rs *IPCRemoteSigner) OnStop() {
if rs.listener != nil {
if err := rs.listener.Close(); err != nil {
rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed"))
}
}
}
func (rs *IPCRemoteSigner) listen() error {
la, err := net.ResolveUnixAddr("unix", rs.addr)
if err != nil {
return err
}
rs.listener, err = net.ListenUnix("unix", la)
return err
}
func (rs *IPCRemoteSigner) handleConnection(conn net.Conn) {
for {
if !rs.IsRunning() {
return // Ignore error from listener closing.
}
// Reset the connection deadline
conn.SetDeadline(time.Now().Add(rs.connDeadline))
req, err := readMsg(conn)
if err != nil {
if err != io.EOF {
rs.Logger.Error("handleConnection", "err", err)
}
return
}
res, err := handleRequest(req, rs.chainID, rs.privVal)
if err != nil {
// only log the error; we'll reply with an error in res
rs.Logger.Error("handleConnection", "err", err)
}
err = writeMsg(conn, res)
if err != nil {
rs.Logger.Error("handleConnection", "err", err)
return
}
}
}

134
privval/ipc_test.go Normal file
View File

@ -0,0 +1,134 @@
package privval
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
"io/ioutil"
"os"
"testing"
"time"
)
func TestIPCPVVote(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV())
ts = time.Now()
vType = types.VoteTypePrecommit
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer rs.Stop()
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func TestIPCPVVoteResetDeadline(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV())
ts = time.Now()
vType = types.VoteTypePrecommit
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer rs.Stop()
time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
// This would exceed the deadline if it was not extended by the previous message
time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func TestIPCPVVoteKeepalive(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV())
ts = time.Now()
vType = types.VoteTypePrecommit
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer rs.Stop()
time.Sleep(10 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func testSetupIPCSocketPair(
t *testing.T,
chainID string,
privValidator types.PrivValidator,
) (*IPCPV, *IPCRemoteSigner) {
var (
addr = testUnixAddr()
logger = log.TestingLogger()
privVal = privValidator
readyc = make(chan struct{})
rs = NewIPCRemoteSigner(
logger,
chainID,
addr,
privVal,
)
sc = NewIPCPV(
logger,
addr,
)
)
rs.connDeadline = time.Millisecond * 5
sc.connTimeout = time.Millisecond * 5
sc.connHeartbeat = time.Millisecond
testStartIPCRemoteSigner(t, readyc, rs)
<-readyc
require.NoError(t, sc.Start())
assert.True(t, sc.IsRunning())
return sc, rs
}
func testStartIPCRemoteSigner(t *testing.T, readyc chan struct{}, rs *IPCRemoteSigner) {
go func(rs *IPCRemoteSigner) {
require.NoError(t, rs.Start())
assert.True(t, rs.IsRunning())
readyc <- struct{}{}
}(rs)
}
func testUnixAddr() string {
f, err := ioutil.TempFile("/tmp", "nettest")
if err != nil {
panic(err)
}
addr := f.Name()
f.Close()
os.Remove(addr)
return addr
}

236
privval/rs_client.go Normal file
View File

@ -0,0 +1,236 @@
package privval
import (
"fmt"
"github.com/tendermint/tendermint/crypto"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
"net"
"sync"
"time"
)
// RemoteSignerClient implements PrivValidator, it uses a socket to request signatures
// from an external process.
type RemoteSignerClient struct {
cmn.BaseService
connHeartbeat time.Duration
conn net.Conn
lock sync.Mutex
cancelPing chan bool
}
// Check that RemoteSignerClient implements PrivValidator.
var _ types.PrivValidator = (*RemoteSignerClient)(nil)
// NewRemoteSignerClient returns an instance of RemoteSignerClient.
func NewRemoteSignerClient(
logger log.Logger,
conn net.Conn,
) *RemoteSignerClient {
sc := &RemoteSignerClient{
conn: conn,
connHeartbeat: connHeartbeat,
}
sc.BaseService = *cmn.NewBaseService(logger, "RemoteSignerClient", sc)
return sc
}
// GetAddress implements PrivValidator.
func (sc *RemoteSignerClient) GetAddress() types.Address {
addr, err := sc.getAddress()
if err != nil {
panic(err)
}
return addr
}
// Address is an alias for PubKey().Address().
func (sc *RemoteSignerClient) getAddress() (cmn.HexBytes, error) {
p, err := sc.getPubKey()
if err != nil {
return nil, err
}
return p.Address(), nil
}
// GetPubKey implements PrivValidator.
func (sc *RemoteSignerClient) GetPubKey() crypto.PubKey {
pubKey, err := sc.getPubKey()
if err != nil {
panic(err)
}
return pubKey
}
func (sc *RemoteSignerClient) getPubKey() (crypto.PubKey, error) {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PubKeyMsg{})
if err != nil {
return nil, err
}
res, err := readMsg(sc.conn)
if err != nil {
return nil, err
}
return res.(*PubKeyMsg).PubKey, nil
}
// SignVote implements PrivValidator.
func (sc *RemoteSignerClient) SignVote(chainID string, vote *types.Vote) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedVoteResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*vote = *resp.Vote
return nil
}
// SignProposal implements PrivValidator.
func (sc *RemoteSignerClient) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedProposalResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*proposal = *resp.Proposal
return nil
}
// SignHeartbeat implements PrivValidator.
func (sc *RemoteSignerClient) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedHeartbeatResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*heartbeat = *resp.Heartbeat
return nil
}
// Ping is used to check connection health.
func (sc *RemoteSignerClient) Ping() error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PingRequest{})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
_, ok := res.(*PingResponse)
if !ok {
return ErrUnexpectedResponse
}
return nil
}
// OnStart implements cmn.Service.
func (sc *RemoteSignerClient) OnStart() error {
// Start a routine to keep the connection alive
sc.cancelPing = make(chan bool, 1)
go func() {
for {
select {
case <-time.Tick(sc.connHeartbeat):
err := sc.Ping()
if err != nil {
sc.Logger.Error(
"Ping",
"err", err,
)
}
case <-sc.cancelPing:
return
}
}
}()
return nil
}
// OnStop implements cmn.Service.
func (sc *RemoteSignerClient) OnStop() {
if sc.cancelPing != nil {
select {
case sc.cancelPing <- true:
default:
}
}
}

View File

@ -36,50 +36,22 @@ var (
var ( var (
acceptDeadline = time.Second * defaultAcceptDeadlineSeconds acceptDeadline = time.Second * defaultAcceptDeadlineSeconds
connDeadline = time.Second * defaultConnDeadlineSeconds connTimeout = time.Second * defaultConnDeadlineSeconds
connHeartbeat = time.Second * defaultConnHeartBeatSeconds connHeartbeat = time.Second * defaultConnHeartBeatSeconds
) )
// SocketPVOption sets an optional parameter on the SocketPV.
type SocketPVOption func(*SocketPV)
// SocketPVAcceptDeadline sets the deadline for the SocketPV listener.
// A zero time value disables the deadline.
func SocketPVAcceptDeadline(deadline time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.acceptDeadline = deadline }
}
// SocketPVConnDeadline sets the read and write deadline for connections
// from external signing processes.
func SocketPVConnDeadline(deadline time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connDeadline = deadline }
}
// SocketPVHeartbeat sets the period on which to check the liveness of the
// connected Signer connections.
func SocketPVHeartbeat(period time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connHeartbeat = period }
}
// SocketPVConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func SocketPVConnWait(timeout time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connWaitTimeout = timeout }
}
// SocketPV implements PrivValidator, it uses a socket to request signatures // SocketPV implements PrivValidator, it uses a socket to request signatures
// from an external process. // from an external process.
type SocketPV struct { type SocketPV struct {
cmn.BaseService cmn.BaseService
*RemoteSignerClient
addr string addr string
acceptDeadline time.Duration acceptDeadline time.Duration
connDeadline time.Duration connTimeout time.Duration
connHeartbeat time.Duration
connWaitTimeout time.Duration connWaitTimeout time.Duration
privKey ed25519.PrivKeyEd25519 privKey ed25519.PrivKeyEd25519
conn net.Conn
listener net.Listener listener net.Listener
lock sync.Mutex lock sync.Mutex
cancelPing chan bool cancelPing chan bool
@ -97,177 +69,17 @@ func NewSocketPV(
sc := &SocketPV{ sc := &SocketPV{
addr: socketAddr, addr: socketAddr,
acceptDeadline: acceptDeadline, acceptDeadline: acceptDeadline,
connDeadline: connDeadline, connTimeout: connTimeout,
connHeartbeat: connHeartbeat,
connWaitTimeout: time.Second * defaultConnWaitSeconds, connWaitTimeout: time.Second * defaultConnWaitSeconds,
privKey: privKey, privKey: privKey,
} }
sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc) sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc)
sc.RemoteSignerClient = NewRemoteSignerClient(sc.Logger, nil)
return sc return sc
} }
// GetAddress implements PrivValidator.
func (sc *SocketPV) GetAddress() types.Address {
addr, err := sc.getAddress()
if err != nil {
panic(err)
}
return addr
}
// Address is an alias for PubKey().Address().
func (sc *SocketPV) getAddress() (cmn.HexBytes, error) {
p, err := sc.getPubKey()
if err != nil {
return nil, err
}
return p.Address(), nil
}
// GetPubKey implements PrivValidator.
func (sc *SocketPV) GetPubKey() crypto.PubKey {
pubKey, err := sc.getPubKey()
if err != nil {
panic(err)
}
return pubKey
}
func (sc *SocketPV) getPubKey() (crypto.PubKey, error) {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PubKeyMsg{})
if err != nil {
return nil, err
}
res, err := readMsg(sc.conn)
if err != nil {
return nil, err
}
return res.(*PubKeyMsg).PubKey, nil
}
// SignVote implements PrivValidator.
func (sc *SocketPV) SignVote(chainID string, vote *types.Vote) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedVoteResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*vote = *resp.Vote
return nil
}
// SignProposal implements PrivValidator.
func (sc *SocketPV) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedProposalResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*proposal = *resp.Proposal
return nil
}
// SignHeartbeat implements PrivValidator.
func (sc *SocketPV) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedHeartbeatResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*heartbeat = *resp.Heartbeat
return nil
}
// Ping is used to check connection health.
func (sc *SocketPV) Ping() error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PingRequest{})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
_, ok := res.(*PingResponse)
if !ok {
return ErrUnexpectedResponse
}
return nil
}
// OnStart implements cmn.Service. // OnStart implements cmn.Service.
func (sc *SocketPV) OnStart() error { func (sc *SocketPV) OnStart() error {
if err := sc.listen(); err != nil { if err := sc.listen(); err != nil {
@ -290,37 +102,30 @@ func (sc *SocketPV) OnStart() error {
return err return err
} }
// Start a routine to keep the connection alive sc.conn = conn
sc.cancelPing = make(chan bool, 1)
go func() { err = sc.RemoteSignerClient.Start()
for {
select {
case <-time.Tick(sc.connHeartbeat):
err := sc.Ping()
if err != nil { if err != nil {
err = cmn.ErrorWrap(err, "failed to start RemoteSignerClient")
sc.Logger.Error( sc.Logger.Error(
"Ping", "OnStart",
"err", err, "err", err,
) )
}
case <-sc.cancelPing:
return
}
}
}()
sc.conn = conn return err
}
return nil return nil
} }
// OnStop implements cmn.Service. // OnStop implements cmn.Service.
func (sc *SocketPV) OnStop() { func (sc *SocketPV) OnStop() {
if sc.cancelPing != nil { if err := sc.RemoteSignerClient.Stop(); err != nil {
select { err = cmn.ErrorWrap(err, "failed to stop RemoteSignerClient")
case sc.cancelPing <- true: sc.Logger.Error(
default: "OnStop",
} "err", err,
)
} }
if sc.conn != nil { if sc.conn != nil {
@ -371,7 +176,7 @@ func (sc *SocketPV) listen() error {
sc.listener = newTCPTimeoutListener( sc.listener = newTCPTimeoutListener(
ln, ln,
sc.acceptDeadline, sc.acceptDeadline,
sc.connDeadline, sc.connTimeout,
sc.connHeartbeat, sc.connHeartbeat,
) )
@ -504,7 +309,7 @@ func (rs *RemoteSigner) connect() (net.Conn, error) {
continue continue
} }
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil { if err := conn.SetDeadline(time.Now().Add(connTimeout)); err != nil {
err = cmn.ErrorWrap(err, "setting connection timeout failed") err = cmn.ErrorWrap(err, "setting connection timeout failed")
rs.Logger.Error( rs.Logger.Error(
"connect", "connect",
@ -547,39 +352,7 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) {
return return
} }
var res SocketPVMsg res, err := handleRequest(req, rs.chainID, rs.privVal)
switch r := req.(type) {
case *PubKeyMsg:
var p crypto.PubKey
p = rs.privVal.GetPubKey()
res = &PubKeyMsg{p}
case *SignVoteRequest:
err = rs.privVal.SignVote(rs.chainID, r.Vote)
if err != nil {
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedVoteResponse{r.Vote, nil}
}
case *SignProposalRequest:
err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
if err != nil {
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedProposalResponse{r.Proposal, nil}
}
case *SignHeartbeatRequest:
err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
if err != nil {
res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedHeartbeatResponse{r.Heartbeat, nil}
}
case *PingRequest:
res = &PingResponse{}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
if err != nil { if err != nil {
// only log the error; we'll reply with an error in res // only log the error; we'll reply with an error in res
@ -594,6 +367,45 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) {
} }
} }
func handleRequest(req SocketPVMsg, chainID string, privVal types.PrivValidator) (SocketPVMsg, error) {
var res SocketPVMsg
var err error
switch r := req.(type) {
case *PubKeyMsg:
var p crypto.PubKey
p = privVal.GetPubKey()
res = &PubKeyMsg{p}
case *SignVoteRequest:
err = privVal.SignVote(chainID, r.Vote)
if err != nil {
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedVoteResponse{r.Vote, nil}
}
case *SignProposalRequest:
err = privVal.SignProposal(chainID, r.Proposal)
if err != nil {
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedProposalResponse{r.Proposal, nil}
}
case *SignHeartbeatRequest:
err = privVal.SignHeartbeat(chainID, r.Heartbeat)
if err != nil {
res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedHeartbeatResponse{r.Heartbeat, nil}
}
case *PingRequest:
res = &PingResponse{}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
return res, err
}
//--------------------------------------------------------- //---------------------------------------------------------
// SocketPVMsg is sent between RemoteSigner and SocketPV. // SocketPVMsg is sent between RemoteSigner and SocketPV.

View File

@ -24,9 +24,9 @@ type tcpTimeoutListener struct {
period time.Duration period time.Duration
} }
// tcpTimeoutConn wraps a *net.TCPConn to standardise protocol timeouts / deadline resets. // timeoutConn wraps a net.Conn to standardise protocol timeouts / deadline resets.
type tcpTimeoutConn struct { type timeoutConn struct {
*net.TCPConn net.Conn
connDeadline time.Duration connDeadline time.Duration
} }
@ -45,11 +45,11 @@ func newTCPTimeoutListener(
} }
} }
// newTCPTimeoutConn returns an instance of newTCPTimeoutConn. // newTimeoutConn returns an instance of newTCPTimeoutConn.
func newTCPTimeoutConn( func newTimeoutConn(
conn *net.TCPConn, conn net.Conn,
connDeadline time.Duration) *tcpTimeoutConn { connDeadline time.Duration) *timeoutConn {
return &tcpTimeoutConn{ return &timeoutConn{
conn, conn,
connDeadline, connDeadline,
} }
@ -67,24 +67,24 @@ func (ln tcpTimeoutListener) Accept() (net.Conn, error) {
return nil, err return nil, err
} }
// Wrap the TCPConn in our timeout wrapper // Wrap the conn in our timeout wrapper
conn := newTCPTimeoutConn(tc, ln.connDeadline) conn := newTimeoutConn(tc, ln.connDeadline)
return conn, nil return conn, nil
} }
// Read implements net.Listener. // Read implements net.Listener.
func (c tcpTimeoutConn) Read(b []byte) (int, error) { func (c timeoutConn) Read(b []byte) (int, error) {
// Reset deadline // Reset deadline
c.TCPConn.SetReadDeadline(time.Now().Add(c.connDeadline)) c.Conn.SetReadDeadline(time.Now().Add(c.connDeadline))
return c.TCPConn.Read(b) return c.Conn.Read(b)
} }
// Write implements net.Listener. // Write implements net.Listener.
func (c tcpTimeoutConn) Write(b []byte) (int, error) { func (c timeoutConn) Write(b []byte) (int, error) {
// Reset deadline // Reset deadline
c.TCPConn.SetWriteDeadline(time.Now().Add(c.connDeadline)) c.Conn.SetWriteDeadline(time.Now().Add(c.connDeadline))
return c.TCPConn.Write(b) return c.Conn.Write(b)
} }

View File

@ -104,14 +104,14 @@ func TestSocketPVVoteResetDeadline(t *testing.T) {
defer sc.Stop() defer sc.Stop()
defer rs.Stop() defer rs.Stop()
time.Sleep(800 * time.Microsecond) time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want)) require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have)) require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature) assert.Equal(t, want.Signature, have.Signature)
// This would exceed the deadline if it was not extended by the previous message // This would exceed the deadline if it was not extended by the previous message
time.Sleep(800 * time.Microsecond) time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want)) require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have)) require.NoError(t, sc.SignVote(chainID, have))
@ -131,7 +131,7 @@ func TestSocketPVVoteKeepalive(t *testing.T) {
defer sc.Stop() defer sc.Stop()
defer rs.Stop() defer rs.Stop()
time.Sleep(2 * time.Millisecond) time.Sleep(10 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want)) require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have)) require.NoError(t, sc.SignVote(chainID, have))
@ -154,21 +154,6 @@ func TestSocketPVHeartbeat(t *testing.T) {
assert.Equal(t, want.Signature, have.Signature) assert.Equal(t, want.Signature, have.Signature)
} }
func TestSocketPVAcceptDeadline(t *testing.T) {
var (
sc = NewSocketPV(
log.TestingLogger(),
"127.0.0.1:0",
ed25519.GenPrivKey(),
)
)
defer sc.Stop()
SocketPVAcceptDeadline(time.Millisecond)(sc)
assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout)
}
func TestSocketPVDeadline(t *testing.T) { func TestSocketPVDeadline(t *testing.T) {
var ( var (
addr = testFreeAddr(t) addr = testFreeAddr(t)
@ -180,8 +165,8 @@ func TestSocketPVDeadline(t *testing.T) {
) )
) )
SocketPVConnDeadline(100 * time.Millisecond)(sc) sc.connTimeout = 100 * time.Millisecond
SocketPVConnWait(500 * time.Millisecond)(sc) sc.connWaitTimeout = 500 * time.Millisecond
go func(sc *SocketPV) { go func(sc *SocketPV) {
defer close(listenc) defer close(listenc)
@ -212,19 +197,6 @@ func TestSocketPVDeadline(t *testing.T) {
assert.Equal(t, err.(cmn.Error).Data(), ErrConnTimeout) assert.Equal(t, err.(cmn.Error).Data(), ErrConnTimeout)
} }
func TestSocketPVWait(t *testing.T) {
sc := NewSocketPV(
log.TestingLogger(),
"127.0.0.1:0",
ed25519.GenPrivKey(),
)
defer sc.Stop()
SocketPVConnWait(time.Millisecond)(sc)
assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout)
}
func TestRemoteSignerRetry(t *testing.T) { func TestRemoteSignerRetry(t *testing.T) {
var ( var (
attemptc = make(chan int) attemptc = make(chan int)
@ -447,13 +419,13 @@ func testSetupSocketPair(
) )
) )
testStartSocketPV(t, readyc, sc) sc.connTimeout = 5 * time.Millisecond
sc.connHeartbeat = 2 * time.Millisecond
SocketPVConnDeadline(time.Millisecond)(sc) RemoteSignerConnDeadline(5 * time.Millisecond)(rs)
SocketPVHeartbeat(500 * time.Microsecond)(sc)
RemoteSignerConnDeadline(time.Millisecond)(rs)
RemoteSignerConnRetries(1e6)(rs) RemoteSignerConnRetries(1e6)(rs)
testStartSocketPV(t, readyc, sc)
require.NoError(t, rs.Start()) require.NoError(t, rs.Start())
assert.True(t, rs.IsRunning()) assert.True(t, rs.IsRunning())