Extend the SecureConn socket to extend its deadline

This commit is contained in:
Hendrik Hofstadt 2018-10-07 16:15:57 +02:00
parent 80562669bf
commit cf779809a7
2 changed files with 39 additions and 14 deletions

View File

@ -24,6 +24,13 @@ type tcpTimeoutListener struct {
period time.Duration
}
// tcpTimeoutConn wraps a *net.TCPConn to standardise protocol timeouts / deadline resets.
type tcpTimeoutConn struct {
*net.TCPConn
connDeadline time.Duration
}
// newTCPTimeoutListener returns an instance of tcpTimeoutListener.
func newTCPTimeoutListener(
ln net.Listener,
@ -38,6 +45,16 @@ func newTCPTimeoutListener(
}
}
// newTCPTimeoutConn returns an instance of newTCPTimeoutConn.
func newTCPTimeoutConn(
conn *net.TCPConn,
connDeadline time.Duration) *tcpTimeoutConn {
return &tcpTimeoutConn{
conn,
connDeadline,
}
}
// Accept implements net.Listener.
func (ln tcpTimeoutListener) Accept() (net.Conn, error) {
err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline))
@ -50,17 +67,24 @@ func (ln tcpTimeoutListener) Accept() (net.Conn, error) {
return nil, err
}
if err := tc.SetDeadline(time.Now().Add(ln.connDeadline)); err != nil {
return nil, err
}
// Wrap the TCPConn in our timeout wrapper
conn := newTCPTimeoutConn(tc, ln.connDeadline)
if err := tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err := tc.SetKeepAlivePeriod(ln.period); err != nil {
return nil, err
}
return tc, nil
return conn, nil
}
// Read implements net.Listener.
func (c tcpTimeoutConn) Read(b []byte) (int, error) {
// Reset deadline
c.TCPConn.SetReadDeadline(time.Now().Add(c.connDeadline))
return c.TCPConn.Read(b)
}
// Write implements net.Listener.
func (c tcpTimeoutConn) Write(b []byte) (int, error) {
// Reset deadline
c.TCPConn.SetWriteDeadline(time.Now().Add(c.connDeadline))
return c.TCPConn.Write(b)
}

View File

@ -44,13 +44,14 @@ func TestTCPTimeoutListenerConnDeadline(t *testing.T) {
time.Sleep(2 * time.Millisecond)
_, err = c.Write([]byte("foo"))
msg := make([]byte, 200)
_, err = c.Read(msg)
opErr, ok := err.(*net.OpError)
if !ok {
t.Fatalf("have %v, want *net.OpError", err)
}
if have, want := opErr.Op, "write"; have != want {
if have, want := opErr.Op, "read"; have != want {
t.Errorf("have %v, want %v", have, want)
}
}(ln)