Skip to content

Commit 57981d4

Browse files
committed
Minor code improvements
Added: - Graceful shutdown: Implemented graceful shutdown for the HTTP server to handle termination signals (SIGINT, SIGTERM). - Log level configuration: Added a --log-level flag to configure logging levels (debug, info, warn, error, fatal, panic). - DNS query count metric: Introduced a new Prometheus metric dnsexp_dns_query_count to track the total number of DNS queries. - Error context: Enhanced error messages in loadConfig with additional context for easier debugging. - DNS client pooling: Implemented a sync.Pool to reuse DNS clients, reducing memory allocations and improving performance. Changed: - Logging level for successful queries: Reduced logging level for successful DNS queries from Info to Debug to reduce log noise. - Worker pool buffer size: Reduced the buffer size of the jobs channel in workerPool to match the workerLimit, optimizing memory usage. - Validation for DNS record types: Added validation to ensure record_type in the configuration is a valid DNS record type. Fixed: - Missing DNS server check: Added a check to ensure the DNS server is specified in the configuration file. - Error handling: Improved error handling in loadConfig and added context to error messages. Optimized: - Concurrency: Optimized the worker pool to limit the number of concurrent DNS queries based on the workerLimit flag.
1 parent 9015949 commit 57981d4

File tree

1 file changed

+56
-8
lines changed

1 file changed

+56
-8
lines changed

main.go

+56-8
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ package main
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
"os"
8+
"os/signal"
79
"sync"
10+
"syscall"
811
"time"
912

1013
"github.com/miekg/dns"
@@ -30,6 +33,7 @@ var (
3033
configFile string
3134
dnsTimeout time.Duration
3235
workerLimit int
36+
logLevel string
3337
)
3438

3539
var logger = logrus.New()
@@ -49,30 +53,48 @@ var (
4953
Name: "dnsexp_dns_last_check_timestamp",
5054
Help: "Timestamp of the last DNS query attempt.",
5155
}, []string{"domain", "record_type"})
56+
57+
dnsQueryCount = prometheus.NewCounterVec(prometheus.CounterOpts{
58+
Name: "dnsexp_dns_query_count",
59+
Help: "Total number of DNS queries.",
60+
}, []string{"domain", "record_type"})
5261
)
5362

5463
func registerMetrics() {
5564
prometheus.MustRegister(dnsQueryTime)
5665
prometheus.MustRegister(dnsQuerySuccess)
5766
prometheus.MustRegister(dnsLastCheck)
67+
prometheus.MustRegister(dnsQueryCount)
5868
}
5969

6070
func loadConfig(filename string) (*Config, error) {
6171
data, err := os.ReadFile(filename)
6272
if err != nil {
63-
return nil, err
73+
return nil, fmt.Errorf("failed to read config file: %w", err)
6474
}
6575
var config Config
6676
if err := yaml.Unmarshal(data, &config); err != nil {
67-
return nil, err
77+
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
6878
}
6979
return &config, nil
7080
}
7181

82+
var clientPool = sync.Pool{
83+
New: func() interface{} {
84+
return new(dns.Client)
85+
},
86+
}
87+
7288
func queryDNS(domain, recordType, dnsServer string) {
73-
logger.Infof("Using DNS server: %s", dnsServer)
89+
dnsQueryCount.WithLabelValues(domain, recordType).Inc()
7490

75-
client := new(dns.Client)
91+
if _, ok := dns.StringToType[recordType]; !ok {
92+
logger.Errorf("Invalid record type: %s", recordType)
93+
return
94+
}
95+
96+
client := clientPool.Get().(*dns.Client)
97+
defer clientPool.Put(client)
7698

7799
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
78100
defer cancel()
@@ -99,7 +121,7 @@ func queryDNS(domain, recordType, dnsServer string) {
99121
case dns.RcodeSuccess:
100122
dnsQuerySuccess.WithLabelValues(domain, recordType).Set(1)
101123
dnsQueryTime.WithLabelValues(domain, recordType).Set(duration)
102-
logger.Infof("Query successful for %s (%s)", domain, recordType)
124+
logger.Debugf("Query successful for %s (%s)", domain, recordType)
103125
case dns.RcodeNameError:
104126
logger.Warnf("DNS query failed for %s (%s): NXDOMAIN (domain does not exist)", domain, recordType)
105127
dnsQuerySuccess.WithLabelValues(domain, recordType).Set(0)
@@ -121,7 +143,7 @@ func queryDNS(domain, recordType, dnsServer string) {
121143

122144
func workerPool(config *Config) {
123145
var wg sync.WaitGroup
124-
jobs := make(chan DomainConfig, len(config.Domains))
146+
jobs := make(chan DomainConfig, workerLimit)
125147

126148
for i := 0; i < workerLimit; i++ {
127149
wg.Add(1)
@@ -148,18 +170,43 @@ func metricsHandler(config *Config) http.Handler {
148170
}
149171

150172
func startExporter(cmd *cobra.Command, args []string) {
173+
level, err := logrus.ParseLevel(logLevel)
174+
if err != nil {
175+
logger.Fatalf("Invalid log level: %v", err)
176+
}
177+
logger.SetLevel(level)
178+
151179
config, err := loadConfig(configFile)
152180
if err != nil {
153181
logger.Fatalf("Failed to load config file: %v", err)
154182
}
155183
if len(config.Domains) == 0 {
156184
logger.Warn("No domains configured")
157185
}
186+
if config.DNSServer == "" {
187+
logger.Fatal("No DNS server specified in config")
188+
}
158189
registerMetrics()
159190

160191
http.Handle("/metrics", metricsHandler(config))
161-
logger.Infof("Starting DNS Exporter on %s with timeout %v", listenAddr, dnsTimeout)
162-
logger.Fatal(http.ListenAndServe(listenAddr, nil))
192+
193+
server := &http.Server{Addr: listenAddr}
194+
go func() {
195+
logger.Infof("Starting DNS Exporter on %s with timeout %v", listenAddr, dnsTimeout)
196+
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
197+
logger.Fatalf("Failed to start server: %v", err)
198+
}
199+
}()
200+
201+
stop := make(chan os.Signal, 1)
202+
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
203+
<-stop
204+
205+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
206+
defer cancel()
207+
if err := server.Shutdown(ctx); err != nil {
208+
logger.Errorf("Failed to shutdown server: %v", err)
209+
}
163210
}
164211

165212
func main() {
@@ -173,6 +220,7 @@ func main() {
173220
rootCmd.Flags().StringVarP(&configFile, "config", "c", "domains.yaml", "Path to YAML config file")
174221
rootCmd.Flags().DurationVarP(&dnsTimeout, "timeout", "t", 900*time.Millisecond, "DNS query timeout (e.g., 500ms, 2s)")
175222
rootCmd.Flags().IntVarP(&workerLimit, "workers", "w", 5, "Number of concurrent workers")
223+
rootCmd.Flags().StringVarP(&logLevel, "log-level", "v", "info", "Log level (debug, info, warn, error, fatal, panic)")
176224

177225
if err := rootCmd.Execute(); err != nil {
178226
logger.Error(err)

0 commit comments

Comments
 (0)