From 00900c13156af77cd8919168f67ad8662888c077 Mon Sep 17 00:00:00 2001 From: ffdfgdfg Date: Fri, 23 Aug 2019 18:53:36 +0800 Subject: [PATCH] mux test --- lib/common/const.go | 12 +++ lib/common/netpackager.go | 175 ++++++++++++++++++++++++++++++++++++ lib/mux/bytes.go | 2 +- lib/mux/conn.go | 42 ++++----- lib/mux/mux.go | 185 ++++++++++++++++++-------------------- lib/mux/mux_test.go | 4 +- lib/pool/pool.go | 32 ++++++- 7 files changed, 324 insertions(+), 128 deletions(-) create mode 100644 lib/common/netpackager.go diff --git a/lib/common/const.go b/lib/common/const.go index ffb2fa6..d77f16b 100644 --- a/lib/common/const.go +++ b/lib/common/const.go @@ -36,3 +36,15 @@ WWW-Authenticate: Basic realm="easyProxy" ` ) + +const ( + MUX_PING_FLAG uint8 = iota + MUX_NEW_CONN_OK + MUX_NEW_CONN_Fail + MUX_NEW_MSG + MUX_MSG_SEND_OK + MUX_NEW_CONN + MUX_CONN_CLOSE + MUX_PING_RETURN + MUX_PING int32 = -1 +) diff --git a/lib/common/netpackager.go b/lib/common/netpackager.go new file mode 100644 index 0000000..315f645 --- /dev/null +++ b/lib/common/netpackager.go @@ -0,0 +1,175 @@ +package common + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "github.com/cnlh/nps/lib/pool" + "io" + "strings" +) + +type NetPackager interface { + Pack(writer io.Writer) (err error) + UnPack(reader io.Reader) (err error) +} + +type BasePackager struct { + Length uint32 + Content []byte +} + +func (Self *BasePackager) NewPac(contents ...interface{}) (err error) { + Self.clean() + for _, content := range contents { + switch content.(type) { + case nil: + Self.Content = Self.Content[:0] + case []byte: + Self.Content = append(Self.Content, content.([]byte)...) + case string: + Self.Content = append(Self.Content, []byte(content.(string))...) + Self.Content = append(Self.Content, []byte(CONN_DATA_SEQ)...) + default: + err = Self.marshal(content) + } + } + Self.setLength() + return +} + +//似乎这里涉及到父类作用域问题,当子类调用父类的方法时,其struct仅仅为父类的 +func (Self *BasePackager) Pack(writer io.Writer) (err error) { + err = binary.Write(writer, binary.LittleEndian, Self.Length) + if err != nil { + return + } + err = binary.Write(writer, binary.LittleEndian, Self.Content) + //logs.Warn(Self.Length, string(Self.Content)) + return +} + +//Unpack 会导致传入的数字类型转化成float64!! +//主要原因是json unmarshal并未传入正确的数据类型 +func (Self *BasePackager) UnPack(reader io.Reader) (err error) { + err = binary.Read(reader, binary.LittleEndian, &Self.Length) + if err != nil { + return + } + Self.Content = pool.GetBufPoolCopy() + Self.Content = Self.Content[:Self.Length] + //n, err := io.ReadFull(reader, Self.Content) + //if n != int(Self.Length) { + // err = io.ErrUnexpectedEOF + //} + err = binary.Read(reader, binary.LittleEndian, &Self.Content) + return +} + +func (Self *BasePackager) marshal(content interface{}) (err error) { + tmp, err := json.Marshal(content) + if err != nil { + return err + } + Self.Content = append(Self.Content, tmp...) + return +} + +func (Self *BasePackager) Unmarshal(content interface{}) (err error) { + err = json.Unmarshal(Self.Content, content) + if err != nil { + return err + } + return +} + +func (Self *BasePackager) setLength() { + Self.Length = uint32(len(Self.Content)) + return +} + +func (Self *BasePackager) clean() { + Self.Length = 0 + Self.Content = Self.Content[:0] +} + +func (Self *BasePackager) Split() (strList []string) { + n := bytes.IndexByte(Self.Content, 0) + strList = strings.Split(string(Self.Content[:n]), CONN_DATA_SEQ) + strList = strList[0 : len(strList)-1] + return +} + +type ConnPackager struct { // Todo + ConnType uint8 + BasePackager +} + +func (Self *ConnPackager) NewPac(connType uint8, content ...interface{}) (err error) { + Self.ConnType = connType + err = Self.BasePackager.NewPac(content...) + return +} + +func (Self *ConnPackager) Pack(writer io.Writer) (err error) { + err = binary.Write(writer, binary.LittleEndian, Self.ConnType) + if err != nil { + return + } + err = Self.BasePackager.Pack(writer) + return +} + +func (Self *ConnPackager) UnPack(reader io.Reader) (err error) { + err = binary.Read(reader, binary.LittleEndian, &Self.ConnType) + if err != nil && err != io.EOF { + return + } + err = Self.BasePackager.UnPack(reader) + return +} + +type MuxPackager struct { + Flag uint8 + Id int32 + BasePackager +} + +func (Self *MuxPackager) NewPac(flag uint8, id int32, content ...interface{}) (err error) { + Self.Flag = flag + Self.Id = id + if flag == MUX_NEW_MSG { + err = Self.BasePackager.NewPac(content...) + } + return +} + +func (Self *MuxPackager) Pack(writer io.Writer) (err error) { + err = binary.Write(writer, binary.LittleEndian, Self.Flag) + if err != nil { + return + } + err = binary.Write(writer, binary.LittleEndian, Self.Id) + if err != nil { + return + } + if Self.Flag == MUX_NEW_MSG { + err = Self.BasePackager.Pack(writer) + } + return +} + +func (Self *MuxPackager) UnPack(reader io.Reader) (err error) { + err = binary.Read(reader, binary.LittleEndian, &Self.Flag) + if err != nil { + return + } + err = binary.Read(reader, binary.LittleEndian, &Self.Id) + if err != nil { + return + } + if Self.Flag == MUX_NEW_MSG { + err = Self.BasePackager.UnPack(reader) + } + return +} diff --git a/lib/mux/bytes.go b/lib/mux/bytes.go index a7e17f7..c44bad4 100644 --- a/lib/mux/bytes.go +++ b/lib/mux/bytes.go @@ -20,7 +20,7 @@ func WriteLenBytes(buf []byte, w io.Writer) (int, error) { //read bytes by length func ReadLenBytes(buf []byte, r io.Reader) (int, error) { - var l int32 + var l uint32 var err error if binary.Read(r, binary.LittleEndian, &l) != nil { return 0, err diff --git a/lib/mux/conn.go b/lib/mux/conn.go index 9e66577..a14e98d 100644 --- a/lib/mux/conn.go +++ b/lib/mux/conn.go @@ -2,10 +2,11 @@ package mux import ( "errors" + "github.com/cnlh/nps/lib/common" "github.com/cnlh/nps/lib/pool" + "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" "io" "net" - "sync" "time" ) @@ -30,8 +31,6 @@ type conn struct { mux *Mux } -var connPool = sync.Pool{} - func NewConn(connId int32, mux *Mux) *conn { c := &conn{ readCh: make(chan struct{}), @@ -73,9 +72,15 @@ func (s *conn) Read(buf []byte) (n int, err error) { return 0, io.EOF } else { pool.PutBufPoolCopy(s.readBuffer) - s.readBuffer = node.val - s.endRead = node.l - s.startRead = 0 + if node.val == nil { + //close + s.Close() + logs.Warn("close from read msg ", s.connId) + } else { + s.readBuffer = node.val + s.endRead = node.l + s.startRead = 0 + } } } if len(buf) < s.endRead-s.startRead { @@ -84,12 +89,11 @@ func (s *conn) Read(buf []byte) (n int, err error) { } else { n = copy(buf, s.readBuffer[s.startRead:s.endRead]) s.startRead += n - s.mux.sendInfo(MUX_MSG_SEND_OK, s.connId, nil) } return } -func (s *conn) Write(buf []byte) (int, error) { +func (s *conn) Write(buf []byte) (n int, err error) { if s.isClose { return 0, errors.New("the conn has closed") } @@ -115,15 +119,11 @@ func (s *conn) write(buf []byte, ch chan struct{}) { start := 0 l := len(buf) for { - if s.hasWrite > 50 { - <-s.getStatusCh - } - s.hasWrite++ if l-start > pool.PoolSizeCopy { - s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:start+pool.PoolSizeCopy]) + s.mux.sendInfo(common.MUX_NEW_MSG, s.connId, buf[start:start+pool.PoolSizeCopy]) start += pool.PoolSizeCopy } else { - s.mux.sendInfo(MUX_NEW_MSG, s.connId, buf[start:l]) + s.mux.sendInfo(common.MUX_NEW_MSG, s.connId, buf[start:l]) break } } @@ -132,16 +132,7 @@ func (s *conn) write(buf []byte, ch chan struct{}) { func (s *conn) Close() error { if s.isClose { - return errors.New("the conn has closed") - } - times := 0 -retry: - if s.waitQueue.Size() > 0 && times < 600 { - time.Sleep(time.Millisecond * 100) - times++ - goto retry - } - if s.isClose { + logs.Warn("already closed", s.connId) return errors.New("the conn has closed") } s.isClose = true @@ -152,9 +143,8 @@ retry: s.waitQueue.Clear() s.mux.connMap.Delete(s.connId) if !s.mux.IsClose { - s.mux.sendInfo(MUX_CONN_CLOSE, s.connId, nil) + s.mux.sendInfo(common.MUX_CONN_CLOSE, s.connId, nil) } - connPool.Put(s) return nil } diff --git a/lib/mux/mux.go b/lib/mux/mux.go index 315bc68..bfd82ff 100644 --- a/lib/mux/mux.go +++ b/lib/mux/mux.go @@ -1,10 +1,10 @@ package mux import ( - "bytes" - "encoding/binary" "errors" + "github.com/cnlh/nps/lib/common" "github.com/cnlh/nps/lib/pool" + "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" "math" "net" "sync" @@ -12,40 +12,30 @@ import ( "time" ) -const ( - MUX_PING_FLAG int32 = iota - MUX_NEW_CONN_OK - MUX_NEW_CONN_Fail - MUX_NEW_MSG - MUX_MSG_SEND_OK - MUX_NEW_CONN - MUX_PING - MUX_CONN_CLOSE - MUX_PING_RETURN -) - type Mux struct { net.Listener - conn net.Conn - connMap *connMap - newConnCh chan *conn - id int32 - closeChan chan struct{} - IsClose bool - pingOk int - connType string + conn net.Conn + connMap *connMap + newConnCh chan *conn + id int32 + closeChan chan struct{} + IsClose bool + pingOk int + connType string + writeQueue *sliceEntry sync.Mutex } func NewMux(c net.Conn, connType string) *Mux { m := &Mux{ - conn: c, - connMap: NewConnMap(), - id: 0, - closeChan: make(chan struct{}), - newConnCh: make(chan *conn), - IsClose: false, - connType: connType, + conn: c, + connMap: NewConnMap(), + id: 0, + closeChan: make(chan struct{}), + newConnCh: make(chan *conn), + IsClose: false, + connType: connType, + writeQueue: NewQueue(), } //read session by flag go m.readSession() @@ -61,7 +51,7 @@ func (s *Mux) NewConn() (*conn, error) { conn := NewConn(s.getId(), s) //it must be set before send s.connMap.Set(conn.connId, conn) - if err := s.sendInfo(MUX_NEW_CONN, conn.connId, nil); err != nil { + if err := s.sendInfo(common.MUX_NEW_CONN, conn.connId, nil); err != nil { return nil, err } //set a timer timeout 30 second @@ -91,19 +81,28 @@ func (s *Mux) Addr() net.Addr { return s.conn.LocalAddr() } -func (s *Mux) sendInfo(flag int32, id int32, content []byte) error { - raw := bytes.NewBuffer([]byte{}) - binary.Write(raw, binary.LittleEndian, flag) - binary.Write(raw, binary.LittleEndian, id) - if content != nil && len(content) > 0 { - binary.Write(raw, binary.LittleEndian, int32(len(content))) - binary.Write(raw, binary.LittleEndian, content) - } - if _, err := s.conn.Write(raw.Bytes()); err != nil { +func (s *Mux) sendInfo(flag uint8, id int32, content []byte) (err error) { + buf := pool.BuffPool.Get() + defer pool.BuffPool.Put(buf) + pack := common.MuxPackager{} + err = pack.NewPac(flag, id, content) + if err != nil { s.Close() - return err + logs.Warn("new pack err", err) + return } - return nil + err = pack.Pack(buf) + if err != nil { + s.Close() + logs.Warn("pack err", err) + return + } + _, err = buf.WriteTo(s.conn) + if err != nil { + s.Close() + logs.Warn("write err", err) + } + return } func (s *Mux) ping() { @@ -117,7 +116,7 @@ func (s *Mux) ping() { if (math.MaxInt32 - s.id) < 10000 { s.id = 0 } - if err := s.sendInfo(MUX_PING_FLAG, MUX_PING, nil); err != nil || (s.pingOk > 10 && s.connType == "kcp") { + if err := s.sendInfo(common.MUX_PING_FLAG, common.MUX_PING, nil); err != nil || (s.pingOk > 10 && s.connType == "kcp") { s.Close() break } @@ -130,65 +129,48 @@ func (s *Mux) ping() { } func (s *Mux) readSession() { - var buf []byte + var pack common.MuxPackager go func() { for { - var flag, i int32 - var n int - var err error - if binary.Read(s.conn, binary.LittleEndian, &flag) == nil { - if binary.Read(s.conn, binary.LittleEndian, &i) != nil { - break - } - s.pingOk = 0 - switch flag { - case MUX_NEW_CONN: //new conn - conn := NewConn(i, s) - s.connMap.Set(i, conn) //it has been set before send ok - s.newConnCh <- conn - s.sendInfo(MUX_NEW_CONN_OK, i, nil) - continue - case MUX_PING_FLAG: //ping - s.sendInfo(MUX_PING_RETURN, MUX_PING, nil) - continue - case MUX_PING_RETURN: - continue - case MUX_NEW_MSG: - buf = pool.GetBufPoolCopy() - if n, err = ReadLenBytes(buf, s.conn); err != nil { - break - } - } - if conn, ok := s.connMap.Get(i); ok && !conn.isClose { - switch flag { - case MUX_NEW_MSG: //new msg from remote conn - //insert wait queue - conn.waitQueue.Push(NewBufNode(buf, n)) - //judge len if >xxx ,send stop - if conn.readWait { - conn.readWait = false - conn.readCh <- struct{}{} - } - case MUX_MSG_SEND_OK: //the remote has read - select { - case conn.getStatusCh <- struct{}{}: - default: - } - conn.hasWrite -- - case MUX_NEW_CONN_OK: //conn ok - conn.connStatusOkCh <- struct{}{} - case MUX_NEW_CONN_Fail: - conn.connStatusFailCh <- struct{}{} - case MUX_CONN_CLOSE: //close the connection - go conn.Close() - s.connMap.Delete(i) - } - } else if flag == MUX_NEW_MSG { - pool.PutBufPoolCopy(buf) - } - } else { + if pack.UnPack(s.conn) != nil { break } + s.pingOk = 0 + switch pack.Flag { + case common.MUX_NEW_CONN: //new conn + logs.Warn("mux new conn", pack.Id) + conn := NewConn(pack.Id, s) + s.connMap.Set(pack.Id, conn) //it has been set before send ok + s.newConnCh <- conn + s.sendInfo(common.MUX_NEW_CONN_OK, pack.Id, nil) + continue + case common.MUX_PING_FLAG: //ping + s.sendInfo(common.MUX_PING_RETURN, common.MUX_PING, nil) + continue + case common.MUX_PING_RETURN: + continue + } + if conn, ok := s.connMap.Get(pack.Id); ok && !conn.isClose { + switch pack.Flag { + case common.MUX_NEW_MSG: //new msg from remote conn + //insert wait queue + conn.waitQueue.Push(NewBufNode(pack.Content, int(pack.Length))) + //judge len if >xxx ,send stop + if conn.readWait { + conn.readWait = false + conn.readCh <- struct{}{} + } + case common.MUX_NEW_CONN_OK: //conn ok + conn.connStatusOkCh <- struct{}{} + case common.MUX_NEW_CONN_Fail: + conn.connStatusFailCh <- struct{}{} + case common.MUX_CONN_CLOSE: //close the connection + conn.waitQueue.Push(NewBufNode(nil, 0)) + s.connMap.Delete(pack.Id) + } + } else if pack.Flag == common.MUX_NEW_MSG { + pool.PutBufPoolCopy(pack.Content) + } } s.Close() }() @@ -198,6 +180,7 @@ func (s *Mux) readSession() { } func (s *Mux) Close() error { + logs.Warn("close mux") if s.IsClose { return errors.New("the mux has closed") } @@ -214,6 +197,10 @@ func (s *Mux) Close() error { } //get new connId as unique flag -func (s *Mux) getId() int32 { - return atomic.AddInt32(&s.id, 1) +func (s *Mux) getId() (id int32) { + id = atomic.AddInt32(&s.id, 1) + if _, ok := s.connMap.Get(id); ok { + s.getId() + } + return } diff --git a/lib/mux/mux_test.go b/lib/mux/mux_test.go index f84e378..067e939 100644 --- a/lib/mux/mux_test.go +++ b/lib/mux/mux_test.go @@ -32,13 +32,14 @@ func TestNewMux(t *testing.T) { log.Fatalln(err) } go func(c net.Conn) { - c2, err := net.Dial("tcp", "127.0.0.1:8082") + c2, err := net.Dial("tcp", "127.0.0.1:80") if err != nil { log.Fatalln(err) } go common.CopyBuffer(c2, c) common.CopyBuffer(c, c2) c.Close() + //logs.Warn("close from out npc ") c2.Close() }(c) } @@ -64,6 +65,7 @@ func TestNewMux(t *testing.T) { common.CopyBuffer(conn, tmpCpnn) conn.Close() tmpCpnn.Close() + logs.Warn("close from out nps ", tmpCpnn.connId) }(conn) } }() diff --git a/lib/pool/pool.go b/lib/pool/pool.go index 70e0477..fb337a2 100644 --- a/lib/pool/pool.go +++ b/lib/pool/pool.go @@ -1,6 +1,7 @@ package pool import ( + "bytes" "sync" ) @@ -36,6 +37,7 @@ var BufPoolCopy = sync.Pool{ return &buf }, } + func PutBufPoolUdp(buf []byte) { if cap(buf) == PoolSizeUdp { BufPoolUdp.Put(buf[:PoolSizeUdp]) @@ -48,7 +50,7 @@ func PutBufPoolCopy(buf []byte) { } } -func GetBufPoolCopy() ([]byte) { +func GetBufPoolCopy() []byte { return (*BufPoolCopy.Get().(*[]byte))[:PoolSizeCopy] } @@ -57,3 +59,31 @@ func PutBufPoolMax(buf []byte) { BufPoolMax.Put(buf[:PoolSize]) } } + +type BufferPool struct { + pool sync.Pool +} + +func (Self *BufferPool) New() { + Self.pool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, + } +} + +func (Self *BufferPool) Get() *bytes.Buffer { + return Self.pool.Get().(*bytes.Buffer) +} + +func (Self *BufferPool) Put(x *bytes.Buffer) { + x.Reset() + Self.pool.Put(x) +} + +var once = sync.Once{} +var BuffPool = BufferPool{} + +func init() { + once.Do(BuffPool.New) +}