Implement IPC PV and abstract socket signing

This commit is contained in:
Hendrik Hofstadt 2018-10-07 20:04:24 +02:00
parent 2e4cae1fdb
commit 03c1fd712a
6 changed files with 678 additions and 306 deletions

218
privval/ipc.go Normal file
View File

@ -0,0 +1,218 @@
package privval
import (
"io"
"net"
"sync"
"time"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
)
// IPCPV implements PrivValidator, it uses a unix socket to request signatures
// from an external process.
type IPCPV struct {
cmn.BaseService
*RemoteSignerClient
addr string
connTimeout time.Duration
lock sync.Mutex
cancelPing chan bool
}
// Check that IPCPV implements PrivValidator.
var _ types.PrivValidator = (*IPCPV)(nil)
// NewIPCPV returns an instance of IPCPV.
func NewIPCPV(
logger log.Logger,
socketAddr string,
) *IPCPV {
sc := &IPCPV{
addr: socketAddr,
connTimeout: connTimeout,
}
sc.BaseService = *cmn.NewBaseService(logger, "IPCPV", sc)
sc.RemoteSignerClient = NewRemoteSignerClient(logger, nil)
return sc
}
// OnStart implements cmn.Service.
func (sc *IPCPV) OnStart() error {
err := sc.connect()
if err != nil {
err = cmn.ErrorWrap(err, "failed to connect")
sc.Logger.Error(
"OnStart",
"err", err,
)
return err
}
err = sc.RemoteSignerClient.Start()
if err != nil {
err = cmn.ErrorWrap(err, "failed to start RemoteSignerClient")
sc.Logger.Error(
"OnStart",
"err", err,
)
return err
}
return nil
}
// OnStop implements cmn.Service.
func (sc *IPCPV) OnStop() {
if err := sc.RemoteSignerClient.Stop(); err != nil {
err = cmn.ErrorWrap(err, "failed to stop RemoteSignerClient")
sc.Logger.Error(
"OnStop",
"err", err,
)
}
if sc.conn != nil {
if err := sc.conn.Close(); err != nil {
err = cmn.ErrorWrap(err, "failed to close connection")
sc.Logger.Error(
"OnStop",
"err", err,
)
}
}
}
func (sc *IPCPV) connect() error {
la, err := net.ResolveUnixAddr("unix", sc.addr)
if err != nil {
return err
}
conn, err := net.DialUnix("unix", nil, la)
if err != nil {
return err
}
// Wrap in a timeoutConn
sc.conn = newTimeoutConn(conn, sc.connTimeout)
return nil
}
//---------------------------------------------------------
// IPCRemoteSigner is a RPC implementation of PrivValidator that listens on a unix socket.
type IPCRemoteSigner struct {
cmn.BaseService
addr string
chainID string
connDeadline time.Duration
connRetries int
privVal types.PrivValidator
listener *net.UnixListener
}
// NewIPCRemoteSigner returns an instance of IPCRemoteSigner.
func NewIPCRemoteSigner(
logger log.Logger,
chainID, socketAddr string,
privVal types.PrivValidator,
) *IPCRemoteSigner {
rs := &IPCRemoteSigner{
addr: socketAddr,
chainID: chainID,
connDeadline: time.Second * defaultConnDeadlineSeconds,
connRetries: defaultDialRetries,
privVal: privVal,
}
rs.BaseService = *cmn.NewBaseService(logger, "IPCRemoteSigner", rs)
return rs
}
// OnStart implements cmn.Service.
func (rs *IPCRemoteSigner) OnStart() error {
err := rs.listen()
if err != nil {
err = cmn.ErrorWrap(err, "listen")
rs.Logger.Error("OnStart", "err", err)
return err
}
go func() {
for {
conn, err := rs.listener.AcceptUnix()
if err != nil {
return
}
go rs.handleConnection(conn)
}
}()
return nil
}
// OnStop implements cmn.Service.
func (rs *IPCRemoteSigner) OnStop() {
if rs.listener != nil {
if err := rs.listener.Close(); err != nil {
rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed"))
}
}
}
func (rs *IPCRemoteSigner) listen() error {
la, err := net.ResolveUnixAddr("unix", rs.addr)
if err != nil {
return err
}
rs.listener, err = net.ListenUnix("unix", la)
return err
}
func (rs *IPCRemoteSigner) handleConnection(conn net.Conn) {
for {
if !rs.IsRunning() {
return // Ignore error from listener closing.
}
// Reset the connection deadline
conn.SetDeadline(time.Now().Add(rs.connDeadline))
req, err := readMsg(conn)
if err != nil {
if err != io.EOF {
rs.Logger.Error("handleConnection", "err", err)
}
return
}
res, err := handleRequest(req, rs.chainID, rs.privVal)
if err != nil {
// only log the error; we'll reply with an error in res
rs.Logger.Error("handleConnection", "err", err)
}
err = writeMsg(conn, res)
if err != nil {
rs.Logger.Error("handleConnection", "err", err)
return
}
}
}

