forked from AdguardTeam/dnsproxy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcache.go
521 lines (428 loc) · 12.5 KB
/
cache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
package proxy
import (
"bytes"
"encoding/binary"
"math"
"net"
"strings"
"sync"
"time"
glcache "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// defaultCacheSize is the size of cache in bytes by default.
const defaultCacheSize = 64 * 1024
// cache is used to cache requests and used upstreams.
type cache struct {
// items is the requests cache.
items glcache.Cache
// itemsLock protects requests cache.
itemsLock sync.RWMutex
// itemsWithSubnet is the requests cache.
itemsWithSubnet glcache.Cache
// itemsWithSubnetLock protects requests cache.
itemsWithSubnetLock sync.RWMutex
// cacheSize is the size of a key-value pair of cache.
cacheSize int
// optimistic defines if the cache should return expired items and resolve
// those again.
optimistic bool
}
// cacheItem is a single cache entry. It's a helper type to aggregate the
// item-specific logic.
type cacheItem struct {
// m contains the cached response.
m *dns.Msg
// u contains an address of the upstream which resolved m.
u string
}
const (
// packedMsgLenSz is the exact length of byte slice capable to store the
// length of packed DNS message. It's essentially the size of a uint16.
packedMsgLenSz = 2
// expTimeSz is the exact length of byte slice capable to store the
// expiration time the response. It's essentially the size of a uint32.
expTimeSz = 4
// minPackedLen is the minimum length of the packed cacheItem.
minPackedLen = expTimeSz + packedMsgLenSz
)
// pack converts the ci into bytes slice.
func (ci *cacheItem) pack() (packed []byte) {
pm, _ := ci.m.Pack()
pmLen := len(pm)
packed = make([]byte, minPackedLen, minPackedLen+pmLen+len(ci.u))
// Put expiration time.
binary.BigEndian.PutUint32(packed, uint32(time.Now().Unix())+lowestTTL(ci.m))
// Put the length of the packed message.
binary.BigEndian.PutUint16(packed[expTimeSz:], uint16(pmLen))
// Put the packed message itself.
packed = append(packed, pm...)
// Put the address of the upstream.
packed = append(packed, ci.u...)
return packed
}
// optimisticTTL is the default TTL for expired cached responses in seconds.
const optimisticTTL = 10
// unpackItem converts the data into cacheItem using req as a request message.
// expired is true if the item exists but expired. The expired cached items are
// only returned if c is optimistic. req must not be nil.
func (c *cache) unpackItem(data []byte, req *dns.Msg) (ci *cacheItem, expired bool) {
if len(data) < minPackedLen {
return nil, false
}
b := bytes.NewBuffer(data)
expire := int64(binary.BigEndian.Uint32(b.Next(expTimeSz)))
now := time.Now().Unix()
var ttl uint32
if expired = expire <= now; expired {
if !c.optimistic {
return nil, expired
}
ttl = optimisticTTL
} else {
ttl = uint32(expire - now)
}
l := int(binary.BigEndian.Uint16(b.Next(packedMsgLenSz)))
if l == 0 {
return nil, expired
}
m := &dns.Msg{}
if m.Unpack(b.Next(l)) != nil {
return nil, expired
}
res := (&dns.Msg{}).SetRcode(req, m.Rcode)
res.AuthenticatedData = m.AuthenticatedData
res.RecursionAvailable = m.RecursionAvailable
var doBit bool
if o := req.IsEdns0(); o != nil {
doBit = o.Do()
}
// Don't return OPT records from cache since it's deprecated by RFC 6891.
// If the request has DO bit set we only remove all the OPT RRs, and also
// all DNSSEC RRs otherwise.
filterMsg(res, m, req.AuthenticatedData, doBit, ttl)
return &cacheItem{
m: res,
u: string(b.Next(b.Len())),
}, expired
}
// initCache initializes cache if it's enabled.
func (p *Proxy) initCache() {
if !p.CacheEnabled {
return
}
log.Printf("DNS cache is enabled")
c := &cache{
optimistic: p.CacheOptimistic,
cacheSize: p.CacheSizeBytes,
}
p.cache = c
c.initLazy()
if p.EnableEDNSClientSubnet {
c.initLazyWithSubnet()
}
p.shortFlighter = newOptimisticResolver(p)
}
// get returns cached item for the req if it's found. expired is true if the
// item's TTL is expired. key is the resulting key for req. It's returned to
// avoid recalculating it afterwards.
func (c *cache) get(req *dns.Msg) (ci *cacheItem, expired bool, key []byte) {
c.itemsLock.RLock()
defer c.itemsLock.RUnlock()
if c.items == nil || req == nil || len(req.Question) != 1 {
return nil, false, nil
}
key = msgToKey(req)
data := c.items.Get(key)
if data == nil {
return nil, false, key
}
if ci, expired = c.unpackItem(data, req); ci == nil {
c.items.Del(key)
}
return ci, expired, key
}
// getWithSubnet returns cached item for the req if it's found by n. expired is
// true if the item's TTL is expired. k is the resulting key for req. It's
// returned to avoid recalculating it afterwards.
//
// Note that a slow longest-prefix-match algorithm is used, so cache searches
// are performed up to mask+1 times.
func (c *cache) getWithSubnet(req *dns.Msg, n *net.IPNet) (ci *cacheItem, expired bool, k []byte) {
mask, _ := n.Mask.Size()
c.itemsWithSubnetLock.RLock()
defer c.itemsWithSubnetLock.RUnlock()
if c.itemsWithSubnet == nil || req == nil || len(req.Question) != 1 {
return nil, false, nil
}
var data []byte
for mask++; mask > 0 && data == nil; {
mask--
k = msgToKeyWithSubnet(req, n.IP, mask)
data = c.itemsWithSubnet.Get(k)
}
if data == nil {
return nil, false, k
}
if ci, expired = c.unpackItem(data, req); ci == nil {
c.itemsWithSubnet.Del(k)
}
return ci, expired, k
}
// initLazy initializes the cache for general requests.
func (c *cache) initLazy() {
c.itemsLock.Lock()
defer c.itemsLock.Unlock()
if c.items == nil {
c.items = c.createCache()
}
}
// initLazyWithSubnet initializes the cache for requests with subnets.
func (c *cache) initLazyWithSubnet() {
c.itemsWithSubnetLock.Lock()
defer c.itemsWithSubnetLock.Unlock()
if c.itemsWithSubnet == nil {
c.itemsWithSubnet = c.createCache()
}
}
// createCache returns new Cache with predefined settings.
func (c *cache) createCache() (glc glcache.Cache) {
conf := glcache.Config{
MaxSize: defaultCacheSize,
EnableLRU: true,
}
if c.cacheSize > 0 {
conf.MaxSize = uint(c.cacheSize)
}
return glcache.New(conf)
}
// set tries to add the ci into cache.
func (c *cache) set(ci *cacheItem) {
if !isCacheable(ci.m) {
return
}
c.initLazy()
key := msgToKey(ci.m)
packed := ci.pack()
c.itemsLock.RLock()
defer c.itemsLock.RUnlock()
c.items.Set(key, packed)
}
// setWithSubnet tries to add the ci into cache with subnet and ip used to
// calculate the key.
func (c *cache) setWithSubnet(ci *cacheItem, subnet *net.IPNet) {
if !isCacheable(ci.m) {
return
}
c.initLazyWithSubnet()
pref, _ := subnet.Mask.Size()
key := msgToKeyWithSubnet(ci.m, subnet.IP, pref)
packed := ci.pack()
c.itemsWithSubnetLock.RLock()
defer c.itemsWithSubnetLock.RUnlock()
c.itemsWithSubnet.Set(key, packed)
}
// isCacheable checks if m is valid to be cached. For negative answers it
// follows RFC 2308 on how to cache NXDOMAIN and NODATA kinds of responses.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-2.1,
// https://datatracker.ietf.org/doc/html/rfc2308#section-2.2.
func isCacheable(m *dns.Msg) bool {
switch {
case m == nil:
return false
case m.Truncated:
log.Tracef("refusing to cache truncated message")
return false
case len(m.Question) != 1:
log.Tracef("refusing to cache message with wrong number of questions")
return false
case lowestTTL(m) == 0:
return false
}
qName := m.Question[0].Name
switch rcode := m.Rcode; rcode {
case dns.RcodeSuccess:
if qType := m.Question[0].Qtype; qType != dns.TypeA && qType != dns.TypeAAAA {
return true
}
return hasIPAns(m) || isCacheableNegative(m)
case dns.RcodeNameError:
return isCacheableNegative(m)
default:
log.Tracef(
"%s: refusing to cache message with response code %s",
qName,
dns.RcodeToString[rcode],
)
return false
}
}
// hasIPAns check the m for containing at least one A or AAAA RR in answer
// section.
func hasIPAns(m *dns.Msg) (ok bool) {
for _, rr := range m.Answer {
if t := rr.Header().Rrtype; t == dns.TypeA || t == dns.TypeAAAA {
return true
}
}
return false
}
// isCacheableNegative returns true if m's header has at least a single SOA RR
// and no NS records so that it can be declared authoritative.
//
// See https://datatracker.ietf.org/doc/html/rfc2308#section-5 for the
// information on the responses from the authoritative server that should be
// cached by the forwarder.
func isCacheableNegative(m *dns.Msg) (ok bool) {
for _, rr := range m.Ns {
switch rr.Header().Rrtype {
case dns.TypeSOA:
ok = true
case dns.TypeNS:
return false
default:
// Go on.
}
}
return ok
}
// lowestTTL returns the lowest TTL in m's RRs or 0 if the information is
// absent.
func lowestTTL(m *dns.Msg) (ttl uint32) {
ttl = math.MaxUint32
for _, rrset := range [...][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, r := range rrset {
ttl = minTTL(r.Header(), ttl)
}
}
if ttl == math.MaxUint32 {
return 0
}
return ttl
}
// minTTL returns the minimum of h's ttl and the passed ttl.
func minTTL(h *dns.RR_Header, ttl uint32) uint32 {
switch {
case h.Rrtype == dns.TypeOPT:
return ttl
case h.Ttl < ttl:
return h.Ttl
default:
return ttl
}
}
// Updates a given TTL to fall within the range specified by the cacheMinTTL and
// cacheMaxTTL settings.
func respectTTLOverrides(ttl, cacheMinTTL, cacheMaxTTL uint32) uint32 {
if ttl < cacheMinTTL {
return cacheMinTTL
}
if cacheMaxTTL != 0 && ttl > cacheMaxTTL {
return cacheMaxTTL
}
return ttl
}
// msgToKey constructs the cache key from type, class and question's name of m.
func msgToKey(m *dns.Msg) (b []byte) {
q := m.Question[0]
name := q.Name
b = make([]byte, packedMsgLenSz+packedMsgLenSz+len(name))
// Put QTYPE, QCLASS, and QNAME.
binary.BigEndian.PutUint16(b, q.Qtype)
binary.BigEndian.PutUint16(b[packedMsgLenSz:], q.Qclass)
copy(b[2*packedMsgLenSz:], strings.ToLower(name))
return b
}
// msgToKeyWithSubnet constructs the cache key from DO bit, type, class, subnet
// mask, client's IP address and question's name of m. ecsIP is expected to be
// masked already.
func msgToKeyWithSubnet(m *dns.Msg, ecsIP net.IP, mask int) (key []byte) {
q := m.Question[0]
cap := 1 + 2*packedMsgLenSz + 1 + len(q.Name)
ipLen := len(ecsIP)
masked := mask != 0
if masked {
cap += ipLen
}
// Initialize the slice.
key = make([]byte, cap)
k := 0
// Put DO.
if opt := m.IsEdns0(); opt != nil && opt.Do() {
key[k] = 1
} else {
key[k] = 0
}
k++
// Put Qtype.
binary.BigEndian.PutUint16(key[:], q.Qtype)
k += packedMsgLenSz
// Put Qclass.
binary.BigEndian.PutUint16(key[k:], q.Qclass)
k += packedMsgLenSz
// Add mask.
key[k] = uint8(mask)
k++
if masked {
k += copy(key[k:], ecsIP)
}
copy(key[k:], strings.ToLower(q.Name))
return key
}
// isDNSSEC returns true if r is a DNSSEC RR. NSEC, NSEC3, DS, DNSKEY and
// RRSIG/SIG are DNSSEC records.
func isDNSSEC(r dns.RR) bool {
switch r.Header().Rrtype {
case
dns.TypeNSEC,
dns.TypeNSEC3,
dns.TypeDS,
dns.TypeRRSIG,
dns.TypeSIG,
dns.TypeDNSKEY:
return true
default:
return false
}
}
// filterRRSlice removes OPT RRs, DNSSEC RRs except the specified type if do is
// false, sets TTL if ttl is not equal to zero and returns the copy of the rrs.
// The except parameter defines RR of which type should not be filtered out.
func filterRRSlice(rrs []dns.RR, do bool, ttl uint32, except uint16) (filtered []dns.RR) {
rrsLen := len(rrs)
if rrsLen == 0 {
return nil
}
j := 0
rs := make([]dns.RR, rrsLen)
for _, r := range rrs {
if (!do && isDNSSEC(r) && r.Header().Rrtype != except) || r.Header().Rrtype == dns.TypeOPT {
continue
}
if ttl != 0 {
r.Header().Ttl = ttl
}
rs[j] = dns.Copy(r)
j++
}
return rs[:j]
}
// filterMsg removes OPT RRs, DNSSEC RRs if do is false, sets TTL to ttl if it's
// not equal to 0 and puts the results to appropriate fields of dst. It also
// filters the AD bit if both ad and do are false.
func filterMsg(dst, m *dns.Msg, ad, do bool, ttl uint32) {
// As RFC 6840 says, validating resolvers should only set the AD bit when a
// response both meets the conditions listed in RFC 4035, and the request
// contained either a set DO bit or a set AD bit.
dst.AuthenticatedData = dst.AuthenticatedData && (ad || do)
// It's important to filter out only DNSSEC RRs that aren't explicitly
// requested.
//
// See https://datatracker.ietf.org/doc/html/rfc4035#section-3.2.1 and
// https://github.com/AdguardTeam/dnsproxy/issues/144.
dst.Answer = filterRRSlice(m.Answer, do, ttl, m.Question[0].Qtype)
dst.Ns = filterRRSlice(m.Ns, do, ttl, dns.TypeNone)
dst.Extra = filterRRSlice(m.Extra, do, ttl, dns.TypeNone)
}