package nat import ( "net" "sync" "time" "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/utils/wrappers" ) const ( mapTimeout = 30 * time.Minute mapUpdateTimeout = mapTimeout / 2 maxRetries = 20 ) type NATRouter interface { MapPort(protocol string, intport, extport uint16, desc string, duration time.Duration) error UnmapPort(protocol string, extport uint16) error ExternalIP() (net.IP, error) GetPortMappingEntry(extport uint16, protocol string) ( InternalIP string, InternalPort uint16, Description string, err error, ) } func GetNATRouter() NATRouter { //TODO add PMP support if r := getUPnPRouter(); r != nil { return r } return NewNoRouter() } type Mapper struct { log logging.Logger r NATRouter closer chan struct{} wg sync.WaitGroup errLock sync.Mutex errs wrappers.Errs } func NewPortMapper(log logging.Logger, r NATRouter) Mapper { return Mapper{ log: log, r: r, closer: make(chan struct{}), } } // Map sets up port mapping using given protocol, internal and external ports // and returns the final port mapped. It returns 0 if mapping failed after the // maximun number of retries func (dev *Mapper) Map(protocol string, intport, extport uint16, desc string) uint16 { mappedPort := make(chan uint16) dev.wg.Add(1) go dev.keepPortMapping(mappedPort, protocol, intport, extport, desc) return <-mappedPort } // keepPortMapping runs in the background to keep a port mapped. It renews the // the port mapping in mapUpdateTimeout. func (dev *Mapper) keepPortMapping(mappedPort chan<- uint16, protocol string, intport, extport uint16, desc string) { updateTimer := time.NewTimer(mapUpdateTimeout) var port uint16 = 0 defer func() { updateTimer.Stop() dev.log.Info("Unmap protocol %s external port %d", protocol, port) if port > 0 { dev.errLock.Lock() dev.errs.Add(dev.r.UnmapPort(protocol, port)) dev.errLock.Unlock() } dev.wg.Done() }() for i := 0; i < maxRetries; i++ { port = extport + uint16(i) if intaddr, intport, desc, err := dev.r.GetPortMappingEntry(port, protocol); err == nil { dev.log.Info("Port %d is mapped to %s:%d: %s, retry with the next port", port, intaddr, intport, desc) continue } if err := dev.r.MapPort(protocol, intport, port, desc, mapTimeout); err != nil { dev.log.Error("Map port failed. Protocol %s Internal %d External %d. %s", protocol, intport, port, err) dev.errLock.Lock() dev.errs.Add(err) dev.errLock.Unlock() } else { dev.log.Info("Mapped Protocol %s Internal %d External %d.", protocol, intport, port) mappedPort <- port break } } if port == 0 { dev.log.Error("Unable to map port %d", extport) mappedPort <- port return } for { select { case <-updateTimer.C: if err := dev.r.MapPort(protocol, intport, port, desc, mapTimeout); err != nil { dev.log.Error("Renew port mapping failed. Protocol %s Internal %d External %d. %s", protocol, intport, port, err) } else { dev.log.Info("Renew port mapping Protocol %s Internal %d External %d.", protocol, intport, port) } updateTimer.Reset(mapUpdateTimeout) case _, _ = <-dev.closer: return } } } func (dev *Mapper) UnmapAllPorts() error { close(dev.closer) dev.wg.Wait() dev.log.Info("Unmapped all ports") return dev.errs.Err }