diff --git a/cmd/dss/main.go b/cmd/dss/main.go index f3ae8b8..36f46c6 100644 --- a/cmd/dss/main.go +++ b/cmd/dss/main.go @@ -23,7 +23,7 @@ var ( Use: "dss", Short: "Scan a domain's DNS records.", Long: "Scan a domain's DNS records.\nhttps://github.com/GlobalCyberAlliance/domain-security-scanner", - Version: "2.4.4", + Version: "2.4.5", PersistentPreRun: func(cmd *cobra.Command, args []string) { if debug { log = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.DebugLevel) diff --git a/cmd/dss/scan.go b/cmd/dss/scan.go index 262cb74..319abe3 100644 --- a/cmd/dss/scan.go +++ b/cmd/dss/scan.go @@ -22,25 +22,26 @@ var cmdScan = &cobra.Command{ Long: "Scan DNS records for one or multiple domains.\nBy default, the command will listen on STDIN, allowing you to type or pipe multiple domains.", Run: func(command *cobra.Command, args []string) { opts := []scanner.ScannerOption{ - scanner.ConcurrentScans(concurrent), - scanner.UseCache(cacheEnabled), - scanner.UseNameservers(nameservers), + scanner.WithCache(cacheEnabled), + scanner.WithConcurrentScans(concurrent), + scanner.WithDKIMSelectors(dkimSelector...), scanner.WithDNSBuffer(dnsBuffer), + scanner.WithNameservers(nameservers), scanner.WithTimeout(time.Duration(timeout) * time.Second), } var source scanner.Source if len(args) == 0 && zoneFile { - source = scanner.ZonefileSource(os.Stdin) + source = scanner.NewSource(os.Stdin, scanner.ZonefileSourceType) } else if len(args) > 0 && zoneFile { log.Fatal().Msg("-z flag provided, but not reading from STDIN") } else if len(args) == 0 { log.Info().Msg("Accepting input from STDIN. Type a domain and hit enter.") - source = scanner.TextSource(os.Stdin) + source = scanner.NewSource(os.Stdin, scanner.TextSourceType) } else { - sr := strings.NewReader(strings.Join(args, "\n")) - source = scanner.TextSource(sr) + reader := strings.NewReader(strings.Join(args, "\n")) + source = scanner.NewSource(reader, scanner.TextSourceType) } sc, err := scanner.New(opts...) @@ -48,8 +49,6 @@ var cmdScan = &cobra.Command{ log.Fatal().Err(err).Msg("An unexpected error occurred.") } - sc.DKIMSelectors = dkimSelector - domainAdvisor := advisor.NewAdvisor(time.Duration(timeout)*time.Second, cacheEnabled) if format == "csv" && outputFile == "" { diff --git a/cmd/dss/serve.go b/cmd/dss/serve.go index bfcea33..f01e2aa 100644 --- a/cmd/dss/serve.go +++ b/cmd/dss/serve.go @@ -50,10 +50,11 @@ var ( server := http.NewServer(log) opts := []scanner.ScannerOption{ - scanner.ConcurrentScans(concurrent), - scanner.UseCache(cacheEnabled), - scanner.UseNameservers(nameservers), + scanner.WithCache(cacheEnabled), + scanner.WithConcurrentScans(concurrent), + scanner.WithDKIMSelectors(dkimSelector...), scanner.WithDNSBuffer(dnsBuffer), + scanner.WithNameservers(nameservers), scanner.WithTimeout(time.Duration(timeout) * time.Second), } @@ -75,10 +76,11 @@ var ( Short: "Serve DNS security queries via a dedicated email account", Run: func(command *cobra.Command, args []string) { opts := []scanner.ScannerOption{ - scanner.ConcurrentScans(concurrent), - scanner.UseCache(cacheEnabled), - scanner.UseNameservers(nameservers), + scanner.WithCache(cacheEnabled), + scanner.WithConcurrentScans(concurrent), + scanner.WithDKIMSelectors(dkimSelector...), scanner.WithDNSBuffer(dnsBuffer), + scanner.WithNameservers(nameservers), scanner.WithTimeout(time.Duration(timeout) * time.Second), } diff --git a/pkg/http/scan.go b/pkg/http/scan.go index 70b0fc4..a44e088 100644 --- a/pkg/http/scan.go +++ b/pkg/http/scan.go @@ -62,10 +62,17 @@ func (s *Server) handleScanDomains(c *gin.Context) { } domainList := strings.NewReader(strings.Join(domains.Domains, "\n")) - source := scanner.TextSource(domainList) + source := scanner.NewSource(domainList, scanner.TextSourceType) + // TODO: temporary solution to allow for custom DKIM selectors in the API. + // This implementation is not ideal, as it will overwrite the selectors for + // future scans. if queryParam, ok := c.GetQuery("dkimSelector"); ok { - s.Scanner.DKIMSelectors = strings.Split(queryParam, ",") + if err := s.Scanner.OverwriteOption(scanner.WithDKIMSelectors(strings.Split(queryParam, ",")...)); err != nil { + s.logger.Error().Err(err).Msg("fai") + s.respond(c, 400, err.Error()) + return + } } var resultsWithAdvice []model.ScanResultWithAdvice diff --git a/pkg/mail/server.go b/pkg/mail/server.go index 9bbb08a..32b738e 100644 --- a/pkg/mail/server.go +++ b/pkg/mail/server.go @@ -77,7 +77,7 @@ func (s *Server) handler() error { } sourceDomainList := strings.NewReader(strings.Join(domainList, "\n")) - source := scanner.TextSource(sourceDomainList) + source := scanner.NewSource(sourceDomainList, scanner.TextSourceType) for result := range s.Scanner.Start(source) { sender := addresses[result.Domain].Address diff --git a/pkg/scanner/requests.go b/pkg/scanner/requests.go index f01b87d..3b922ec 100644 --- a/pkg/scanner/requests.go +++ b/pkg/scanner/requests.go @@ -8,25 +8,55 @@ import ( "github.com/pkg/errors" ) -var knownDkimSelectors = []string{ - "x", // Generic - "google", // Google - "selector1", // Microsoft - "selector2", // Microsoft - "k1", // MailChimp - "mandrill", // Mandrill - "everlytickey1", // Everlytic - "everlytickey2", // Everlytic - "dkim", // Hetzner - "mxvault", // MxVault +// getDNSRecords queries the DNS server for records of a specific type for a domain. +// It returns a slice of strings (the records) and an error if any occurred. +func (s *Scanner) getDNSRecords(domain string, recordType uint16) (records []string, err error) { + answers, err := s.getDNSAnswers(domain, recordType) + if err != nil { + return nil, err + } + + for _, answer := range answers { + if answer.Header().Rrtype == dns.TypeCNAME { + if t, ok := answer.(*dns.CNAME); ok { + recursiveLookupTxt, err := s.getDNSRecords(t.Target, recordType) + if err != nil { + return nil, fmt.Errorf("failed to recursively lookup txt record for %v: %w", t.Target, err) + } + + records = append(records, recursiveLookupTxt...) + + continue + } + + answer.Header().Rrtype = recordType + } + + switch t := answer.(type) { + case *dns.A: + records = append(records, t.A.String()) + case *dns.AAAA: + records = append(records, t.AAAA.String()) + case *dns.MX: + records = append(records, t.Mx) + case *dns.NS: + records = append(records, t.Ns) + case *dns.TXT: + records = append(records, t.Txt...) + } + } + + return records, nil } +// getDNSAnswers queries the DNS server for answers to a specific question. +// It returns a slice of dns.RR (DNS resource records) and an error if any occurred. func (s *Scanner) getDNSAnswers(domain string, recordType uint16) ([]dns.RR, error) { - req := new(dns.Msg) + req := &dns.Msg{} req.SetQuestion(dns.Fqdn(domain), recordType) req.SetEdns0(s.dnsBuffer, true) // increases the response buffer size - in, _, err := s.dnsClient.Exchange(req, s.GetNS()) + in, _, err := s.dnsClient.Exchange(req, s.getNS()) if err != nil { return nil, err } @@ -36,7 +66,7 @@ func (s *Scanner) getDNSAnswers(domain string, recordType uint16) ([]dns.RR, err req.SetEdns0(4096, true) - in, _, err = s.dnsClient.Exchange(req, s.GetNS()) + in, _, err = s.dnsClient.Exchange(req, s.getNS()) if err != nil { return nil, err } @@ -46,30 +76,35 @@ func (s *Scanner) getDNSAnswers(domain string, recordType uint16) ([]dns.RR, err } // GetDNSRecords is a convenience wrapper which will scan all provided DNS record types -// and fill the pointered ScanResult +// and fill the pointered ScanResult. It returns an error if any occurred. func (s *Scanner) GetDNSRecords(scanResult *ScanResult, recordTypes ...string) (err error) { + var records []string + for _, recordType := range recordTypes { switch strings.ToUpper(recordType) { case "A": - scanResult.A, err = s.getTypeA(scanResult.Domain) + scanResult.A, err = s.getDNSRecords(scanResult.Domain, dns.TypeA) case "AAAA": - scanResult.AAAA, err = s.getTypeAAAA(scanResult.Domain) + scanResult.AAAA, err = s.getDNSRecords(scanResult.Domain, dns.TypeAAAA) case "BIMI": scanResult.BIMI, err = s.getTypeBIMI(scanResult.Domain) case "CNAME": - scanResult.CNAME, err = s.getTypeCNAME(scanResult.Domain) + records, err = s.getDNSRecords(scanResult.Domain, dns.TypeCNAME) + if err == nil { + scanResult.CNAME = records[0] + } case "DKIM": scanResult.DKIM, err = s.getTypeDKIM(scanResult.Domain) case "DMARC": scanResult.DMARC, err = s.getTypeDMARC(scanResult.Domain) case "MX": - scanResult.MX, err = s.getTypeMX(scanResult.Domain) + scanResult.MX, err = s.getDNSRecords(scanResult.Domain, dns.TypeMX) case "NS": - scanResult.NS, err = s.getTypeNS(scanResult.Domain) + scanResult.NS, err = s.getDNSRecords(scanResult.Domain, dns.TypeNS) case "SPF": scanResult.SPF, err = s.getTypeSPF(scanResult.Domain) case "TXT": - scanResult.TXT, err = s.getTypeTXT(scanResult.Domain) + scanResult.TXT, err = s.getDNSRecords(scanResult.Domain, dns.TypeTXT) default: return errors.New("invalid dns record type") } @@ -82,50 +117,20 @@ func (s *Scanner) GetDNSRecords(scanResult *ScanResult, recordTypes ...string) ( return nil } -func (s *Scanner) getTypeA(domain string) (records []string, err error) { - answers, err := s.getDNSAnswers(domain, dns.TypeA) - if err != nil { - return nil, err - } - - for _, answer := range answers { - if t, ok := answer.(*dns.A); ok { - records = append(records, t.A.String()) - } - } - - return records, nil -} - -func (s *Scanner) getTypeAAAA(domain string) (records []string, err error) { - answers, err := s.getDNSAnswers(domain, dns.TypeAAAA) - if err != nil { - return nil, err - } - - for _, answer := range answers { - if t, ok := answer.(*dns.AAAA); ok { - records = append(records, t.AAAA.String()) - } - } - - return records, nil -} - func (s *Scanner) getTypeBIMI(domain string) (string, error) { for _, dname := range []string{ "default._bimi." + domain, domain, } { - txtRecords, err := s.getTypeTXT(dname) + records, err := s.getDNSRecords(dname, dns.TypeTXT) if err != nil { return "", nil } - for index, txt := range txtRecords { - if strings.HasPrefix(txt, BIMIPrefix) { + for index, record := range records { + if strings.HasPrefix(record, BIMIPrefix) { // TXT records can be split across multiple strings, so we need to join them - return strings.Join(txtRecords[index:], ""), nil + return strings.Join(records[index:], ""), nil } } } @@ -133,34 +138,21 @@ func (s *Scanner) getTypeBIMI(domain string) (string, error) { return "", nil } -func (s *Scanner) getTypeCNAME(domain string) (string, error) { - answers, err := s.getDNSAnswers(domain, dns.TypeCNAME) - if err != nil { - return "", err - } - - for _, answer := range answers { - if t, ok := answer.(*dns.CNAME); ok { - return t.String(), err - } - } - - return "", nil -} - +// getTypeDKIM queries the DNS server for DKIM records of a domain. +// It returns a string (DKIM record) and an error if any occurred. func (s *Scanner) getTypeDKIM(domain string) (string, error) { - selectors := append(s.DKIMSelectors, knownDkimSelectors...) + selectors := append(s.dkimSelectors, knownDkimSelectors...) for _, selector := range selectors { - txtRecords, err := s.getTypeTXT(selector + "._domainkey." + domain) + records, err := s.getDNSRecords(selector+"._domainkey."+domain, dns.TypeTXT) if err != nil { return "", nil } - for index, txt := range txtRecords { - if strings.HasPrefix(txt, DKIMPrefix) { + for index, record := range records { + if strings.HasPrefix(record, DKIMPrefix) { // TXT records can be split across multiple strings, so we need to join them - return strings.Join(txtRecords[index:], ""), nil + return strings.Join(records[index:], ""), nil } } } @@ -168,20 +160,22 @@ func (s *Scanner) getTypeDKIM(domain string) (string, error) { return "", nil } +// getTypeDMARC queries the DNS server for DMARC records of a domain. +// It returns a string (DMARC record) and an error if any occurred. func (s *Scanner) getTypeDMARC(domain string) (string, error) { for _, dname := range []string{ "_dmarc." + domain, domain, } { - txtRecords, err := s.getTypeTXT(dname) + records, err := s.getDNSRecords(dname, dns.TypeTXT) if err != nil { return "", nil } - for index, txt := range txtRecords { - if strings.HasPrefix(txt, DMARCPrefix) { + for index, record := range records { + if strings.HasPrefix(record, DMARCPrefix) { // TXT records can be split across multiple strings, so we need to join them - return strings.Join(txtRecords[index:], ""), nil + return strings.Join(records[index:], ""), nil } } } @@ -189,49 +183,21 @@ func (s *Scanner) getTypeDMARC(domain string) (string, error) { return "", nil } -func (s *Scanner) getTypeMX(domain string) (records []string, err error) { - answers, err := s.getDNSAnswers(domain, dns.TypeMX) - if err != nil { - return nil, err - } - - for _, answer := range answers { - if t, ok := answer.(*dns.MX); ok { - records = append(records, t.Mx) - } - } - - return records, nil -} - -func (s *Scanner) getTypeNS(domain string) (records []string, err error) { - answers, err := s.getDNSAnswers(domain, dns.TypeNS) - if err != nil { - return nil, err - } - - for _, answer := range answers { - if t, ok := answer.(*dns.NS); ok { - records = append(records, t.Ns) - } - } - - return records, nil -} - +// getTypeSPF queries the DNS server for SPF records of a domain. +// It returns a string (SPF record) and an error if any occurred. func (s *Scanner) getTypeSPF(domain string) (string, error) { - txtRecords, err := s.getTypeTXT(domain) + records, err := s.getDNSRecords(domain, dns.TypeTXT) if err != nil { return "", err } - for _, txt := range txtRecords { - if strings.HasPrefix(txt, SPFPrefix) { - if !strings.Contains(txt, "redirect=") { - return txt, nil + for _, record := range records { + if strings.HasPrefix(record, SPFPrefix) { + if !strings.Contains(record, "redirect=") { + return record, nil } - parts := strings.Fields(txt) + parts := strings.Fields(record) for _, part := range parts { if strings.Contains(part, "redirect=") { redirectDomain := strings.TrimPrefix(part, "redirect=") @@ -243,33 +209,3 @@ func (s *Scanner) getTypeSPF(domain string) (string, error) { return "", nil } - -func (s *Scanner) getTypeTXT(domain string) (records []string, err error) { - answers, err := s.getDNSAnswers(domain, dns.TypeTXT) - if err != nil { - return nil, err - } - - for _, answer := range answers { - // handle recursive lookups - if answer.Header().Rrtype == dns.TypeCNAME { - if t, ok := answer.(*dns.CNAME); ok { - recursiveLookupTxt, err := s.getTypeTXT(t.Target) - if err != nil { - return nil, fmt.Errorf("failed to recursively lookup txt record for %v: %w", t.Target, err) - } - - records = append(records, recursiveLookupTxt...) - - continue - } - } - - answer.Header().Rrtype = dns.TypeTXT - if t, ok := answer.(*dns.TXT); ok { - records = append(records, t.Txt...) - } - } - - return records, nil -} diff --git a/pkg/scanner/scanner.go b/pkg/scanner/scanner.go index ff68e41..878936e 100644 --- a/pkg/scanner/scanner.go +++ b/pkg/scanner/scanner.go @@ -25,20 +25,26 @@ var ( DKIMPrefix = DefaultDKIMPrefix DMARCPrefix = DefaultDMARCPrefix SPFPrefix = DefaultSPFPrefix + + // knownDkimSelectors is a list of known DKIM selectors. + knownDkimSelectors = []string{ + "x", // Generic + "google", // Google + "selector1", // Microsoft + "selector2", // Microsoft + "k1", // MailChimp + "mandrill", // Mandrill + "everlytickey1", // Everlytic + "everlytickey2", // Everlytic + "dkim", // Hetzner + "mxvault", // MxVault + } ) type ( // Scanner is a type that queries the DNS records for domain names, looking // for specific resource records. Scanner struct { - // DKIMSelectors is used to specify where a DKIM record is hosted for - // a specific domain. - DKIMSelectors []string - - // Nameservers is a slice of "host:port" strings of nameservers to - // issue queries against. - Nameservers []string - // cache is a simple in-memory cache to reduce external requests from // the scanner. cache *cache.Cache @@ -47,6 +53,10 @@ type ( // cache or not. cacheEnabled bool + // dkimSelectors is used to specify where a DKIM record is hosted for + // a specific domain. + dkimSelectors []string + // DNS client shared by all goroutines the scanner spawns. dnsClient *dns.Client @@ -54,12 +64,16 @@ type ( // DNS responses dnsBuffer uint16 - // The index of the last-used nameserver, from the Nameservers slice. + // The index of the last-used nameserver, from the nameservers slice. // // This field is managed by atomic operations, and should only ever - // be referenced by the (*Scanner).GetNS() method. + // be referenced by the (*Scanner).getNS() method. lastNameserverIndex uint32 + // nameservers is a slice of "host:port" strings of nameservers to + // issue queries against. + nameservers []string + // A channel to use as a semaphore for limiting the number of DNS // queries that can be made concurrently. sem chan struct{} @@ -89,8 +103,9 @@ type ( // New initializes and returns a new *Scanner. func New(options ...ScannerOption) (*Scanner, error) { s := &Scanner{ - dnsClient: new(dns.Client), - dnsBuffer: 1024, + dnsClient: new(dns.Client), + dnsBuffer: 1024, + nameservers: []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53"}, } for _, option := range options { @@ -110,12 +125,29 @@ func New(options ...ScannerOption) (*Scanner, error) { return s, nil } -// ConcurrentScans sets the number of domains that will be scanned +// OverwriteOption allows the caller to overwrite an existing option. +func (s *Scanner) OverwriteOption(option ScannerOption) error { + return option(s) +} + +// WithCache enables domain caching for the scanner. This is a simple +// implementation, intended to mitigate abuse attempts. +func WithCache(enable bool) ScannerOption { + return func(s *Scanner) error { + if enable { + s.cache = cache.New(1*time.Minute, 5*time.Minute) + s.cacheEnabled = true + } + return nil + } +} + +// WithConcurrentScans sets the number of domains that will be scanned // concurrently. // // If n <= 0, then this option will default to the return value of // runtime.NumCPU(). -func ConcurrentScans(n int) ScannerOption { +func WithConcurrentScans(n int) ScannerOption { return func(s *Scanner) error { if n <= 0 { n = runtime.NumCPU() @@ -131,22 +163,31 @@ func ConcurrentScans(n int) ScannerOption { } } -// UseCache enables domain caching for the scanner. This is a simple -// implementation, intended to mitigate abuse attempts. -func UseCache(enable bool) ScannerOption { +// WithDKIMSelectors allows the caller to specify which DKIM selectors to +// scan for (falling back to the default selectors if none are provided). +func WithDKIMSelectors(selectors ...string) ScannerOption { return func(s *Scanner) error { - if enable { - s.cache = cache.New(1*time.Minute, 5*time.Minute) - s.cacheEnabled = true + s.dkimSelectors = selectors + return nil + } +} + +// WithDNSBuffer increases the allocated buffer for DNS responses +func WithDNSBuffer(bufferSize uint16) ScannerOption { + return func(s *Scanner) error { + if bufferSize > 4096 { + return errors.New("buffer size should not be larger than 4096") } + + s.dnsBuffer = bufferSize return nil } } -// UseNameservers allows the caller to provide a custom set of nameservers for +// WithNameservers allows the caller to provide a custom set of nameservers for // a *Scanner to use. If ns is nil, or zero-length, the *Scanner will use // the nameservers specified in /etc/resolv.conf. -func UseNameservers(ns []string) ScannerOption { +func WithNameservers(ns []string) ScannerOption { return func(s *Scanner) error { // If the provided slice of nameservers is nil, or has zero // elements, load up /etc/resolv.conf, and get the "nameserver" @@ -180,24 +221,12 @@ func UseNameservers(ns []string) ScannerOption { } } - s.Nameservers = ns[:] + s.nameservers = ns[:] return nil } } -// WithDNSBuffer increases the allocated buffer for DNS responses -func WithDNSBuffer(bufferSize uint16) ScannerOption { - return func(s *Scanner) error { - if bufferSize > 4096 { - return errors.New("buffer size should not be larger than 4096") - } - - s.dnsBuffer = bufferSize - return nil - } -} - // WithTimeout sets the timeout duration of a DNS query. func WithTimeout(timeout time.Duration) ScannerOption { return func(s *Scanner) error { @@ -221,8 +250,8 @@ func (s *Scanner) start(src Source, ch chan *ScanResult) { for domain := range src.Read() { <-s.sem wg.Add(1) - go func(dname string) { - ch <- s.Scan(dname) + go func(domain string) { + ch <- s.Scan(domain) s.sem <- struct{}{} wg.Done() }(domain) @@ -241,30 +270,30 @@ func (s *Scanner) Scan(domain string) *ScanResult { } // check that the domain name is valid - recs, err := s.getDNSAnswers(domain, dns.TypeNS) - if err != nil || len(recs) == 0 { + records, err := s.getDNSAnswers(domain, dns.TypeNS) + if err != nil || len(records) == 0 { return &ScanResult{ Domain: domain, Error: "invalid domain name", } } - res := &ScanResult{Domain: domain} + result := &ScanResult{Domain: domain} start := time.Now() - if err = s.GetDNSRecords(res, "BIMI", "DKIM", "DMARC", "MX", "NS", "SPF"); err != nil { - res.Error = err.Error() + if err = s.GetDNSRecords(result, "BIMI", "DKIM", "DMARC", "MX", "NS", "SPF"); err != nil { + result.Error = err.Error() } - res.Elapsed = time.Since(start).Milliseconds() + result.Elapsed = time.Since(start).Milliseconds() if s.cacheEnabled { - s.cache.Set(domain, res, 1*time.Minute) + s.cache.Set(domain, result, 1*time.Minute) } - return res + return result } -func (s *Scanner) GetNS() string { - return s.Nameservers[int(atomic.AddUint32(&s.lastNameserverIndex, 1))%len(s.Nameservers)] +func (s *Scanner) getNS() string { + return s.nameservers[int(atomic.AddUint32(&s.lastNameserverIndex, 1))%len(s.nameservers)] } diff --git a/pkg/scanner/source.go b/pkg/scanner/source.go index b79103d..0912b1d 100644 --- a/pkg/scanner/source.go +++ b/pkg/scanner/source.go @@ -8,27 +8,34 @@ import ( "github.com/miekg/dns" ) -// Source defines the interface of a data source that feeds a Scanner. -type Source interface { - Read() <-chan string - Close() error -} +const ( + TextSourceType SourceType = iota + ZonefileSourceType +) -// ZonefileSource returns a Source that can be used by a Scanner to read -// domain names from an io.Reader that reads from a RFC 1035 formatted zone -// file. -func ZonefileSource(r io.Reader) Source { - return &zonefileSource{reader: r} -} +type ( + SourceType int + + // Source defines the interface of a data source that feeds a Scanner. + Source interface { + Read() <-chan string + Close() error + } + + source struct { + ch chan string + closed bool + reader io.Reader + stop chan struct{} + sourceType SourceType + } +) -type zonefileSource struct { - reader io.Reader - ch chan string - stop chan struct{} - closed bool +func NewSource(reader io.Reader, sourceType SourceType) Source { + return &source{reader: reader, sourceType: sourceType} } -func (src *zonefileSource) Read() <-chan string { +func (src *source) Read() <-chan string { if src.closed { return nil } @@ -48,94 +55,46 @@ func (src *zonefileSource) Read() <-chan string { return src.ch } -func (src *zonefileSource) read() { +func (src *source) read() { defer close(src.ch) - zoneParser := dns.NewZoneParser(src.reader, "", "") - zoneParser.SetIncludeAllowed(true) - - for tok, ok := zoneParser.Next(); ok; _, ok = zoneParser.Next() { - if tok.Header().Rrtype == dns.TypeNS { - continue - } - - name := strings.Trim(tok.Header().Name, ".") - if !strings.Contains(name, ".") { - // we have an NS record that serves as an anchor, and should skip it - continue - } - - select { - case src.ch <- name: - case <-src.stop: - return - } - } -} - -func (src *zonefileSource) Close() error { - if src.closed { - return nil - } - if len(src.ch) > 0 { - src.stop <- struct{}{} - - // drain the channel - for range src.ch { + switch src.sourceType { + case TextSourceType: + sc := bufio.NewScanner(src.reader) + for sc.Scan() { + domain := strings.Trim(sc.Text(), ".") + + select { + case src.ch <- domain: + case <-src.stop: + return + } } - } - close(src.ch) - close(src.stop) - src.closed = true - return nil -} - -// TextSource returns a new Source that can be used by a Scanner to read -// newline-separated domain names from r. -func TextSource(r io.Reader) Source { - return &textSource{reader: r} -} - -type textSource struct { - ch chan string - closed bool - reader io.Reader - stop chan struct{} -} - -func (src *textSource) Read() <-chan string { - if src.closed { - return nil - } - - if src.ch != nil { - return src.ch - } - - src.ch = make(chan string) - src.stop = make(chan struct{}) - - go src.read() - - return src.ch -} - -func (src *textSource) read() { - defer close(src.ch) - - sc := bufio.NewScanner(src.reader) - for sc.Scan() { - domain := strings.Trim(sc.Text(), ".") - - select { - case src.ch <- domain: - case <-src.stop: - return + case ZonefileSourceType: + zoneParser := dns.NewZoneParser(src.reader, "", "") + zoneParser.SetIncludeAllowed(true) + + for tok, ok := zoneParser.Next(); ok; tok, ok = zoneParser.Next() { + if tok.Header().Rrtype == dns.TypeNS { + continue + } + + name := strings.Trim(tok.Header().Name, ".") + if !strings.Contains(name, ".") { + // we have an NS record that serves as an anchor, and should skip it + continue + } + + select { + case src.ch <- name: + case <-src.stop: + return + } } } } -func (src *textSource) Close() error { +func (src *source) Close() error { if src.closed { return nil } @@ -148,7 +107,9 @@ func (src *textSource) Close() error { } } + close(src.ch) close(src.stop) src.closed = true + return nil }