134
privval/ipc_test.go Normal file
View File

@ -0,0 +1,134 @@
package privval
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
"io/ioutil"
"os"
"testing"
"time"
)
func TestIPCPVVote(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV())
ts = time.Now()
vType = types.VoteTypePrecommit
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer rs.Stop()
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func TestIPCPVVoteResetDeadline(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV())
ts = time.Now()
vType = types.VoteTypePrecommit
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer rs.Stop()
time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
// This would exceed the deadline if it was not extended by the previous message
time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func TestIPCPVVoteKeepalive(t *testing.T) {
var (
chainID = cmn.RandStr(12)
sc, rs = testSetupIPCSocketPair(t, chainID, types.NewMockPV())
ts = time.Now()
vType = types.VoteTypePrecommit
want = &types.Vote{Timestamp: ts, Type: vType}
have = &types.Vote{Timestamp: ts, Type: vType}
)
defer sc.Stop()
defer rs.Stop()
time.Sleep(10 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
}
func testSetupIPCSocketPair(
t *testing.T,
chainID string,
privValidator types.PrivValidator,
) (*IPCPV, *IPCRemoteSigner) {
var (
addr = testUnixAddr()
logger = log.TestingLogger()
privVal = privValidator
readyc = make(chan struct{})
rs = NewIPCRemoteSigner(
logger,
chainID,
addr,
privVal,
)
sc = NewIPCPV(
logger,
addr,
)
)
rs.connDeadline = time.Millisecond * 5
sc.connTimeout = time.Millisecond * 5
sc.connHeartbeat = time.Millisecond
testStartIPCRemoteSigner(t, readyc, rs)
<-readyc
require.NoError(t, sc.Start())
assert.True(t, sc.IsRunning())
return sc, rs
}
func testStartIPCRemoteSigner(t *testing.T, readyc chan struct{}, rs *IPCRemoteSigner) {
go func(rs *IPCRemoteSigner) {
require.NoError(t, rs.Start())
assert.True(t, rs.IsRunning())
readyc <- struct{}{}
}(rs)
}
func testUnixAddr() string {
f, err := ioutil.TempFile("/tmp", "nettest")
if err != nil {
panic(err)
}
addr := f.Name()
f.Close()
os.Remove(addr)
return addr
}

236
privval/rs_client.go Normal file
View File

@ -0,0 +1,236 @@
package privval
import (
"fmt"
"github.com/tendermint/tendermint/crypto"
cmn "github.com/tendermint/tendermint/libs/common"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/types"
"net"
"sync"
"time"
)
// RemoteSignerClient implements PrivValidator, it uses a socket to request signatures
// from an external process.
type RemoteSignerClient struct {
cmn.BaseService
connHeartbeat time.Duration
conn net.Conn
lock sync.Mutex
cancelPing chan bool
}
// Check that RemoteSignerClient implements PrivValidator.
var _ types.PrivValidator = (*RemoteSignerClient)(nil)
// NewRemoteSignerClient returns an instance of RemoteSignerClient.
func NewRemoteSignerClient(
logger log.Logger,
conn net.Conn,
) *RemoteSignerClient {
sc := &RemoteSignerClient{
conn: conn,
connHeartbeat: connHeartbeat,
}
sc.BaseService = *cmn.NewBaseService(logger, "RemoteSignerClient", sc)
return sc
}
// GetAddress implements PrivValidator.
func (sc *RemoteSignerClient) GetAddress() types.Address {
addr, err := sc.getAddress()
if err != nil {
panic(err)
}
return addr
}
// Address is an alias for PubKey().Address().
func (sc *RemoteSignerClient) getAddress() (cmn.HexBytes, error) {
p, err := sc.getPubKey()
if err != nil {
return nil, err
}
return p.Address(), nil
}
// GetPubKey implements PrivValidator.
func (sc *RemoteSignerClient) GetPubKey() crypto.PubKey {
pubKey, err := sc.getPubKey()
if err != nil {
panic(err)
}
return pubKey
}
func (sc *RemoteSignerClient) getPubKey() (crypto.PubKey, error) {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PubKeyMsg{})
if err != nil {
return nil, err
}
res, err := readMsg(sc.conn)
if err != nil {
return nil, err
}
return res.(*PubKeyMsg).PubKey, nil
}
// SignVote implements PrivValidator.
func (sc *RemoteSignerClient) SignVote(chainID string, vote *types.Vote) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedVoteResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*vote = *resp.Vote
return nil
}
// SignProposal implements PrivValidator.
func (sc *RemoteSignerClient) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedProposalResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*proposal = *resp.Proposal
return nil
}
// SignHeartbeat implements PrivValidator.
func (sc *RemoteSignerClient) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedHeartbeatResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*heartbeat = *resp.Heartbeat
return nil
}
// Ping is used to check connection health.
func (sc *RemoteSignerClient) Ping() error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PingRequest{})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
_, ok := res.(*PingResponse)
if !ok {
return ErrUnexpectedResponse
}
return nil
}
// OnStart implements cmn.Service.
func (sc *RemoteSignerClient) OnStart() error {
// Start a routine to keep the connection alive
sc.cancelPing = make(chan bool, 1)
go func() {
for {
select {
case <-time.Tick(sc.connHeartbeat):
err := sc.Ping()
if err != nil {
sc.Logger.Error(
"Ping",
"err", err,
)
}
case <-sc.cancelPing:
return
}
}
}()
return nil
}
// OnStop implements cmn.Service.
func (sc *RemoteSignerClient) OnStop() {
if sc.cancelPing != nil {
select {
case sc.cancelPing <- true:
default:
}
}
}

