nps/server.go

194 lines
3.9 KiB
Go
Raw Normal View History

2018-11-04 07:19:22 -08:00
package main
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"sync"
"time"
)
type TRPServer struct {
tcpPort int
httpPort int
listener *net.TCPListener
connList chan net.Conn
sync.RWMutex
}
func NewRPServer(tcpPort, httpPort int) *TRPServer {
s := new(TRPServer)
s.tcpPort = tcpPort
s.httpPort = httpPort
s.connList = make(chan net.Conn, 1000)
return s
}
func (s *TRPServer) Start() error {
var err error
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tcpPort, ""})
if err != nil {
return err
}
go s.httpserver()
return s.tcpserver()
}
func (s *TRPServer) Close() error {
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return errors.New("TCP实例未创建")
}
func (s *TRPServer) tcpserver() error {
var err error
for {
conn, err := s.listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
go s.cliProcess(conn)
}
return err
}
func badRequest(w http.ResponseWriter) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
}
func (s *TRPServer) httpserver() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
retry:
if len(s.connList) == 0 {
badRequest(w)
return
}
conn := <-s.connList
log.Println(r.RequestURI)
err := s.write(r, conn)
if err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
err = s.read(w, conn)
if err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
s.connList <- conn
conn = nil
})
log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", s.httpPort), nil))
}
func (s *TRPServer) cliProcess(conn *net.TCPConn) error {
conn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
vval := make([]byte, 20)
_, err := conn.Read(vval)
if err != nil {
log.Println("客户端读超时。客户端地址为::", conn.RemoteAddr())
conn.Close()
return err
}
if bytes.Compare(vval, getverifyval()[:]) != 0 {
log.Println("当前客户端连接校验错误,关闭此客户端:", conn.RemoteAddr())
conn.Write([]byte("vkey"))
conn.Close()
return err
}
conn.SetReadDeadline(time.Time{})
log.Println("连接新的客户端:", conn.RemoteAddr())
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
s.connList <- conn
return nil
}
func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
raw, err := EncodeRequest(r)
if err != nil {
return err
}
c, err := conn.Write(raw)
if err != nil {
return err
}
if c != len(raw) {
return errors.New("写出长度与字节长度不一致。")
}
return nil
}
func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
val := make([]byte, 4)
_, err := conn.Read(val)
if err != nil {
return err
}
flags := string(val)
switch flags {
case "sign":
_, err = conn.Read(val)
if err != nil {
return err
}
nlen := int(binary.LittleEndian.Uint32(val))
if nlen == 0 {
return errors.New("读取客户端长度错误。")
}
log.Println("收到客户端数据,需要读取长度:", nlen)
raw := make([]byte, 0)
buff := make([]byte, 1024)
c := 0
for {
clen, err := conn.Read(buff)
if err != nil && err != io.EOF {
return err
}
raw = append(raw, buff[:clen]...)
c += clen
if c >= nlen {
break
}
}
log.Println("读取完成,长度:", c, "实际raw长度", len(raw))
if c != nlen {
return fmt.Errorf("已读取长度错误,已读取%dbyte需要读取%dbyte。", c, nlen)
}
resp, err := DecodeResponse(raw)
if err != nil {
return err
}
bodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
for k, v := range resp.Header {
for _, v2 := range v {
w.Header().Set(k, v2)
}
}
w.WriteHeader(resp.StatusCode)
w.Write(bodyBytes)
case "msg0":
return nil
default:
log.Println("无法解析此错误", string(val))
}
return nil
}