Implement @melekes comments

This commit is contained in:
Hendrik Hofstadt 2018-10-16 23:44:09 +02:00
parent ec299c2b84
commit 45cd8ebf80
5 changed files with 22 additions and 39 deletions

View File

@ -37,6 +37,7 @@ type IPCVal struct {
conn net.Conn
cancelPing chan struct{}
pingTicker *time.Ticker
}
// Check that IPCVal implements PrivValidator.
@ -70,15 +71,17 @@ func (sc *IPCVal) OnStart() error {
// Start a routine to keep the connection alive
sc.cancelPing = make(chan struct{}, 1)
sc.pingTicker = time.NewTicker(sc.connHeartbeat)
go func() {
for {
select {
case <-time.Tick(sc.connHeartbeat):
case <-sc.pingTicker.C:
err := sc.Ping()
if err != nil {
sc.Logger.Error("Ping", "err", err)
}
case <-sc.cancelPing:
sc.pingTicker.Stop()
return
}
}

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"

View File

@ -34,22 +34,12 @@ func NewRemoteSignerClient(
// GetAddress implements PrivValidator.
func (sc *RemoteSignerClient) GetAddress() types.Address {
addr, err := sc.getAddress()
pubKey, err := sc.getPubKey()
if err != nil {
panic(err)
}
return addr
}
// Address is an alias for PubKey().Address().
func (sc *RemoteSignerClient) getAddress() (cmn.HexBytes, error) {
p, err := sc.getPubKey()
if err != nil {
return nil, err
}
return p.Address(), nil
return pubKey.Address()
}
// GetPubKey implements PrivValidator.

View File

@ -16,14 +16,12 @@ const (
defaultAcceptDeadlineSeconds = 3
defaultConnDeadlineSeconds = 3
defaultConnHeartBeatSeconds = 2
defaultConnWaitSeconds = 60
defaultDialRetries = 10
)
// Socket errors.
var (
ErrDialRetryMax = errors.New("dialed maximum retries")
ErrConnWaitTimeout = errors.New("waited for remote signer for too long")
ErrConnTimeout = errors.New("remote signer timed out")
ErrUnexpectedResponse = errors.New("received unexpected response")
)
@ -55,28 +53,22 @@ func TCPValHeartbeat(period time.Duration) TCPValOption {
return func(sc *TCPVal) { sc.connHeartbeat = period }
}
// TCPValConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func TCPValConnWait(timeout time.Duration) TCPValOption {
return func(sc *TCPVal) { sc.connWaitTimeout = timeout }
}
// TCPVal implements PrivValidator, it uses a socket to request signatures
// from an external process.
type TCPVal struct {
cmn.BaseService
*RemoteSignerClient
addr string
acceptDeadline time.Duration
connTimeout time.Duration
connHeartbeat time.Duration
connWaitTimeout time.Duration
privKey ed25519.PrivKeyEd25519
addr string
acceptDeadline time.Duration
connTimeout time.Duration
connHeartbeat time.Duration
privKey ed25519.PrivKeyEd25519
conn net.Conn
listener net.Listener
cancelPing chan struct{}
pingTicker *time.Ticker
}
// Check that TCPVal implements PrivValidator.
@ -89,12 +81,11 @@ func NewTCPVal(
privKey ed25519.PrivKeyEd25519,
) *TCPVal {
sc := &TCPVal{
addr: socketAddr,
acceptDeadline: acceptDeadline,
connTimeout: connTimeout,
connHeartbeat: connHeartbeat,
connWaitTimeout: time.Second * defaultConnWaitSeconds,
privKey: privKey,
addr: socketAddr,
acceptDeadline: acceptDeadline,
connTimeout: connTimeout,
connHeartbeat: connHeartbeat,
privKey: privKey,
}
sc.BaseService = *cmn.NewBaseService(logger, "TCPVal", sc)
@ -121,10 +112,11 @@ func (sc *TCPVal) OnStart() error {
// Start a routine to keep the connection alive
sc.cancelPing = make(chan struct{}, 1)
sc.pingTicker = time.NewTicker(sc.connHeartbeat)
go func() {
for {
select {
case <-time.Tick(sc.connHeartbeat):
case <-sc.pingTicker.C:
err := sc.Ping()
if err != nil {
sc.Logger.Error(
@ -133,6 +125,7 @@ func (sc *TCPVal) OnStart() error {
)
}
case <-sc.cancelPing:
sc.pingTicker.Stop()
return
}
}
@ -217,7 +210,5 @@ func (sc *TCPVal) waitConnection() (net.Conn, error) {
return conn, nil
case err := <-errc:
return nil, err
case <-time.After(sc.connWaitTimeout):
return nil, ErrConnWaitTimeout
}
}

View File

@ -27,8 +27,7 @@ func TestSocketPVAddress(t *testing.T) {
serverAddr := rs.privVal.GetAddress()
clientAddr, err := sc.getAddress()
require.NoError(t, err)
clientAddr := sc.GetAddress()
assert.Equal(t, serverAddr, clientAddr)
@ -166,7 +165,6 @@ func TestSocketPVDeadline(t *testing.T) {
)
TCPValConnTimeout(100 * time.Millisecond)(sc)
TCPValConnWait(500 * time.Millisecond)(sc)
go func(sc *TCPVal) {
defer close(listenc)