View File

@ -36,50 +36,22 @@ var (
var (
acceptDeadline = time.Second * defaultAcceptDeadlineSeconds
connDeadline = time.Second * defaultConnDeadlineSeconds
connTimeout = time.Second * defaultConnDeadlineSeconds
connHeartbeat = time.Second * defaultConnHeartBeatSeconds
)
// SocketPVOption sets an optional parameter on the SocketPV.
type SocketPVOption func(*SocketPV)
// SocketPVAcceptDeadline sets the deadline for the SocketPV listener.
// A zero time value disables the deadline.
func SocketPVAcceptDeadline(deadline time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.acceptDeadline = deadline }
}
// SocketPVConnDeadline sets the read and write deadline for connections
// from external signing processes.
func SocketPVConnDeadline(deadline time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connDeadline = deadline }
}
// SocketPVHeartbeat sets the period on which to check the liveness of the
// connected Signer connections.
func SocketPVHeartbeat(period time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connHeartbeat = period }
}
// SocketPVConnWait sets the timeout duration before connection of external
// signing processes are considered to be unsuccessful.
func SocketPVConnWait(timeout time.Duration) SocketPVOption {
return func(sc *SocketPV) { sc.connWaitTimeout = timeout }
}
// SocketPV implements PrivValidator, it uses a socket to request signatures
// from an external process.
type SocketPV struct {
cmn.BaseService
*RemoteSignerClient
addr string
acceptDeadline time.Duration
connDeadline time.Duration
connHeartbeat time.Duration
connTimeout time.Duration
connWaitTimeout time.Duration
privKey ed25519.PrivKeyEd25519
conn net.Conn
listener net.Listener
lock sync.Mutex
cancelPing chan bool
@ -97,177 +69,17 @@ func NewSocketPV(
sc := &SocketPV{
addr: socketAddr,
acceptDeadline: acceptDeadline,
connDeadline: connDeadline,
connHeartbeat: connHeartbeat,
connTimeout: connTimeout,
connWaitTimeout: time.Second * defaultConnWaitSeconds,
privKey: privKey,
}
sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc)
sc.RemoteSignerClient = NewRemoteSignerClient(sc.Logger, nil)
return sc
}
// GetAddress implements PrivValidator.
func (sc *SocketPV) GetAddress() types.Address {
addr, err := sc.getAddress()
if err != nil {
panic(err)
}
return addr
}
// Address is an alias for PubKey().Address().
func (sc *SocketPV) getAddress() (cmn.HexBytes, error) {
p, err := sc.getPubKey()
if err != nil {
return nil, err
}
return p.Address(), nil
}
// GetPubKey implements PrivValidator.
func (sc *SocketPV) GetPubKey() crypto.PubKey {
pubKey, err := sc.getPubKey()
if err != nil {
panic(err)
}
return pubKey
}
func (sc *SocketPV) getPubKey() (crypto.PubKey, error) {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PubKeyMsg{})
if err != nil {
return nil, err
}
res, err := readMsg(sc.conn)
if err != nil {
return nil, err
}
return res.(*PubKeyMsg).PubKey, nil
}
// SignVote implements PrivValidator.
func (sc *SocketPV) SignVote(chainID string, vote *types.Vote) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedVoteResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*vote = *resp.Vote
return nil
}
// SignProposal implements PrivValidator.
func (sc *SocketPV) SignProposal(
chainID string,
proposal *types.Proposal,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedProposalResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*proposal = *resp.Proposal
return nil
}
// SignHeartbeat implements PrivValidator.
func (sc *SocketPV) SignHeartbeat(
chainID string,
heartbeat *types.Heartbeat,
) error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: heartbeat})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
resp, ok := res.(*SignedHeartbeatResponse)
if !ok {
return ErrUnexpectedResponse
}
if resp.Error != nil {
return fmt.Errorf("remote error occurred: code: %v, description: %s",
resp.Error.Code,
resp.Error.Description)
}
*heartbeat = *resp.Heartbeat
return nil
}
// Ping is used to check connection health.
func (sc *SocketPV) Ping() error {
sc.lock.Lock()
defer sc.lock.Unlock()
err := writeMsg(sc.conn, &PingRequest{})
if err != nil {
return err
}
res, err := readMsg(sc.conn)
if err != nil {
return err
}
_, ok := res.(*PingResponse)
if !ok {
return ErrUnexpectedResponse
}
return nil
}
// OnStart implements cmn.Service.
func (sc *SocketPV) OnStart() error {
if err := sc.listen(); err != nil {
@ -290,37 +102,30 @@ func (sc *SocketPV) OnStart() error {
return err
}
// Start a routine to keep the connection alive
sc.cancelPing = make(chan bool, 1)
go func() {
for {
select {
case <-time.Tick(sc.connHeartbeat):
err := sc.Ping()
if err != nil {
sc.Logger.Error(
"Ping",
"err", err,
)
}
case <-sc.cancelPing:
return
}
}
}()
sc.conn = conn
err = sc.RemoteSignerClient.Start()
if err != nil {
err = cmn.ErrorWrap(err, "failed to start RemoteSignerClient")
sc.Logger.Error(
"OnStart",
"err", err,
)
return err
}
return nil
}
// OnStop implements cmn.Service.
func (sc *SocketPV) OnStop() {
if sc.cancelPing != nil {
select {
case sc.cancelPing <- true:
default:
}
if err := sc.RemoteSignerClient.Stop(); err != nil {
err = cmn.ErrorWrap(err, "failed to stop RemoteSignerClient")
sc.Logger.Error(
"OnStop",
"err", err,
)
}
if sc.conn != nil {
@ -371,7 +176,7 @@ func (sc *SocketPV) listen() error {
sc.listener = newTCPTimeoutListener(
ln,
sc.acceptDeadline,
sc.connDeadline,
sc.connTimeout,
sc.connHeartbeat,
)
@ -504,7 +309,7 @@ func (rs *RemoteSigner) connect() (net.Conn, error) {
continue
}
if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
if err := conn.SetDeadline(time.Now().Add(connTimeout)); err != nil {
err = cmn.ErrorWrap(err, "setting connection timeout failed")
rs.Logger.Error(
"connect",
@ -547,39 +352,7 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) {
return
}
var res SocketPVMsg
switch r := req.(type) {
case *PubKeyMsg:
var p crypto.PubKey
p = rs.privVal.GetPubKey()
res = &PubKeyMsg{p}
case *SignVoteRequest:
err = rs.privVal.SignVote(rs.chainID, r.Vote)
if err != nil {
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedVoteResponse{r.Vote, nil}
}
case *SignProposalRequest:
err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
if err != nil {
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedProposalResponse{r.Proposal, nil}
}
case *SignHeartbeatRequest:
err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
if err != nil {
res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedHeartbeatResponse{r.Heartbeat, nil}
}
case *PingRequest:
res = &PingResponse{}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
res, err := handleRequest(req, rs.chainID, rs.privVal)
if err != nil {
// only log the error; we'll reply with an error in res
@ -594,6 +367,45 @@ func (rs *RemoteSigner) handleConnection(conn net.Conn) {
}
}
func handleRequest(req SocketPVMsg, chainID string, privVal types.PrivValidator) (SocketPVMsg, error) {
var res SocketPVMsg
var err error
switch r := req.(type) {
case *PubKeyMsg:
var p crypto.PubKey
p = privVal.GetPubKey()
res = &PubKeyMsg{p}
case *SignVoteRequest:
err = privVal.SignVote(chainID, r.Vote)
if err != nil {
res = &SignedVoteResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedVoteResponse{r.Vote, nil}
}
case *SignProposalRequest:
err = privVal.SignProposal(chainID, r.Proposal)
if err != nil {
res = &SignedProposalResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedProposalResponse{r.Proposal, nil}
}
case *SignHeartbeatRequest:
err = privVal.SignHeartbeat(chainID, r.Heartbeat)
if err != nil {
res = &SignedHeartbeatResponse{nil, &RemoteSignerError{0, err.Error()}}
} else {
res = &SignedHeartbeatResponse{r.Heartbeat, nil}
}
case *PingRequest:
res = &PingResponse{}
default:
err = fmt.Errorf("unknown msg: %v", r)
}
return res, err
}
//---------------------------------------------------------
// SocketPVMsg is sent between RemoteSigner and SocketPV.

View File

@ -24,9 +24,9 @@ type tcpTimeoutListener struct {
period time.Duration
}
// tcpTimeoutConn wraps a *net.TCPConn to standardise protocol timeouts / deadline resets.
type tcpTimeoutConn struct {
*net.TCPConn
// timeoutConn wraps a net.Conn to standardise protocol timeouts / deadline resets.
type timeoutConn struct {
net.Conn
connDeadline time.Duration
}
@ -45,11 +45,11 @@ func newTCPTimeoutListener(
}
}
// newTCPTimeoutConn returns an instance of newTCPTimeoutConn.
func newTCPTimeoutConn(
conn *net.TCPConn,
connDeadline time.Duration) *tcpTimeoutConn {
return &tcpTimeoutConn{
// newTimeoutConn returns an instance of newTCPTimeoutConn.
func newTimeoutConn(
conn net.Conn,
connDeadline time.Duration) *timeoutConn {
return &timeoutConn{
conn,
connDeadline,
}
@ -67,24 +67,24 @@ func (ln tcpTimeoutListener) Accept() (net.Conn, error) {
return nil, err
}
// Wrap the TCPConn in our timeout wrapper
conn := newTCPTimeoutConn(tc, ln.connDeadline)
// Wrap the conn in our timeout wrapper
conn := newTimeoutConn(tc, ln.connDeadline)
return conn, nil
}
// Read implements net.Listener.
func (c tcpTimeoutConn) Read(b []byte) (int, error) {
func (c timeoutConn) Read(b []byte) (int, error) {
// Reset deadline
c.TCPConn.SetReadDeadline(time.Now().Add(c.connDeadline))
c.Conn.SetReadDeadline(time.Now().Add(c.connDeadline))
return c.TCPConn.Read(b)
return c.Conn.Read(b)
}
// Write implements net.Listener.
func (c tcpTimeoutConn) Write(b []byte) (int, error) {
func (c timeoutConn) Write(b []byte) (int, error) {
// Reset deadline
c.TCPConn.SetWriteDeadline(time.Now().Add(c.connDeadline))
c.Conn.SetWriteDeadline(time.Now().Add(c.connDeadline))
return c.TCPConn.Write(b)
return c.Conn.Write(b)
}

View File

@ -104,14 +104,14 @@ func TestSocketPVVoteResetDeadline(t *testing.T) {
defer sc.Stop()
defer rs.Stop()
time.Sleep(800 * time.Microsecond)
time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
assert.Equal(t, want.Signature, have.Signature)
// This would exceed the deadline if it was not extended by the previous message
time.Sleep(800 * time.Microsecond)
time.Sleep(3 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
@ -131,7 +131,7 @@ func TestSocketPVVoteKeepalive(t *testing.T) {
defer sc.Stop()
defer rs.Stop()
time.Sleep(2 * time.Millisecond)
time.Sleep(10 * time.Millisecond)
require.NoError(t, rs.privVal.SignVote(chainID, want))
require.NoError(t, sc.SignVote(chainID, have))
@ -154,21 +154,6 @@ func TestSocketPVHeartbeat(t *testing.T) {
assert.Equal(t, want.Signature, have.Signature)
}
func TestSocketPVAcceptDeadline(t *testing.T) {
var (
sc = NewSocketPV(
log.TestingLogger(),
"127.0.0.1:0",
ed25519.GenPrivKey(),
)
)
defer sc.Stop()
SocketPVAcceptDeadline(time.Millisecond)(sc)
assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout)
}
func TestSocketPVDeadline(t *testing.T) {
var (
addr = testFreeAddr(t)
@ -180,8 +165,8 @@ func TestSocketPVDeadline(t *testing.T) {
)
)
SocketPVConnDeadline(100 * time.Millisecond)(sc)
SocketPVConnWait(500 * time.Millisecond)(sc)
sc.connTimeout = 100 * time.Millisecond
sc.connWaitTimeout = 500 * time.Millisecond
go func(sc *SocketPV) {
defer close(listenc)
@ -212,19 +197,6 @@ func TestSocketPVDeadline(t *testing.T) {
assert.Equal(t, err.(cmn.Error).Data(), ErrConnTimeout)
}
func TestSocketPVWait(t *testing.T) {
sc := NewSocketPV(
log.TestingLogger(),
"127.0.0.1:0",
ed25519.GenPrivKey(),
)
defer sc.Stop()
SocketPVConnWait(time.Millisecond)(sc)
assert.Equal(t, sc.Start().(cmn.Error).Data(), ErrConnWaitTimeout)
}
func TestRemoteSignerRetry(t *testing.T) {
var (
attemptc = make(chan int)
@ -447,13 +419,13 @@ func testSetupSocketPair(
)
)
testStartSocketPV(t, readyc, sc)
SocketPVConnDeadline(time.Millisecond)(sc)
SocketPVHeartbeat(500 * time.Microsecond)(sc)
RemoteSignerConnDeadline(time.Millisecond)(rs)
sc.connTimeout = 5 * time.Millisecond
sc.connHeartbeat = 2 * time.Millisecond
RemoteSignerConnDeadline(5 * time.Millisecond)(rs)
RemoteSignerConnRetries(1e6)(rs)
testStartSocketPV(t, readyc, sc)
require.NoError(t, rs.Start())
assert.True(t, rs.IsRunning())