diff --git a/nat/nat.go b/nat/nat.go index 76beb2b..9fe97f6 100644 --- a/nat/nat.go +++ b/nat/nat.go @@ -1,3 +1,6 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + package nat import ( @@ -10,14 +13,14 @@ import ( ) const ( - mapTimeout = 30 * time.Minute + mapTimeout = 30 * time.Second mapUpdateTimeout = mapTimeout / 2 maxRetries = 20 ) type NATRouter interface { MapPort(protocol string, intPort, extPort uint16, desc string, duration time.Duration) error - UnmapPort(protocol string, extPort uint16) error + UnmapPort(protocol string, intPort, extPort uint16) error ExternalIP() (net.IP, error) GetPortMappingEntry(extPort uint16, protocol string) ( InternalIP string, @@ -28,10 +31,12 @@ type NATRouter interface { } func GetNATRouter() NATRouter { - //TODO add PMP support if r := getUPnPRouter(); r != nil { return r } + if r := getPMPRouter(); r != nil { + return r + } return NewNoRouter() } @@ -95,11 +100,9 @@ func (dev *Mapper) keepPortMapping(mappedPort chan<- uint16, protocol string, updateTimer.Stop() dev.log.Info("Unmap protocol %s external port %d", protocol, port) - if port > 0 { - dev.errLock.Lock() - dev.errs.Add(dev.r.UnmapPort(protocol, port)) - dev.errLock.Unlock() - } + dev.errLock.Lock() + dev.errs.Add(dev.r.UnmapPort(protocol, intPort, port)) + dev.errLock.Unlock() dev.wg.Done() }(port) @@ -120,7 +123,6 @@ func (dev *Mapper) keepPortMapping(mappedPort chan<- uint16, protocol string, return } } - break } } diff --git a/nat/no_router.go b/nat/no_router.go index 07ac025..7a15601 100644 --- a/nat/no_router.go +++ b/nat/no_router.go @@ -1,3 +1,6 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + package nat import ( @@ -12,14 +15,14 @@ type noRouter struct { ip net.IP } -func (noRouter) MapPort(protocol string, intport, extport uint16, desc string, duration time.Duration) error { - if intport != extport { - return fmt.Errorf("cannot map port %d to %d", intport, extport) +func (noRouter) MapPort(_ string, intPort, extPort uint16, _ string, _ time.Duration) error { + if intPort != extPort { + return fmt.Errorf("cannot map port %d to %d", intPort, extPort) } return nil } -func (noRouter) UnmapPort(protocol string, extport uint16) error { +func (noRouter) UnmapPort(string, uint16, uint16) error { return nil } diff --git a/nat/pmp.go b/nat/pmp.go new file mode 100644 index 0000000..7c5a800 --- /dev/null +++ b/nat/pmp.go @@ -0,0 +1,78 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package nat + +import ( + "fmt" + "net" + "time" + + "github.com/jackpal/gateway" + "github.com/jackpal/go-nat-pmp" +) + +var ( + pmpClientTimeout = 500 * time.Millisecond +) + +// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to +// the common interface. +type pmpRouter struct { + client *natpmp.Client +} + +func (pmp *pmpRouter) MapPort( + networkProtocol string, + newInternalPort uint16, + newExternalPort uint16, + mappingName string, + mappingDuration time.Duration) error { + protocol := string(networkProtocol) + internalPort := int(newInternalPort) + externalPort := int(newExternalPort) + // go-nat-pmp uses seconds to denote their lifetime + lifetime := int(mappingDuration / time.Second) + + _, err := pmp.client.AddPortMapping(protocol, internalPort, externalPort, lifetime) + return err +} + +func (pmp *pmpRouter) UnmapPort( + networkProtocol string, + internalPort uint16, + _ uint16) error { + protocol := string(networkProtocol) + internalPortInt := int(internalPort) + + _, err := pmp.client.AddPortMapping(protocol, internalPortInt, 0, 0) + return err +} + +func (pmp *pmpRouter) ExternalIP() (net.IP, error) { + response, err := pmp.client.GetExternalAddress() + if err != nil { + return nil, err + } + return response.ExternalIPAddress[:], nil +} + +// go-nat-pmp does not support port mapping entry query +func (pmp *pmpRouter) GetPortMappingEntry(externalPort uint16, protocol string) ( + string, uint16, string, error) { + return "", 0, "", fmt.Errorf("port mapping entry not found") +} + +func getPMPRouter() *pmpRouter { + gatewayIP, err := gateway.DiscoverGateway() + if err != nil { + return nil + } + + pmp := &pmpRouter{natpmp.NewClientWithTimeout(gatewayIP, pmpClientTimeout)} + if _, err := pmp.ExternalIP(); err != nil { + return nil + } + + return pmp +} diff --git a/nat/upnp.go b/nat/upnp.go index 863544e..3191bc8 100644 --- a/nat/upnp.go +++ b/nat/upnp.go @@ -1,3 +1,6 @@ +// (c) 2019-2020, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + package nat import ( @@ -119,7 +122,7 @@ func (r *upnpRouter) MapPort(protocol string, intPort, extPort uint16, ip.String(), true, desc, lifetime) } -func (r *upnpRouter) UnmapPort(protocol string, extPort uint16) error { +func (r *upnpRouter) UnmapPort(protocol string, _, extPort uint16) error { return r.client.DeletePortMapping("", extPort, protocol) }