Implement @melekes comments

This commit is contained in:
Hendrik Hofstadt 2018-10-16 23:44:09 +02:00
parent ec299c2b84
commit 45cd8ebf80
5 changed files with 22 additions and 39 deletions

View File

@ -37,6 +37,7 @@ type IPCVal struct {
conn net.Conn conn net.Conn
cancelPing chan struct{} cancelPing chan struct{}
pingTicker *time.Ticker
} }
// Check that IPCVal implements PrivValidator. // Check that IPCVal implements PrivValidator.
@ -70,15 +71,17 @@ func (sc *IPCVal) OnStart() error {
// Start a routine to keep the connection alive // Start a routine to keep the connection alive
sc.cancelPing = make(chan struct{}, 1) sc.cancelPing = make(chan struct{}, 1)
sc.pingTicker = time.NewTicker(sc.connHeartbeat)
go func() { go func() {
for { for {
select { select {
case <-time.Tick(sc.connHeartbeat): case <-sc.pingTicker.C:
err := sc.Ping() err := sc.Ping()
if err != nil { if err != nil {
sc.Logger.Error("Ping", "err", err) sc.Logger.Error("Ping", "err", err)
} }
case <-sc.cancelPing: case <-sc.cancelPing:
sc.pingTicker.Stop()
return return
} }
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
cmn "github.com/tendermint/tendermint/libs/common" cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types" "github.com/tendermint/tendermint/types"

View File

@ -34,22 +34,12 @@ func NewRemoteSignerClient(
// GetAddress implements PrivValidator. // GetAddress implements PrivValidator.
func (sc *RemoteSignerClient) GetAddress() types.Address { func (sc *RemoteSignerClient) GetAddress() types.Address {
addr, err := sc.getAddress() pubKey, err := sc.getPubKey()
if err != nil { if err != nil {
panic(err) panic(err)
} }
return addr return pubKey.Address()
}
// Address is an alias for PubKey().Address().
func (sc *RemoteSignerClient) getAddress() (cmn.HexBytes, error) {
p, err := sc.getPubKey()
if err != nil {
return nil, err
}
return p.Address(), nil
} }
// GetPubKey implements PrivValidator. // GetPubKey implements PrivValidator.

View File

@ -16,14 +16,12 @@ const (
defaultAcceptDeadlineSeconds = 3 defaultAcceptDeadlineSeconds = 3
defaultConnDeadlineSeconds = 3 defaultConnDeadlineSeconds = 3
defaultConnHeartBeatSeconds = 2 defaultConnHeartBeatSeconds = 2
defaultConnWaitSeconds = 60
defaultDialRetries = 10 defaultDialRetries = 10
) )
// Socket errors. // Socket errors.
var ( var (
ErrDialRetryMax = errors.New("dialed maximum retries") ErrDialRetryMax = errors.New("dialed maximum retries")
ErrConnWaitTimeout = errors.New("waited for remote signer for too long")
ErrConnTimeout = errors.New("remote signer timed out") ErrConnTimeout = errors.New("remote signer timed out")
ErrUnexpectedResponse = errors.New("received unexpected response") ErrUnexpectedResponse = errors.New("received unexpected response")
) )
@ -55,28 +53,22 @@ func TCPValHeartbeat(period time.Duration) TCPValOption {
return func(sc *TCPVal) { sc.connHeartbeat = period } return func(sc *TCPVal) { sc.connHeartbeat = period }
} }
// TCPValConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func TCPValConnWait(timeout time.Duration) TCPValOption {
return func(sc *TCPVal) { sc.connWaitTimeout = timeout }
}
// TCPVal implements PrivValidator, it uses a socket to request signatures // TCPVal implements PrivValidator, it uses a socket to request signatures
// from an external process. // from an external process.
type TCPVal struct { type TCPVal struct {
cmn.BaseService cmn.BaseService
*RemoteSignerClient *RemoteSignerClient
addr string addr string
acceptDeadline time.Duration acceptDeadline time.Duration
connTimeout time.Duration connTimeout time.Duration
connHeartbeat time.Duration connHeartbeat time.Duration
connWaitTimeout time.Duration privKey ed25519.PrivKeyEd25519
privKey ed25519.PrivKeyEd25519
conn net.Conn conn net.Conn
listener net.Listener listener net.Listener
cancelPing chan struct{} cancelPing chan struct{}
pingTicker *time.Ticker
} }
// Check that TCPVal implements PrivValidator. // Check that TCPVal implements PrivValidator.
@ -89,12 +81,11 @@ func NewTCPVal(
privKey ed25519.PrivKeyEd25519, privKey ed25519.PrivKeyEd25519,
) *TCPVal { ) *TCPVal {
sc := &TCPVal{ sc := &TCPVal{
addr: socketAddr, addr: socketAddr,
acceptDeadline: acceptDeadline, acceptDeadline: acceptDeadline,
connTimeout: connTimeout, connTimeout: connTimeout,
connHeartbeat: connHeartbeat, connHeartbeat: connHeartbeat,
connWaitTimeout: time.Second * defaultConnWaitSeconds, privKey: privKey,
privKey: privKey,
} }
sc.BaseService = *cmn.NewBaseService(logger, "TCPVal", sc) sc.BaseService = *cmn.NewBaseService(logger, "TCPVal", sc)
@ -121,10 +112,11 @@ func (sc *TCPVal) OnStart() error {
// Start a routine to keep the connection alive // Start a routine to keep the connection alive
sc.cancelPing = make(chan struct{}, 1) sc.cancelPing = make(chan struct{}, 1)
sc.pingTicker = time.NewTicker(sc.connHeartbeat)
go func() { go func() {
for { for {
select { select {
case <-time.Tick(sc.connHeartbeat): case <-sc.pingTicker.C:
err := sc.Ping() err := sc.Ping()
if err != nil { if err != nil {
sc.Logger.Error( sc.Logger.Error(
@ -133,6 +125,7 @@ func (sc *TCPVal) OnStart() error {
) )
} }
case <-sc.cancelPing: case <-sc.cancelPing:
sc.pingTicker.Stop()
return return
} }
} }
@ -217,7 +210,5 @@ func (sc *TCPVal) waitConnection() (net.Conn, error) {
return conn, nil return conn, nil
case err := <-errc: case err := <-errc:
return nil, err return nil, err
case <-time.After(sc.connWaitTimeout):
return nil, ErrConnWaitTimeout
} }
} }

View File

@ -27,8 +27,7 @@ func TestSocketPVAddress(t *testing.T) {
serverAddr := rs.privVal.GetAddress() serverAddr := rs.privVal.GetAddress()
clientAddr, err := sc.getAddress() clientAddr := sc.GetAddress()
require.NoError(t, err)
assert.Equal(t, serverAddr, clientAddr) assert.Equal(t, serverAddr, clientAddr)
@ -166,7 +165,6 @@ func TestSocketPVDeadline(t *testing.T) {
) )
TCPValConnTimeout(100 * time.Millisecond)(sc) TCPValConnTimeout(100 * time.Millisecond)(sc)
TCPValConnWait(500 * time.Millisecond)(sc)
go func(sc *TCPVal) { go func(sc *TCPVal) {
defer close(listenc) defer close(listenc)