From a732febf3b7ef84c810417157e25ce4fa8726fa4 Mon Sep 17 00:00:00 2001 From: snowie2000 Date: Wed, 15 Apr 2020 10:59:48 +0800 Subject: [PATCH 1/4] fixed typo in test.go replaced self-made http reverseproxy with a more robust and versatile one. dynamically generate cert for client-server tls encryption --- cmd/nps/nps.go | 20 ++- lib/crypt/tls.go | 63 ++++++- server/proxy/http.go | 316 +++++++++++++++++------------------ server/proxy/reverseproxy.go | 136 +++++++++++++++ server/test/test.go | 4 +- 5 files changed, 360 insertions(+), 179 deletions(-) create mode 100644 server/proxy/reverseproxy.go diff --git a/cmd/nps/nps.go b/cmd/nps/nps.go index baa930b..c3b4d33 100644 --- a/cmd/nps/nps.go +++ b/cmd/nps/nps.go @@ -1,14 +1,6 @@ package main import ( - "ehang.io/nps/lib/crypt" - "ehang.io/nps/lib/file" - "ehang.io/nps/lib/install" - "ehang.io/nps/lib/version" - "ehang.io/nps/server" - "ehang.io/nps/server/connection" - "ehang.io/nps/server/tool" - "ehang.io/nps/web/routers" "flag" "log" "os" @@ -18,7 +10,16 @@ import ( "strings" "sync" + "ehang.io/nps/lib/file" + "ehang.io/nps/lib/install" + "ehang.io/nps/lib/version" + "ehang.io/nps/server" + "ehang.io/nps/server/connection" + "ehang.io/nps/server/tool" + "ehang.io/nps/web/routers" + "ehang.io/nps/lib/common" + "ehang.io/nps/lib/crypt" "ehang.io/nps/lib/daemon" "github.com/astaxie/beego" "github.com/astaxie/beego/logs" @@ -200,7 +201,8 @@ func run() { } logs.Info("the version of server is %s ,allow client core version to be %s", version.VERSION, version.GetVersion()) connection.InitConnectionService() - crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key")) + //crypt.InitTls(filepath.Join(common.GetRunPath(), "conf", "server.pem"), filepath.Join(common.GetRunPath(), "conf", "server.key")) + crypt.InitTls() tool.InitAllowPort() tool.StartSystemInfo() go server.StartNewServer(bridgePort, task, beego.AppConfig.String("bridge_type")) diff --git a/lib/crypt/tls.go b/lib/crypt/tls.go index 35a0a74..c301be8 100644 --- a/lib/crypt/tls.go +++ b/lib/crypt/tls.go @@ -1,22 +1,37 @@ package crypt import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "log" + "math/big" "net" "os" + "time" "github.com/astaxie/beego/logs" ) -var pemPath, keyPath string +var ( + cert tls.Certificate +) -func InitTls(pem, key string) { - pemPath = pem - keyPath = key +func InitTls() { + c, k, err := generateKeyPair("NPS Corp,.Inc") + if err == nil { + cert, err = tls.X509KeyPair(c, k) + } + if err != nil { + log.Fatalln("Error initializing crypto certs", err) + } } func NewTlsServerConn(conn net.Conn) net.Conn { - cert, err := tls.LoadX509KeyPair(pemPath, keyPath) + var err error if err != nil { logs.Error(err) os.Exit(0) @@ -32,3 +47,41 @@ func NewTlsClientConn(conn net.Conn) net.Conn { } return tls.Client(conn, conf) } + +func generateKeyPair(CommonName string) (rawCert, rawKey []byte, err error) { + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Company Name LTD."}, + CommonName: CommonName, + Country: []string{"US"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return +} diff --git a/server/proxy/http.go b/server/proxy/http.go index af4ad80..224183a 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -1,7 +1,7 @@ package proxy import ( - "bufio" + "context" "crypto/tls" "io" "net" @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "strconv" - "strings" "sync" + "time" "ehang.io/nps/bridge" "ehang.io/nps/lib/cache" @@ -101,174 +101,164 @@ func (s *httpServer) Close() error { return nil } -func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) { - hijacker, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Hijacking not supported", http.StatusInternalServerError) - return - } - c, _, err := hijacker.Hijack() - if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - } - s.handleHttp(conn.NewConn(c), r) -} - -func (s *httpServer) handleHttp(c *conn.Conn, r *http.Request) { - var ( - host *file.Host - target net.Conn - err error - connClient io.ReadWriteCloser - scheme = r.URL.Scheme - lk *conn.Link - targetAddr string - lenConn *conn.LenConn - isReset bool - wg sync.WaitGroup - ) - defer func() { - if connClient != nil { - connClient.Close() - }else { - s.writeConnFail(c.Conn) - } - c.Close() - }() -reset: - if isReset { - host.Client.AddConn() - } - if host, err = file.GetDb().GetInfoByHost(r.Host, r); err != nil { - logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) - return - } - if err := s.CheckFlowAndConnNum(host.Client); err != nil { - logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error()) - return - } - if !isReset { - defer host.Client.AddConn() - } - if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil { - logs.Warn("auth error", err, r.RemoteAddr) - return - } - if targetAddr, err = host.Target.GetRandomTarget(); err != nil { - logs.Warn(err.Error()) - return - } - lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) - if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { - logs.Notice("connect to target %s error %s", lk.Host, err) - return - } - connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) - - //read from inc-client - go func() { - wg.Add(1) - isReset = false - defer connClient.Close() - defer func() { - wg.Done() - if !isReset { - c.Close() - } - }() - for { - if resp, err := http.ReadResponse(bufio.NewReader(connClient), r); err != nil || resp == nil { - return - } else { - //if the cache is start and the response is in the extension,store the response to the cache list - if s.useCache && r.URL != nil && strings.Contains(r.URL.Path, ".") { - b, err := httputil.DumpResponse(resp, true) - if err != nil { - return - } - c.Write(b) - host.Flow.Add(0, int64(len(b))) - s.cache.Add(filepath.Join(host.Host, r.URL.Path), b) - } else { - lenConn := conn.NewLenConn(c) - if err := resp.Write(lenConn); err != nil { - logs.Error(err) - return - } - host.Flow.Add(0, int64(lenConn.Len)) - } - } - } - }() - - for { - //if the cache start and the request is in the cache list, return the cache - if s.useCache { - if v, ok := s.cache.Get(filepath.Join(host.Host, r.URL.Path)); ok { - n, err := c.Write(v.([]byte)) - if err != nil { - break - } - logs.Trace("%s request, method %s, host %s, url %s, remote address %s, return cache", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String()) - host.Flow.Add(0, int64(n)) - //if return cache and does not create a new conn with client and Connection is not set or close, close the connection. - if strings.ToLower(r.Header.Get("Connection")) == "close" || strings.ToLower(r.Header.Get("Connection")) == "" { - break - } - goto readReq - } - } - - //change the host and header and set proxy setting - common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String(), s.addOrigin) - logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String(), lk.Host) - //write - lenConn = conn.NewLenConn(connClient) - if err := r.Write(lenConn); err != nil { - logs.Error(err) - break - } - host.Flow.Add(int64(lenConn.Len), 0) - - readReq: - //read req from connection - if r, err = http.ReadRequest(bufio.NewReader(c)); err != nil { - break - } - r.URL.Scheme = scheme - //What happened ,Why one character less??? - r.Method = resetReqMethod(r.Method) - if hostTmp, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil { - logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) - break - } else if host != hostTmp { - host = hostTmp - isReset = true - connClient.Close() - goto reset - } - } - wg.Wait() -} - -func resetReqMethod(method string) string { - if method == "ET" { - return "GET" - } - if method == "OST" { - return "POST" - } - return method -} - func (s *httpServer) NewServer(port int, scheme string) *http.Server { + rProxy := NewHttpReverseProxy(s) return &http.Server{ Addr: ":" + strconv.Itoa(port), Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = scheme - s.handleTunneling(w, r) + rProxy.ServeHTTP(w, r) }), // Disable HTTP/2. TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), } } + +type HttpReverseProxy struct { + proxy *ReverseProxy + + responseHeaderTimeout time.Duration +} + +func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + var ( + host *file.Host + targetAddr string + err error + ) + if host, err = file.GetDb().GetInfoByHost(req.Host, req); err != nil { + rw.WriteHeader(http.StatusNotFound) + rw.Write([]byte(req.Host + " not found")) + return + } + if host.Client.Cnf.U != "" && host.Client.Cnf.P != "" && !common.CheckAuth(req, host.Client.Cnf.U, host.Client.Cnf.P) { + rw.WriteHeader(http.StatusUnauthorized) + rw.Write([]byte("Unauthorized")) + return + } + if targetAddr, err = host.Target.GetRandomTarget(); err != nil { + rw.WriteHeader(http.StatusBadGateway) + rw.Write([]byte("502 Bad Gateway")) + return + } + req = req.WithContext(context.WithValue(req.Context(), "host", host)) + req = req.WithContext(context.WithValue(req.Context(), "target", targetAddr)) + req = req.WithContext(context.WithValue(req.Context(), "req", req)) + + rp.proxy.ServeHTTP(rw, req) +} + +func NewHttpReverseProxy(s *httpServer) *HttpReverseProxy { + rp := &HttpReverseProxy{ + responseHeaderTimeout: 30 * time.Second, + } + local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") + proxy := NewReverseProxy(&httputil.ReverseProxy{ + Director: func(r *http.Request) { + r.URL.Host = r.Host + if host, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil { + logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) + return + } else { + common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false) + } + }, + Transport: &http.Transport{ + ResponseHeaderTimeout: rp.responseHeaderTimeout, + DisableKeepAlives: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var ( + host *file.Host + target net.Conn + err error + connClient io.ReadWriteCloser + targetAddr string + lk *conn.Link + ) + + r := ctx.Value("req").(*http.Request) + host = ctx.Value("host").(*file.Host) + targetAddr = ctx.Value("target").(string) + + lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) + if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { + logs.Notice("connect to target %s error %s", lk.Host, err) + return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the server") + } + connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) + return &flowConn{ + ReadWriteCloser: connClient, + fakeAddr: local, + host: host, + }, nil + }, + }, + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + logs.Warn("do http proxy request error: %v", err) + rw.WriteHeader(http.StatusNotFound) + }, + }) + proxy.WebSocketDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var ( + host *file.Host + target net.Conn + err error + connClient io.ReadWriteCloser + targetAddr string + lk *conn.Link + ) + r := ctx.Value("req").(*http.Request) + host = ctx.Value("host").(*file.Host) + targetAddr = ctx.Value("target").(string) + + lk = conn.NewLink("tcp", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) + if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { + logs.Notice("connect to target %s error %s", lk.Host, err) + return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the target") + } + connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) + return &flowConn{ + ReadWriteCloser: connClient, + fakeAddr: local, + host: host, + }, nil + } + rp.proxy = proxy + return rp +} + +type flowConn struct { + io.ReadWriteCloser + fakeAddr net.Addr + host *file.Host + flowIn int64 + flowOut int64 + once sync.Once +} + +func (c *flowConn) Read(p []byte) (n int, err error) { + n, err = c.ReadWriteCloser.Read(p) + c.flowIn += int64(n) + return n, err +} + +func (c *flowConn) Write(p []byte) (n int, err error) { + n, err = c.ReadWriteCloser.Write(p) + c.flowOut += int64(n) + return n, err +} + +func (c *flowConn) Close() error { + c.once.Do(func() { c.host.Flow.Add(c.flowIn, c.flowOut) }) + return c.ReadWriteCloser.Close() +} + +func (c *flowConn) LocalAddr() net.Addr { return c.fakeAddr } + +func (c *flowConn) RemoteAddr() net.Addr { return c.fakeAddr } + +func (*flowConn) SetDeadline(t time.Time) error { return nil } + +func (*flowConn) SetReadDeadline(t time.Time) error { return nil } + +func (*flowConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/server/proxy/reverseproxy.go b/server/proxy/reverseproxy.go new file mode 100644 index 0000000..df7e866 --- /dev/null +++ b/server/proxy/reverseproxy.go @@ -0,0 +1,136 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package proxy + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "sync" +) + +type HTTPError struct { + error + HTTPCode int +} + +func NewHTTPError(code int, errmsg string) error { + return &HTTPError{ + error: errors.New(errmsg), + HTTPCode: code, + } +} + +type ReverseProxy struct { + *httputil.ReverseProxy + WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error) +} + +func IsWebsocketRequest(req *http.Request) bool { + containsHeader := func(name, value string) bool { + items := strings.Split(req.Header.Get(name), ",") + for _, item := range items { + if value == strings.ToLower(strings.TrimSpace(item)) { + return true + } + } + return false + } + return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket") +} + +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + rp := &ReverseProxy{ + ReverseProxy: httputil.NewSingleHostReverseProxy(target), + WebSocketDialContext: nil, + } + rp.ErrorHandler = rp.errHandler + return rp +} + +func NewReverseProxy(orp *httputil.ReverseProxy) *ReverseProxy { + rp := &ReverseProxy{ + ReverseProxy: orp, + WebSocketDialContext: nil, + } + rp.ErrorHandler = rp.errHandler + return rp +} + +func (p *ReverseProxy) errHandler(rw http.ResponseWriter, r *http.Request, e error) { + if e == io.EOF { + rw.WriteHeader(521) + //rw.Write(getWaitingPageContent()) + } else { + if httperr, ok := e.(*HTTPError); ok { + rw.WriteHeader(httperr.HTTPCode) + } else { + rw.WriteHeader(http.StatusNotFound) + } + rw.Write([]byte("error: " + e.Error())) + } +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if IsWebsocketRequest(req) { + p.serveWebSocket(rw, req) + } else { + p.ReverseProxy.ServeHTTP(rw, req) + } +} + +func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) { + if p.WebSocketDialContext == nil { + rw.WriteHeader(500) + return + } + targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "") + if err != nil { + rw.WriteHeader(501) + return + } + defer targetConn.Close() + + p.Director(req) + + hijacker, ok := rw.(http.Hijacker) + if !ok { + rw.WriteHeader(500) + return + } + conn, _, errHijack := hijacker.Hijack() + if errHijack != nil { + rw.WriteHeader(500) + return + } + defer conn.Close() + + req.Write(targetConn) + Join(conn, targetConn) +} + +func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) { + var wait sync.WaitGroup + pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) { + defer to.Close() + defer from.Close() + defer wait.Done() + + *count, _ = io.Copy(to, from) + } + + wait.Add(2) + go pipe(c1, c2, &inCount) + go pipe(c2, c1, &outCount) + wait.Wait() + return +} diff --git a/server/test/test.go b/server/test/test.go index a30d03d..3fce9ee 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -52,10 +52,10 @@ func TestServerConfig() { if port, err := strconv.Atoi(p); err != nil { log.Fatalln("get https port error", err) } else { - if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) { + if beego.AppConfig.String("pemPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("pemPath"))) { log.Fatalf("ssl certFile %s is not exist", beego.AppConfig.String("pemPath")) } - if !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("ketPath"))) { + if beego.AppConfig.String("keyPath") != "" && !common.FileExists(filepath.Join(common.GetRunPath(), beego.AppConfig.String("keyPath"))) { log.Fatalf("ssl keyFile %s is not exist", beego.AppConfig.String("pemPath")) } isInArr(&postTcpArr, port, "http port", "tcp") From 89e2a4c2ebfef20569d9d2b9230ad4ef37f6bd06 Mon Sep 17 00:00:00 2001 From: snowie2000 Date: Wed, 15 Apr 2020 11:08:05 +0800 Subject: [PATCH 2/4] removed unnecessary host fetch in revereproxy.Director --- server/proxy/http.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/server/proxy/http.go b/server/proxy/http.go index 224183a..4d6ae81 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -155,13 +155,8 @@ func NewHttpReverseProxy(s *httpServer) *HttpReverseProxy { local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") proxy := NewReverseProxy(&httputil.ReverseProxy{ Director: func(r *http.Request) { - r.URL.Host = r.Host - if host, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil { - logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) - return - } else { - common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false) - } + host := r.Context().Value("host").(*file.Host) + common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false) }, Transport: &http.Transport{ ResponseHeaderTimeout: rp.responseHeaderTimeout, From 16be6d1b55d3eaa6fcd3ae9d839d3255d1e89ce8 Mon Sep 17 00:00:00 2001 From: snowie2000 Date: Wed, 15 Apr 2020 11:25:13 +0800 Subject: [PATCH 3/4] Revert http reverse proxy changes --- server/proxy/http.go | 311 ++++++++++++++++++----------------- server/proxy/reverseproxy.go | 136 --------------- 2 files changed, 163 insertions(+), 284 deletions(-) delete mode 100644 server/proxy/reverseproxy.go diff --git a/server/proxy/http.go b/server/proxy/http.go index 4d6ae81..78a26b6 100644 --- a/server/proxy/http.go +++ b/server/proxy/http.go @@ -1,7 +1,7 @@ package proxy import ( - "context" + "bufio" "crypto/tls" "io" "net" @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "strconv" + "strings" "sync" - "time" "ehang.io/nps/bridge" "ehang.io/nps/lib/cache" @@ -101,159 +101,174 @@ func (s *httpServer) Close() error { return nil } +func (s *httpServer) handleTunneling(w http.ResponseWriter, r *http.Request) { + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "Hijacking not supported", http.StatusInternalServerError) + return + } + c, _, err := hijacker.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusServiceUnavailable) + } + s.handleHttp(conn.NewConn(c), r) +} + +func (s *httpServer) handleHttp(c *conn.Conn, r *http.Request) { + var ( + host *file.Host + target net.Conn + err error + connClient io.ReadWriteCloser + scheme = r.URL.Scheme + lk *conn.Link + targetAddr string + lenConn *conn.LenConn + isReset bool + wg sync.WaitGroup + ) + defer func() { + if connClient != nil { + connClient.Close() + } else { + s.writeConnFail(c.Conn) + } + c.Close() + }() +reset: + if isReset { + host.Client.AddConn() + } + if host, err = file.GetDb().GetInfoByHost(r.Host, r); err != nil { + logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) + return + } + if err := s.CheckFlowAndConnNum(host.Client); err != nil { + logs.Warn("client id %d, host id %d, error %s, when https connection", host.Client.Id, host.Id, err.Error()) + return + } + if !isReset { + defer host.Client.AddConn() + } + if err = s.auth(r, c, host.Client.Cnf.U, host.Client.Cnf.P); err != nil { + logs.Warn("auth error", err, r.RemoteAddr) + return + } + if targetAddr, err = host.Target.GetRandomTarget(); err != nil { + logs.Warn(err.Error()) + return + } + lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) + if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { + logs.Notice("connect to target %s error %s", lk.Host, err) + return + } + connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) + + //read from inc-client + go func() { + wg.Add(1) + isReset = false + defer connClient.Close() + defer func() { + wg.Done() + if !isReset { + c.Close() + } + }() + for { + if resp, err := http.ReadResponse(bufio.NewReader(connClient), r); err != nil || resp == nil { + return + } else { + //if the cache is start and the response is in the extension,store the response to the cache list + if s.useCache && r.URL != nil && strings.Contains(r.URL.Path, ".") { + b, err := httputil.DumpResponse(resp, true) + if err != nil { + return + } + c.Write(b) + host.Flow.Add(0, int64(len(b))) + s.cache.Add(filepath.Join(host.Host, r.URL.Path), b) + } else { + lenConn := conn.NewLenConn(c) + if err := resp.Write(lenConn); err != nil { + logs.Error(err) + return + } + host.Flow.Add(0, int64(lenConn.Len)) + } + } + } + }() + + for { + //if the cache start and the request is in the cache list, return the cache + if s.useCache { + if v, ok := s.cache.Get(filepath.Join(host.Host, r.URL.Path)); ok { + n, err := c.Write(v.([]byte)) + if err != nil { + break + } + logs.Trace("%s request, method %s, host %s, url %s, remote address %s, return cache", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String()) + host.Flow.Add(0, int64(n)) + //if return cache and does not create a new conn with client and Connection is not set or close, close the connection. + if strings.ToLower(r.Header.Get("Connection")) == "close" || strings.ToLower(r.Header.Get("Connection")) == "" { + break + } + goto readReq + } + } + + //change the host and header and set proxy setting + common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, c.Conn.RemoteAddr().String(), s.addOrigin) + logs.Trace("%s request, method %s, host %s, url %s, remote address %s, target %s", r.URL.Scheme, r.Method, r.Host, r.URL.Path, c.RemoteAddr().String(), lk.Host) + //write + lenConn = conn.NewLenConn(connClient) + if err := r.Write(lenConn); err != nil { + logs.Error(err) + break + } + host.Flow.Add(int64(lenConn.Len), 0) + + readReq: + //read req from connection + if r, err = http.ReadRequest(bufio.NewReader(c)); err != nil { + break + } + r.URL.Scheme = scheme + //What happened ,Why one character less??? + r.Method = resetReqMethod(r.Method) + if hostTmp, err := file.GetDb().GetInfoByHost(r.Host, r); err != nil { + logs.Notice("the url %s %s %s can't be parsed!", r.URL.Scheme, r.Host, r.RequestURI) + break + } else if host != hostTmp { + host = hostTmp + isReset = true + connClient.Close() + goto reset + } + } + wg.Wait() +} + +func resetReqMethod(method string) string { + if method == "ET" { + return "GET" + } + if method == "OST" { + return "POST" + } + return method +} + func (s *httpServer) NewServer(port int, scheme string) *http.Server { - rProxy := NewHttpReverseProxy(s) return &http.Server{ Addr: ":" + strconv.Itoa(port), Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = scheme - rProxy.ServeHTTP(w, r) + s.handleTunneling(w, r) }), // Disable HTTP/2. TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), } } - -type HttpReverseProxy struct { - proxy *ReverseProxy - - responseHeaderTimeout time.Duration -} - -func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - var ( - host *file.Host - targetAddr string - err error - ) - if host, err = file.GetDb().GetInfoByHost(req.Host, req); err != nil { - rw.WriteHeader(http.StatusNotFound) - rw.Write([]byte(req.Host + " not found")) - return - } - if host.Client.Cnf.U != "" && host.Client.Cnf.P != "" && !common.CheckAuth(req, host.Client.Cnf.U, host.Client.Cnf.P) { - rw.WriteHeader(http.StatusUnauthorized) - rw.Write([]byte("Unauthorized")) - return - } - if targetAddr, err = host.Target.GetRandomTarget(); err != nil { - rw.WriteHeader(http.StatusBadGateway) - rw.Write([]byte("502 Bad Gateway")) - return - } - req = req.WithContext(context.WithValue(req.Context(), "host", host)) - req = req.WithContext(context.WithValue(req.Context(), "target", targetAddr)) - req = req.WithContext(context.WithValue(req.Context(), "req", req)) - - rp.proxy.ServeHTTP(rw, req) -} - -func NewHttpReverseProxy(s *httpServer) *HttpReverseProxy { - rp := &HttpReverseProxy{ - responseHeaderTimeout: 30 * time.Second, - } - local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") - proxy := NewReverseProxy(&httputil.ReverseProxy{ - Director: func(r *http.Request) { - host := r.Context().Value("host").(*file.Host) - common.ChangeHostAndHeader(r, host.HostChange, host.HeaderChange, "", false) - }, - Transport: &http.Transport{ - ResponseHeaderTimeout: rp.responseHeaderTimeout, - DisableKeepAlives: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - var ( - host *file.Host - target net.Conn - err error - connClient io.ReadWriteCloser - targetAddr string - lk *conn.Link - ) - - r := ctx.Value("req").(*http.Request) - host = ctx.Value("host").(*file.Host) - targetAddr = ctx.Value("target").(string) - - lk = conn.NewLink("http", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) - if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { - logs.Notice("connect to target %s error %s", lk.Host, err) - return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the server") - } - connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) - return &flowConn{ - ReadWriteCloser: connClient, - fakeAddr: local, - host: host, - }, nil - }, - }, - ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - logs.Warn("do http proxy request error: %v", err) - rw.WriteHeader(http.StatusNotFound) - }, - }) - proxy.WebSocketDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - var ( - host *file.Host - target net.Conn - err error - connClient io.ReadWriteCloser - targetAddr string - lk *conn.Link - ) - r := ctx.Value("req").(*http.Request) - host = ctx.Value("host").(*file.Host) - targetAddr = ctx.Value("target").(string) - - lk = conn.NewLink("tcp", targetAddr, host.Client.Cnf.Crypt, host.Client.Cnf.Compress, r.RemoteAddr, host.Target.LocalProxy) - if target, err = s.bridge.SendLinkInfo(host.Client.Id, lk, nil); err != nil { - logs.Notice("connect to target %s error %s", lk.Host, err) - return nil, NewHTTPError(http.StatusBadGateway, "Cannot connect to the target") - } - connClient = conn.GetConn(target, lk.Crypt, lk.Compress, host.Client.Rate, true) - return &flowConn{ - ReadWriteCloser: connClient, - fakeAddr: local, - host: host, - }, nil - } - rp.proxy = proxy - return rp -} - -type flowConn struct { - io.ReadWriteCloser - fakeAddr net.Addr - host *file.Host - flowIn int64 - flowOut int64 - once sync.Once -} - -func (c *flowConn) Read(p []byte) (n int, err error) { - n, err = c.ReadWriteCloser.Read(p) - c.flowIn += int64(n) - return n, err -} - -func (c *flowConn) Write(p []byte) (n int, err error) { - n, err = c.ReadWriteCloser.Write(p) - c.flowOut += int64(n) - return n, err -} - -func (c *flowConn) Close() error { - c.once.Do(func() { c.host.Flow.Add(c.flowIn, c.flowOut) }) - return c.ReadWriteCloser.Close() -} - -func (c *flowConn) LocalAddr() net.Addr { return c.fakeAddr } - -func (c *flowConn) RemoteAddr() net.Addr { return c.fakeAddr } - -func (*flowConn) SetDeadline(t time.Time) error { return nil } - -func (*flowConn) SetReadDeadline(t time.Time) error { return nil } - -func (*flowConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/server/proxy/reverseproxy.go b/server/proxy/reverseproxy.go deleted file mode 100644 index df7e866..0000000 --- a/server/proxy/reverseproxy.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// HTTP reverse proxy handler - -package proxy - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "net/http/httputil" - "net/url" - "strings" - "sync" -) - -type HTTPError struct { - error - HTTPCode int -} - -func NewHTTPError(code int, errmsg string) error { - return &HTTPError{ - error: errors.New(errmsg), - HTTPCode: code, - } -} - -type ReverseProxy struct { - *httputil.ReverseProxy - WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error) -} - -func IsWebsocketRequest(req *http.Request) bool { - containsHeader := func(name, value string) bool { - items := strings.Split(req.Header.Get(name), ",") - for _, item := range items { - if value == strings.ToLower(strings.TrimSpace(item)) { - return true - } - } - return false - } - return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket") -} - -func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { - rp := &ReverseProxy{ - ReverseProxy: httputil.NewSingleHostReverseProxy(target), - WebSocketDialContext: nil, - } - rp.ErrorHandler = rp.errHandler - return rp -} - -func NewReverseProxy(orp *httputil.ReverseProxy) *ReverseProxy { - rp := &ReverseProxy{ - ReverseProxy: orp, - WebSocketDialContext: nil, - } - rp.ErrorHandler = rp.errHandler - return rp -} - -func (p *ReverseProxy) errHandler(rw http.ResponseWriter, r *http.Request, e error) { - if e == io.EOF { - rw.WriteHeader(521) - //rw.Write(getWaitingPageContent()) - } else { - if httperr, ok := e.(*HTTPError); ok { - rw.WriteHeader(httperr.HTTPCode) - } else { - rw.WriteHeader(http.StatusNotFound) - } - rw.Write([]byte("error: " + e.Error())) - } -} - -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if IsWebsocketRequest(req) { - p.serveWebSocket(rw, req) - } else { - p.ReverseProxy.ServeHTTP(rw, req) - } -} - -func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) { - if p.WebSocketDialContext == nil { - rw.WriteHeader(500) - return - } - targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "") - if err != nil { - rw.WriteHeader(501) - return - } - defer targetConn.Close() - - p.Director(req) - - hijacker, ok := rw.(http.Hijacker) - if !ok { - rw.WriteHeader(500) - return - } - conn, _, errHijack := hijacker.Hijack() - if errHijack != nil { - rw.WriteHeader(500) - return - } - defer conn.Close() - - req.Write(targetConn) - Join(conn, targetConn) -} - -func Join(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) (inCount int64, outCount int64) { - var wait sync.WaitGroup - pipe := func(to io.ReadWriteCloser, from io.ReadWriteCloser, count *int64) { - defer to.Close() - defer from.Close() - defer wait.Done() - - *count, _ = io.Copy(to, from) - } - - wait.Add(2) - go pipe(c1, c2, &inCount) - go pipe(c2, c1, &outCount) - wait.Wait() - return -} From 7e60ed14b5bc3fe28a6ef4f77e24c550a358276e Mon Sep 17 00:00:00 2001 From: ffdfgdfg Date: Thu, 30 Apr 2020 23:25:59 +0800 Subject: [PATCH 4/4] change tls key pair name --- lib/crypt/tls.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/crypt/tls.go b/lib/crypt/tls.go index c301be8..799a1c0 100644 --- a/lib/crypt/tls.go +++ b/lib/crypt/tls.go @@ -21,7 +21,7 @@ var ( ) func InitTls() { - c, k, err := generateKeyPair("NPS Corp,.Inc") + c, k, err := generateKeyPair("NPS Org") if err == nil { cert, err = tls.X509KeyPair(c, k) }