mirror of https://github.com/qwqdanchun/nps.git
194 lines
3.9 KiB
Go
194 lines
3.9 KiB
Go
|
package main
|
|||
|
|
|||
|
import (
|
|||
|
"bytes"
|
|||
|
"encoding/binary"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"io"
|
|||
|
"io/ioutil"
|
|||
|
"log"
|
|||
|
"net"
|
|||
|
"net/http"
|
|||
|
"sync"
|
|||
|
"time"
|
|||
|
)
|
|||
|
|
|||
|
type TRPServer struct {
|
|||
|
tcpPort int
|
|||
|
httpPort int
|
|||
|
listener *net.TCPListener
|
|||
|
connList chan net.Conn
|
|||
|
sync.RWMutex
|
|||
|
}
|
|||
|
|
|||
|
func NewRPServer(tcpPort, httpPort int) *TRPServer {
|
|||
|
s := new(TRPServer)
|
|||
|
s.tcpPort = tcpPort
|
|||
|
s.httpPort = httpPort
|
|||
|
s.connList = make(chan net.Conn, 1000)
|
|||
|
return s
|
|||
|
}
|
|||
|
|
|||
|
func (s *TRPServer) Start() error {
|
|||
|
var err error
|
|||
|
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tcpPort, ""})
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
go s.httpserver()
|
|||
|
return s.tcpserver()
|
|||
|
}
|
|||
|
|
|||
|
func (s *TRPServer) Close() error {
|
|||
|
if s.listener != nil {
|
|||
|
err := s.listener.Close()
|
|||
|
s.listener = nil
|
|||
|
return err
|
|||
|
}
|
|||
|
return errors.New("TCP实例未创建!")
|
|||
|
}
|
|||
|
|
|||
|
func (s *TRPServer) tcpserver() error {
|
|||
|
var err error
|
|||
|
for {
|
|||
|
conn, err := s.listener.AcceptTCP()
|
|||
|
if err != nil {
|
|||
|
log.Println(err)
|
|||
|
continue
|
|||
|
}
|
|||
|
go s.cliProcess(conn)
|
|||
|
}
|
|||
|
return err
|
|||
|
}
|
|||
|
|
|||
|
func badRequest(w http.ResponseWriter) {
|
|||
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
|||
|
}
|
|||
|
|
|||
|
func (s *TRPServer) httpserver() {
|
|||
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|||
|
retry:
|
|||
|
if len(s.connList) == 0 {
|
|||
|
badRequest(w)
|
|||
|
return
|
|||
|
}
|
|||
|
conn := <-s.connList
|
|||
|
log.Println(r.RequestURI)
|
|||
|
err := s.write(r, conn)
|
|||
|
if err != nil {
|
|||
|
log.Println(err)
|
|||
|
conn.Close()
|
|||
|
goto retry
|
|||
|
return
|
|||
|
}
|
|||
|
err = s.read(w, conn)
|
|||
|
if err != nil {
|
|||
|
log.Println(err)
|
|||
|
conn.Close()
|
|||
|
goto retry
|
|||
|
return
|
|||
|
}
|
|||
|
s.connList <- conn
|
|||
|
conn = nil
|
|||
|
})
|
|||
|
log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", s.httpPort), nil))
|
|||
|
}
|
|||
|
|
|||
|
func (s *TRPServer) cliProcess(conn *net.TCPConn) error {
|
|||
|
conn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
|
|||
|
vval := make([]byte, 20)
|
|||
|
_, err := conn.Read(vval)
|
|||
|
if err != nil {
|
|||
|
log.Println("客户端读超时。客户端地址为::", conn.RemoteAddr())
|
|||
|
conn.Close()
|
|||
|
return err
|
|||
|
}
|
|||
|
if bytes.Compare(vval, getverifyval()[:]) != 0 {
|
|||
|
log.Println("当前客户端连接校验错误,关闭此客户端:", conn.RemoteAddr())
|
|||
|
conn.Write([]byte("vkey"))
|
|||
|
conn.Close()
|
|||
|
return err
|
|||
|
}
|
|||
|
conn.SetReadDeadline(time.Time{})
|
|||
|
log.Println("连接新的客户端:", conn.RemoteAddr())
|
|||
|
conn.SetKeepAlive(true)
|
|||
|
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
|
|||
|
s.connList <- conn
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
|
|||
|
raw, err := EncodeRequest(r)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
c, err := conn.Write(raw)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
if c != len(raw) {
|
|||
|
return errors.New("写出长度与字节长度不一致。")
|
|||
|
}
|
|||
|
return nil
|
|||
|
}
|
|||
|
|
|||
|
func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
|
|||
|
val := make([]byte, 4)
|
|||
|
_, err := conn.Read(val)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
flags := string(val)
|
|||
|
switch flags {
|
|||
|
case "sign":
|
|||
|
_, err = conn.Read(val)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
nlen := int(binary.LittleEndian.Uint32(val))
|
|||
|
if nlen == 0 {
|
|||
|
return errors.New("读取客户端长度错误。")
|
|||
|
}
|
|||
|
log.Println("收到客户端数据,需要读取长度:", nlen)
|
|||
|
raw := make([]byte, 0)
|
|||
|
buff := make([]byte, 1024)
|
|||
|
c := 0
|
|||
|
for {
|
|||
|
clen, err := conn.Read(buff)
|
|||
|
if err != nil && err != io.EOF {
|
|||
|
return err
|
|||
|
}
|
|||
|
raw = append(raw, buff[:clen]...)
|
|||
|
c += clen
|
|||
|
if c >= nlen {
|
|||
|
break
|
|||
|
}
|
|||
|
}
|
|||
|
log.Println("读取完成,长度:", c, "实际raw长度:", len(raw))
|
|||
|
if c != nlen {
|
|||
|
return fmt.Errorf("已读取长度错误,已读取%dbyte,需要读取%dbyte。", c, nlen)
|
|||
|
}
|
|||
|
resp, err := DecodeResponse(raw)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
bodyBytes, err := ioutil.ReadAll(resp.Body)
|
|||
|
if err != nil {
|
|||
|
return err
|
|||
|
}
|
|||
|
for k, v := range resp.Header {
|
|||
|
for _, v2 := range v {
|
|||
|
w.Header().Set(k, v2)
|
|||
|
}
|
|||
|
}
|
|||
|
w.WriteHeader(resp.StatusCode)
|
|||
|
w.Write(bodyBytes)
|
|||
|
case "msg0":
|
|||
|
return nil
|
|||
|
default:
|
|||
|
log.Println("无法解析此错误", string(val))
|
|||
|
}
|
|||
|
return nil
|
|||
|
}
|