diff --git a/client.go b/client.go index 296ccb9..6443024 100755 --- a/client.go +++ b/client.go @@ -103,6 +103,7 @@ func (s *TRPClient) dealChan() error { //与目标建立连接 server, err := net.Dial("tcp", host) if err != nil { + fmt.Println(err) return err } //创建成功后io.copy diff --git a/conn.go b/conn.go index 97f256b..41df325 100644 --- a/conn.go +++ b/conn.go @@ -7,6 +7,9 @@ import ( "fmt" "io" "net" + "net/url" + "regexp" + "strings" "time" ) @@ -95,6 +98,38 @@ func (s *Conn) SetAlive() { conn.SetKeepAlivePeriod(time.Duration(2 * time.Second)) } +//从tcp报文中解析出host +func (s *Conn) GetHost() (method, address string, rb []byte, err error) { + var b [2048]byte + var n int + var host string + if n, err = s.Read(b[:]); err != nil { + return + } + rb = b[:n] + //TODO:某些不规范报文可能会有问题 + fmt.Sscanf(string(b[:n]), "%s", &method) + reg, err := regexp.Compile(`(\w+:\/\/)([^/:]+)(:\d*)?`) + if err != nil { + return + } + host = string(reg.Find(b[:])) + hostPortURL, err := url.Parse(host) + if err != nil { + return + } + if hostPortURL.Opaque == "443" { //https访问 + address = hostPortURL.Scheme + ":443" + } else { //http访问 + if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80 + address = hostPortURL.Host + ":80" + } else { + address = hostPortURL.Host + } + } + return +} + func (s *Conn) Close() error { return s.conn.Close() } diff --git a/main.go b/main.go index 951a38e..905ca8b 100755 --- a/main.go +++ b/main.go @@ -47,13 +47,14 @@ func main() { log.Fatalln(err) } } else if *rpMode == "tunnelServer" { - svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget) - if err := svr.Start(); err != nil { - log.Fatalln(err) - } + svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessTunnel) + svr.Start() } else if *rpMode == "sock5Server" { svr := NewSock5ModeServer(*tcpPort, *httpPort) svr.Start() + } else if *rpMode == "httpProxyServer" { + svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget, ProcessHttp) + svr.Start() } } } diff --git a/server.go b/server.go index b70d858..c4b5f19 100755 --- a/server.go +++ b/server.go @@ -125,19 +125,22 @@ func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error { return nil } +type process func(c *Conn, s *TunnelModeServer) error type TunnelModeServer struct { Tunnel httpPort int tunnelTarget string + process process } -func NewTunnelModeServer(tcpPort, httpPort int, tunnelTarget string) *TunnelModeServer { +func NewTunnelModeServer(tcpPort, httpPort int, tunnelTarget string, process process) *TunnelModeServer { s := new(TunnelModeServer) s.tunnelPort = tcpPort s.httpPort = httpPort s.tunnelTarget = tunnelTarget s.tunnelList = make(chan *Conn, 1000) s.signalList = make(chan *Conn, 10) + s.process = process return s } @@ -164,21 +167,43 @@ func (s *TunnelModeServer) startTunnelServer() { log.Println(err) continue } - go s.process(NewConn(conn)) + go s.process(NewConn(conn), s) } } -//监听连接处理 -func (s *TunnelModeServer) process(c *Conn) error { +//TODO:这种实现方式…… +//tcp隧道模式 +func ProcessTunnel(c *Conn, s *TunnelModeServer) error { retry: - if len(s.tunnelList) < 10 { //新建通道 - go s.newChan() - } - link := <-s.tunnelList + link := s.GetTunnel() if _, err := link.WriteHost(s.tunnelTarget); err != nil { + link.Close() goto retry } go relay(link.conn, c.conn) relay(c.conn, link.conn) return nil } + +//http代理模式 +func ProcessHttp(c *Conn, s *TunnelModeServer) error { + method, addr, rb, err := c.GetHost() + if err != nil { + c.Close() + return err + } +retry: + link := s.GetTunnel() + if _, err := link.WriteHost(addr); err != nil { + link.Close() + goto retry + } + if method == "CONNECT" { + fmt.Fprint(c, "HTTP/1.1 200 Connection established\r\n") + } else { + link.Write(rb) + } + go relay(link.conn, c.conn) + relay(c.conn, link.conn) + return nil +} diff --git a/tunnel.go b/tunnel.go index ec7f177..dfd57c2 100644 --- a/tunnel.go +++ b/tunnel.go @@ -3,7 +3,6 @@ package main import ( "bytes" "errors" - "fmt" "log" "net" "sync" @@ -91,8 +90,15 @@ retry: connPass := <-s.signalList _, err := connPass.conn.Write([]byte("chan")) if err != nil { - fmt.Println(err) + log.Println(err) goto retry } s.signalList <- connPass } + +func (s *Tunnel) GetTunnel() *Conn { + if len(s.tunnelList) < 10 { //新建通道 + go s.newChan() + } + return <-s.tunnelList +} diff --git a/util.go b/util.go index f44b2b0..49f4e12 100644 --- a/util.go +++ b/util.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "io" - "log" "net" "net/http" "net/http/httputil" @@ -21,15 +20,10 @@ var ( disabledRedirect = errors.New("disabled redirect.") ) - - - func BadRequest(w http.ResponseWriter) { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) } - - //发送请求并转为bytes func GetEncodeResponse(req *http.Request) ([]byte, error) { var respBytes []byte @@ -52,7 +46,6 @@ func GetEncodeResponse(req *http.Request) ([]byte, error) { return respBytes, nil } - // 将request 的处理 func EncodeRequest(r *http.Request) ([]byte, error) { raw := bytes.NewBuffer([]byte{}) @@ -91,6 +84,7 @@ func DecodeRequest(data []byte) (*http.Request, error) { scheme = "https" } req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI)) + fmt.Println(req.URL) req.RequestURI = "" return req, nil } @@ -151,8 +145,6 @@ func replaceHost(resp []byte) []byte { } func relay(in, out net.Conn) { - if _, err := io.Copy(in, out); err != nil { - log.Println("copy error:", err) - } - in.Close() // + io.Copy(in, out); + in.Close() }