Invert privVal socket communication

Follow-up to #1255 aligning with the expectation that the external
signing process connects to the node. The SocketClient will block on
start until one connection has been established, support for multiple
signers connected simultaneously is a planned future extension.

* SocketClient accepts connection
* PrivValSocketServer renamed to RemoteSigner
* extend tests
This commit is contained in:
Alexander Simmerl 2018-03-06 19:54:49 +01:00
parent 2ce57a65ff
commit 589781721a
No known key found for this signature in database
GPG Key ID: 4694E95C9CC61BDA
3 changed files with 359 additions and 200 deletions

View File

@ -12,36 +12,41 @@ import (
func main() {
var (
addr = flag.String("addr", ":46659", "Address of client to connect to")
chainID = flag.String("chain-id", "mychain", "chain id")
listenAddr = flag.String("laddr", ":46659", "Validator listen address (0.0.0.0:0 means any interface, any port")
maxConn = flag.Int("clients", 3, "maximum of concurrent connections")
privValPath = flag.String("priv", "", "priv val file path")
logger = log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val")
logger = log.NewTMLogger(
log.NewSyncWriter(os.Stdout),
).With("module", "priv_val")
)
flag.Parse()
logger.Info(
"Starting private validator",
"addr", *addr,
"chainID", *chainID,
"listenAddr", *listenAddr,
"maxConn", *maxConn,
"privPath", *privValPath,
)
privVal := priv_val.LoadPrivValidatorJSON(*privValPath)
pvss := priv_val.NewPrivValidatorSocketServer(
rs := priv_val.NewRemoteSigner(
logger,
*chainID,
*listenAddr,
*maxConn,
*addr,
privVal,
nil,
)
pvss.Start()
err := rs.Start()
if err != nil {
panic(err)
}
cmn.TrapSignal(func() {
pvss.Stop()
err := rs.Stop()
if err != nil {
panic(err)
}
})
}

View File

@ -19,12 +19,16 @@ import (
const (
defaultConnDeadlineSeconds = 3
defaultDialRetryMax = 10
defaultConnWaitSeconds = 60
defaultDialRetries = 10
defaultSignersMax = 1
)
// Socket errors.
var (
ErrDialRetryMax = errors.New("Error max client retries")
ErrDialRetryMax = errors.New("Error max client retries")
ErrConnWaitTimeout = errors.New("Error waiting for external connection")
ErrConnTimeout = errors.New("Error connection timed out")
)
var (
@ -34,10 +38,16 @@ var (
// SocketClientOption sets an optional parameter on the SocketClient.
type SocketClientOption func(*SocketClient)
// SocketClientTimeout sets the timeout for connecting to the external socket
// address.
func SocketClientTimeout(timeout time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connectTimeout = timeout }
// SocketClientConnDeadline sets the read and write deadline for connections
// from external signing processes.
func SocketClientConnDeadline(deadline time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connDeadline = deadline }
}
// SocketClientConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func SocketClientConnWait(timeout time.Duration) SocketClientOption {
return func(sc *SocketClient) { sc.connWaitTimeout = timeout }
}
// SocketClient implements PrivValidator, it uses a socket to request signatures
@ -45,11 +55,13 @@ func SocketClientTimeout(timeout time.Duration) SocketClientOption {
type SocketClient struct {
cmn.BaseService
conn net.Conn
privKey *crypto.PrivKeyEd25519
addr string
connDeadline time.Duration
connWaitTimeout time.Duration
privKey *crypto.PrivKeyEd25519
addr string
connectTimeout time.Duration
conn net.Conn
listener net.Listener
}
// Check that SocketClient implements PrivValidator2.
@ -62,24 +74,37 @@ func NewSocketClient(
privKey *crypto.PrivKeyEd25519,
) *SocketClient {
sc := &SocketClient{
addr: socketAddr,
connectTimeout: time.Second * defaultConnDeadlineSeconds,
privKey: privKey,
addr: socketAddr,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connWaitTimeout: time.Second * defaultConnWaitSeconds,
privKey: privKey,
}
sc.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketClient", sc)
sc.BaseService = *cmn.NewBaseService(logger, "SocketClient", sc)
return sc
}
// OnStart implements cmn.Service.
func (sc *SocketClient) OnStart() error {
if err := sc.BaseService.OnStart(); err != nil {
return err
if sc.listener == nil {
if err := sc.listen(); err != nil {
sc.Logger.Error(
"OnStart",
"err", errors.Wrap(err, "failed to listen"),
)
return err
}
}
conn, err := sc.connect()
conn, err := sc.waitConnection()
if err != nil {
sc.Logger.Error(
"OnStart",
"err", errors.Wrap(err, "failed to accept connection"),
)
return err
}
@ -93,7 +118,21 @@ func (sc *SocketClient) OnStop() {
sc.BaseService.OnStop()
if sc.conn != nil {
sc.conn.Close()
if err := sc.conn.Close(); err != nil {
sc.Logger.Error(
"OnStop",
"err", errors.Wrap(err, "failed to close connection"),
)
}
}
if sc.listener != nil {
if err := sc.listener.Close(); err != nil {
sc.Logger.Error(
"OnStop",
"err", errors.Wrap(err, "failed to close listener"),
)
}
}
}
@ -162,7 +201,10 @@ func (sc *SocketClient) SignVote(chainID string, vote *types.Vote) error {
}
// SignProposal implements PrivValidator2.
func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) error {
func (sc *SocketClient) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal})
if err != nil {
return err
@ -179,7 +221,10 @@ func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) e
}
// SignHeartbeat implements PrivValidator2.
func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error {
func (sc *SocketClient) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
if err != nil {
return err
@ -195,22 +240,164 @@ func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat
return nil
}
func (sc *SocketClient) connect() (net.Conn, error) {
retries := defaultDialRetryMax
func (sc *SocketClient) acceptConnection() (net.Conn, error) {
conn, err := sc.listener.Accept()
if err != nil {
if !sc.IsRunning() {
return nil, nil // Ignore error from listener closing.
}
return nil, err
}
if err := conn.SetDeadline(time.Now().Add(sc.connDeadline)); err != nil {
return nil, err
}
if sc.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap())
if err != nil {
return nil, err
}
}
return conn, nil
}
func (sc *SocketClient) listen() error {
ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr))
if err != nil {
return err
}
sc.listener = netutil.LimitListener(ln, defaultSignersMax)
return nil
}
// waitConnection uses the configured wait timeout to error if no external
// process connects in the time period.
func (sc *SocketClient) waitConnection() (net.Conn, error) {
var (
connc = make(chan net.Conn, 1)
errc = make(chan error, 1)
)
go func(connc chan<- net.Conn, errc chan<- error) {
conn, err := sc.acceptConnection()
if err != nil {
errc <- err
return
}
connc <- conn
}(connc, errc)
select {
case conn := <-connc:
return conn, nil
case err := <-errc:
return nil, err
case <-time.After(sc.connWaitTimeout):
return nil, ErrConnWaitTimeout
}
}
//---------------------------------------------------------
// RemoteSignerOption sets an optional parameter on the RemoteSigner.
type RemoteSignerOption func(*RemoteSigner)
// RemoteSignerConnDeadline sets the read and write deadline for connections
// from external signing processes.
func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connDeadline = deadline }
}
// RemoteSignerConnRetries sets the amount of attempted retries to connect.
func RemoteSignerConnRetries(retries int) RemoteSignerOption {
return func(ss *RemoteSigner) { ss.connRetries = retries }
}
// RemoteSigner implements PrivValidator.
// It responds to requests over a socket
type RemoteSigner struct {
cmn.BaseService
addr string
chainID string
connDeadline time.Duration
connRetries int
privKey *crypto.PrivKeyEd25519
privVal PrivValidator
conn net.Conn
}
// NewRemoteSigner returns an instance of
// RemoteSigner.
func NewRemoteSigner(
logger log.Logger,
chainID, socketAddr string,
privVal PrivValidator,
privKey *crypto.PrivKeyEd25519,
) *RemoteSigner {
rs := &RemoteSigner{
addr: socketAddr,
chainID: chainID,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connRetries: defaultDialRetries,
privKey: privKey,
privVal: privVal,
}
rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs)
return rs
}
// OnStart implements cmn.Service.
func (rs *RemoteSigner) OnStart() error {
conn, err := rs.connect()
if err != nil {
rs.Logger.Error("OnStart", "err", errors.Wrap(err, "connect"))
return err
}
go rs.handleConnection(conn)
return nil
}
// OnStop implements cmn.Service.
func (rs *RemoteSigner) OnStop() {
if rs.conn == nil {
return
}
if err := rs.conn.Close(); err != nil {
rs.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed"))
}
}
func (rs *RemoteSigner) connect() (net.Conn, error) {
retries := defaultDialRetries
RETRY_LOOP:
for retries > 0 {
if retries != defaultDialRetryMax {
time.Sleep(sc.connectTimeout)
// Don't sleep if it is the first retry.
if retries != defaultDialRetries {
time.Sleep(rs.connDeadline)
}
retries--
conn, err := cmn.Connect(sc.addr)
conn, err := cmn.Connect(rs.addr)
if err != nil {
sc.Logger.Error(
"sc connect",
"addr", sc.addr,
rs.Logger.Error(
"connect",
"addr", rs.addr,
"err", errors.Wrap(err, "connection failed"),
)
@ -218,17 +405,17 @@ RETRY_LOOP:
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
sc.Logger.Error(
"sc connect",
rs.Logger.Error(
"connect",
"err", errors.Wrap(err, "setting connection timeout failed"),
)
continue
}
if sc.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap())
if rs.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap())
if err != nil {
sc.Logger.Error(
rs.Logger.Error(
"sc connect",
"err", errors.Wrap(err, "encrypting connection failed"),
)
@ -243,118 +430,16 @@ RETRY_LOOP:
return nil, ErrDialRetryMax
}
//---------------------------------------------------------
// PrivValidatorSocketServer implements PrivValidator.
// It responds to requests over a socket
type PrivValidatorSocketServer struct {
cmn.BaseService
proto, addr string
listener net.Listener
maxConnections int
privKey *crypto.PrivKeyEd25519
privVal PrivValidator
chainID string
}
// NewPrivValidatorSocketServer returns an instance of
// PrivValidatorSocketServer.
func NewPrivValidatorSocketServer(
logger log.Logger,
chainID, socketAddr string,
maxConnections int,
privVal PrivValidator,
privKey *crypto.PrivKeyEd25519,
) *PrivValidatorSocketServer {
proto, addr := cmn.ProtocolAndAddress(socketAddr)
pvss := &PrivValidatorSocketServer{
proto: proto,
addr: addr,
maxConnections: maxConnections,
privKey: privKey,
privVal: privVal,
chainID: chainID,
}
pvss.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketServer", pvss)
return pvss
}
// OnStart implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStart() error {
ln, err := net.Listen(pvss.proto, pvss.addr)
if err != nil {
return err
}
pvss.listener = netutil.LimitListener(ln, pvss.maxConnections)
go pvss.acceptConnections()
return nil
}
// OnStop implements cmn.Service.
func (pvss *PrivValidatorSocketServer) OnStop() {
if pvss.listener == nil {
return
}
if err := pvss.listener.Close(); err != nil {
pvss.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed"))
}
}
func (pvss *PrivValidatorSocketServer) acceptConnections() {
func (rs *RemoteSigner) handleConnection(conn net.Conn) {
for {
conn, err := pvss.listener.Accept()
if err != nil {
if !pvss.IsRunning() {
return // Ignore error from listener closing.
}
pvss.Logger.Error(
"acceptConnections",
"err", errors.Wrap(err, "failed to accept connection"),
)
continue
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
pvss.Logger.Error(
"acceptConnetions",
"err", errors.Wrap(err, "setting connection timeout failed"),
)
continue
}
if pvss.privKey != nil {
conn, err = p2pconn.MakeSecretConnection(conn, pvss.privKey.Wrap())
if err != nil {
pvss.Logger.Error(
"acceptConnections",
"err", errors.Wrap(err, "secret connection failed"),
)
continue
}
}
go pvss.handleConnection(conn)
}
}
func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
defer conn.Close()
for {
if !pvss.IsRunning() {
if !rs.IsRunning() {
return // Ignore error from listener closing.
}
req, err := readMsg(conn)
if err != nil {
if err != io.EOF {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
}
return
}
@ -365,29 +450,29 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) {
case *PubKeyMsg:
var p crypto.PubKey
p, err = pvss.privVal.PubKey()
p, err = rs.privVal.PubKey()
res = &PubKeyMsg{p}
case *SignVoteMsg:
err = pvss.privVal.SignVote(pvss.chainID, r.Vote)
err = rs.privVal.SignVote(rs.chainID, r.Vote)
res = &SignVoteMsg{r.Vote}
case *SignProposalMsg:
err = pvss.privVal.SignProposal(pvss.chainID, r.Proposal)
err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
res = &SignProposalMsg{r.Proposal}
case *SignHeartbeatMsg:
err = pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat)
err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
res = &SignHeartbeatMsg{r.Heartbeat}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
return
}
err = writeMsg(conn, res)
if err != nil {
pvss.Logger.Error("handleConnection", "err", err)
rs.Logger.Error("handleConnection", "err", err)
return
}
}
@ -442,6 +527,10 @@ func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) {
read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
return nil, errors.Wrapf(ErrConnTimeout, opErr.Addr.String())
}
return nil, err
}
@ -461,6 +550,9 @@ func writeMsg(w io.Writer, msg interface{}) error {
// TODO(xla): This extra wrap should be gone with the sdk-2 update.
wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err)
if opErr, ok := err.(*net.OpError); ok {
return errors.Wrapf(ErrConnTimeout, opErr.Addr.String())
}
return err
}

