Extend the SecureConn socket to extend its deadline

This commit is contained in:
Hendrik Hofstadt 2018-10-07 16:15:57 +02:00
parent 80562669bf
commit cf779809a7
2 changed files with 39 additions and 14 deletions

View File

@ -24,6 +24,13 @@ type tcpTimeoutListener struct {
period time.Duration period time.Duration
} }
// tcpTimeoutConn wraps a *net.TCPConn to standardise protocol timeouts / deadline resets.
type tcpTimeoutConn struct {
*net.TCPConn
connDeadline time.Duration
}
// newTCPTimeoutListener returns an instance of tcpTimeoutListener. // newTCPTimeoutListener returns an instance of tcpTimeoutListener.
func newTCPTimeoutListener( func newTCPTimeoutListener(
ln net.Listener, ln net.Listener,
@ -38,6 +45,16 @@ func newTCPTimeoutListener(
} }
} }
// newTCPTimeoutConn returns an instance of newTCPTimeoutConn.
func newTCPTimeoutConn(
conn *net.TCPConn,
connDeadline time.Duration) *tcpTimeoutConn {
return &tcpTimeoutConn{
conn,
connDeadline,
}
}
// Accept implements net.Listener. // Accept implements net.Listener.
func (ln tcpTimeoutListener) Accept() (net.Conn, error) { func (ln tcpTimeoutListener) Accept() (net.Conn, error) {
err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline)) err := ln.SetDeadline(time.Now().Add(ln.acceptDeadline))
@ -50,17 +67,24 @@ func (ln tcpTimeoutListener) Accept() (net.Conn, error) {
return nil, err return nil, err
} }
if err := tc.SetDeadline(time.Now().Add(ln.connDeadline)); err != nil { // Wrap the TCPConn in our timeout wrapper
return nil, err conn := newTCPTimeoutConn(tc, ln.connDeadline)
}
if err := tc.SetKeepAlive(true); err != nil { return conn, nil
return nil, err }
}
// Read implements net.Listener.
if err := tc.SetKeepAlivePeriod(ln.period); err != nil { func (c tcpTimeoutConn) Read(b []byte) (int, error) {
return nil, err // Reset deadline
} c.TCPConn.SetReadDeadline(time.Now().Add(c.connDeadline))
return tc, nil return c.TCPConn.Read(b)
}
// Write implements net.Listener.
func (c tcpTimeoutConn) Write(b []byte) (int, error) {
// Reset deadline
c.TCPConn.SetWriteDeadline(time.Now().Add(c.connDeadline))
return c.TCPConn.Write(b)
} }

View File

@ -44,13 +44,14 @@ func TestTCPTimeoutListenerConnDeadline(t *testing.T) {
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
_, err = c.Write([]byte("foo")) msg := make([]byte, 200)
_, err = c.Read(msg)
opErr, ok := err.(*net.OpError) opErr, ok := err.(*net.OpError)
if !ok { if !ok {
t.Fatalf("have %v, want *net.OpError", err) t.Fatalf("have %v, want *net.OpError", err)
} }
if have, want := opErr.Op, "write"; have != want { if have, want := opErr.Op, "read"; have != want {
t.Errorf("have %v, want %v", have, want) t.Errorf("have %v, want %v", have, want)
} }
}(ln) }(ln)