View File

@ -4,10 +4,12 @@ import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
crypto "github.com/tendermint/go-crypto"
cmn "github.com/tendermint/tmlibs/common"
"github.com/tendermint/tmlibs/log"
"github.com/tendermint/tendermint/types"
@ -16,13 +18,13 @@ import (
func TestSocketClientAddress(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
serverAddr, err := pvss.privVal.Address()
serverAddr, err := rs.privVal.Address()
require.NoError(err)
clientAddr, err := sc.Address()
@ -38,16 +40,16 @@ func TestSocketClientAddress(t *testing.T) {
func TestSocketClientPubKey(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
clientKey, err := sc.PubKey()
require.NoError(err)
privKey, err := pvss.privVal.PubKey()
privKey, err := rs.privVal.PubKey()
require.NoError(err)
assert.Equal(privKey, clientKey)
@ -59,17 +61,17 @@ func TestSocketClientPubKey(t *testing.T) {
func TestSocketClientProposal(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
privProposal = &types.Proposal{Timestamp: ts}
clientProposal = &types.Proposal{Timestamp: ts}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignProposal(chainID, privProposal))
require.NoError(rs.privVal.SignProposal(chainID, privProposal))
require.NoError(sc.SignProposal(chainID, clientProposal))
assert.Equal(privProposal.Signature, clientProposal.Signature)
}
@ -77,8 +79,8 @@ func TestSocketClientProposal(t *testing.T) {
func TestSocketClientVote(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
ts = time.Now()
vType = types.VoteTypePrecommit
@ -86,9 +88,9 @@ func TestSocketClientVote(t *testing.T) {
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignVote(chainID, want))
require.NoError(rs.privVal.SignVote(chainID, want))
require.NoError(sc.SignVote(chainID, have))
assert.Equal(want.Signature, have.Signature)
}
@ -96,69 +98,129 @@ func TestSocketClientVote(t *testing.T) {
func TestSocketClientHeartbeat(t *testing.T) {
var (
assert, require = assert.New(t), require.New(t)
chainID = "test-chain-secret"
sc, pvss = testSetupSocketPair(t, chainID)
chainID = cmn.RandStr(12)
sc, rs = testSetupSocketPair(t, chainID)
want = &types.Heartbeat{}
have = &types.Heartbeat{}
)
defer sc.Stop()
defer pvss.Stop()
defer rs.Stop()
require.NoError(pvss.privVal.SignHeartbeat(chainID, want))
require.NoError(rs.privVal.SignHeartbeat(chainID, want))
require.NoError(sc.SignHeartbeat(chainID, have))
assert.Equal(want.Signature, have.Signature)
}
func TestSocketClientConnectRetryMax(t *testing.T) {
func TestSocketClientDeadline(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
logger = log.TestingLogger()
clientPrivKey = crypto.GenPrivKeyEd25519()
sc = NewSocketClient(
logger,
assert, require = assert.New(t), require.New(t)
readyc = make(chan struct{})
sc = NewSocketClient(
log.TestingLogger(),
"127.0.0.1:0",
&clientPrivKey,
nil,
)
)
defer sc.Stop()
SocketClientTimeout(time.Millisecond)(sc)
SocketClientConnDeadline(time.Millisecond)(sc)
assert.EqualError(sc.Start(), ErrDialRetryMax.Error())
require.NoError(sc.listen())
go func(sc *SocketClient) {
require.NoError(sc.Start())
assert.True(sc.IsRunning())
readyc <- struct{}{}
}(sc)
_, err := cmn.Connect(sc.listener.Addr().String())
require.NoError(err)
<-readyc
_, err = sc.PubKey()
assert.Equal(errors.Cause(err), ErrConnTimeout)
}
func testSetupSocketPair(t *testing.T, chainID string) (*SocketClient, *PrivValidatorSocketServer) {
func TestSocketClientWait(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
logger = log.TestingLogger()
privKey = crypto.GenPrivKeyEd25519()
sc = NewSocketClient(
logger,
"127.0.0.1:0",
&privKey,
)
)
defer sc.Stop()
SocketClientConnWait(time.Millisecond)(sc)
assert.EqualError(sc.Start(), ErrConnWaitTimeout.Error())
}
func TestRemoteSignerRetry(t *testing.T) {
var (
assert, _ = assert.New(t), require.New(t)
privKey = crypto.GenPrivKeyEd25519()
rs = NewRemoteSigner(
log.TestingLogger(),
cmn.RandStr(12),
"127.0.0.1:0",
NewTestPrivValidator(types.GenSigner()),
&privKey,
)
)
defer rs.Stop()
RemoteSignerConnDeadline(time.Millisecond)(rs)
RemoteSignerConnRetries(2)(rs)
assert.EqualError(rs.Start(), ErrDialRetryMax.Error())
}
func testSetupSocketPair(
t *testing.T,
chainID string,
) (*SocketClient, *RemoteSigner) {
var (
assert, require = assert.New(t), require.New(t)
logger = log.TestingLogger()
signer = types.GenSigner()
clientPrivKey = crypto.GenPrivKeyEd25519()
serverPrivKey = crypto.GenPrivKeyEd25519()
remotePrivKey = crypto.GenPrivKeyEd25519()
privVal = NewTestPrivValidator(signer)
pvss = NewPrivValidatorSocketServer(
readyc = make(chan struct{})
sc = NewSocketClient(
logger,
chainID,
"127.0.0.1:0",
1,
privVal,
&serverPrivKey,
&clientPrivKey,
)
)
err := pvss.Start()
require.NoError(err)
assert.True(pvss.IsRunning())
require.NoError(sc.listen())
sc := NewSocketClient(
go func(sc *SocketClient) {
require.NoError(sc.Start())
assert.True(sc.IsRunning())
readyc <- struct{}{}
}(sc)
rs := NewRemoteSigner(
logger,
pvss.listener.Addr().String(),
&clientPrivKey,
chainID,
sc.listener.Addr().String(),
privVal,
&remotePrivKey,
)
require.NoError(rs.Start())
assert.True(rs.IsRunning())
err = sc.Start()
require.NoError(err)
assert.True(sc.IsRunning())
<-readyc
return sc, pvss
return sc, rs
}