From 06aa2b9c980a96d2b7ae6e986231c8971c819a28 Mon Sep 17 00:00:00 2001 From: David Stainton Date: Sat, 4 Apr 2015 22:56:08 +0000 Subject: [PATCH] critical fix to getOverlapBytes - this fixes several bugs in getOverlapBytes - removes a few small helper functions - removes several helper function unit tests --- retrospective.go | 160 +++++++++++++----------------------- retrospective_test.go | 186 +----------------------------------------- 2 files changed, 58 insertions(+), 288 deletions(-) diff --git a/retrospective.go b/retrospective.go index 1161595..7071a9f 100644 --- a/retrospective.go +++ b/retrospective.go @@ -43,6 +43,8 @@ func injectionInStreamRing(p PacketManifest, flow *types.TcpIpFlow, ringPtr *typ return nil } + log.Printf("len overlapBytes %d startOffset %d endOffset %d\n", len(overlapBytes), startOffset, endOffset) + if !bytes.Equal(overlapBytes, p.Payload[startOffset:endOffset]) { log.Print("injection attack detected\n") e := &types.Event{ @@ -63,86 +65,87 @@ func injectionInStreamRing(p PacketManifest, flow *types.TcpIpFlow, ringPtr *typ } } -// getOverlapBytes returns the overlap byte array; that is the contiguous data stored in our ring buffer +// getOverlapBytes takes several arguments: +// head and tail - ring pointers used to indentify a list of ring elements. +// start and end - sequence numbers representing locations in head and tail respectively. +// NOTE: here we assume that the head and tail were calculated properly such that: +// 1. start must be located within the head segment's sequence boundaries or BEFORE. +// 2. end must be located within the tail segment's sequence boundaries or AFTER. +// normally head and tail values would be procured with a call to getOverlapRings like this: +// head, tail := getOverlapRings(p, flow, ringPtr) +// Given these arguments, getOverlapBytes returns the overlap byte array; +// that is the contiguous data stored in our ring buffer // that overlaps with the stream segment specified by the start and end Sequence boundaries. // The other return values are the slice offsets of the original packet payload that can be used to derive -// the new overlapping portion of the stream segment. +// calculate the section of the packet that has overlapped with our Reassembly ring buffer. func getOverlapBytes(head, tail *types.Ring, start, end types.Sequence) ([]byte, int, int) { var overlapStartSlice, overlapEndSlice int var overlapBytes []byte + if head == nil || tail == nil { panic("wtf; head or tail is nil\n") } - if len(head.Reassembly.Bytes) == 0 { panic("length of head ring element is zero") } - if len(tail.Reassembly.Bytes) == 0 { panic("length of tail ring element is zero") } - // XXX todo : something here is broken. fix it! - sequenceStart, overlapStartSlice := getStartOverlapSequenceAndOffset(head, start) - headOffset := getHeadRingOffset(head, sequenceStart) - - if headOffset > len(head.Reassembly.Bytes) { - log.Printf("sequenceStart %d overlapStartSlice %d headOffset %d\n", sequenceStart, overlapStartSlice, headOffset) - panic("getOverlapBytes headOffset > len head") + packetLength := start.Difference(end) + if packetLength <= 0 { + panic("wtf") } - - if headOffset < 0 { - panic("headOffset is below zero") + var headOffset int + tailLastSeq := tail.Reassembly.Seq.Add(len(tail.Reassembly.Bytes) - 1) + diff := head.Reassembly.Seq.Difference(start) + if diff < 0 { + headOffset = 0 + overlapStartSlice = -1 * diff + } else if diff == 0 { + headOffset = 0 + overlapStartSlice = 0 + } else { + headOffset = diff + overlapStartSlice = 0 } - - sequenceEnd, overlapEndOffset := getEndOverlapSequenceAndOffset(tail, end) - tailOffset := getTailRingOffset(tail, sequenceEnd) - if head.Reassembly.Seq == tail.Reassembly.Seq { - var endOffset int - log.Print("head == tail\n") - - if tailOffset < 0 { - panic("wtf") - } else { - endOffset = len(head.Reassembly.Bytes) - tailOffset - } - - if overlapStartSlice < 0 { - panic("wtf impossible") + var endOffset int + diff = tailLastSeq.Difference(end) + if diff <= 0 { + overlapEndSlice = packetLength + 1 + tailDiff := end.Difference(tailLastSeq) + endOffset = len(head.Reassembly.Bytes) - tailDiff } else { - overlapEndSlice = len(head.Reassembly.Bytes) - tailOffset + overlapStartSlice - headOffset + overlapEndSlice = packetLength - diff + 1 + endOffset = len(head.Reassembly.Bytes) + log.Printf("endOffset %d diff %d", endOffset, diff) } - + log.Printf("len head %d headOffset %d endOffset %d", len(head.Reassembly.Bytes), headOffset, endOffset) overlapBytes = head.Reassembly.Bytes[headOffset:endOffset] } else { log.Print("head != tail\n") - - if tailOffset < 0 { - panic("wtf") - } - - // XXX wrong - totalLen := start.Difference(end) + 1 - - if overlapEndOffset < 0 { - overlapEndSlice = 0 + diff = tailLastSeq.Difference(end) + var tailSlice int + // if end is equal or less than tailLastSeq + if diff <= 0 { + overlapEndSlice = packetLength + if (-1 * diff) > len(tail.Reassembly.Bytes) { + tailSlice = len(tail.Reassembly.Bytes) + } else { + tailSlice = len(tail.Reassembly.Bytes) - (diff * -1) + } } else { - overlapEndSlice = totalLen - overlapEndOffset - } - - tailSlice := len(tail.Reassembly.Bytes) - tailOffset - - if tailSlice < 0 { - panic("tailSlice is below zero") + overlapEndSlice = packetLength - diff + 1 + tailSlice = len(tail.Reassembly.Bytes) } - overlapBytes = getRingSlice(head, tail, headOffset, tailSlice) if overlapBytes == nil { return nil, 0, 0 } } + log.Printf("len overlapBytes %d overlapStartSlice %d overlapEndSlice %d", len(overlapBytes), overlapStartSlice, overlapEndSlice) return overlapBytes, overlapStartSlice, overlapEndSlice } @@ -232,26 +235,14 @@ func getStartSequence(head *types.Ring, start types.Sequence) types.Sequence { return startSeq } -// getEndSequence receives a ring pointer and an ending sequence number -// and returns the closest available ending sequence number that is available from the ring. -func getEndSequence(tail *types.Ring, end types.Sequence) types.Sequence { - var seqEnd types.Sequence - diff := tail.Reassembly.Seq.Add(len(tail.Reassembly.Bytes) - 1).Difference(end) - if diff <= 0 { - seqEnd = end - } else { - seqEnd = tail.Reassembly.Seq.Add(len(tail.Reassembly.Bytes) - 1) - } - return seqEnd -} - // getRingSlice returns a byte slice from the ring buffer given the head -// and tail of the ring segment. sliceStart indicates the zero-indexed byte offset into -// the head that we should copy from; sliceEnd indicates the number of bytes from the tail -// that we should skip. +// and tail of the ring segment AND the slice indexes for head and tail. +// That is, for head's byte slice, sliceStart is the a slice start index. +// For tail's byte slice, sliceEnd is the slice end index. func getRingSlice(head, tail *types.Ring, sliceStart, sliceEnd int) []byte { var overlapBytes []byte if sliceStart < 0 || sliceEnd < 0 { + log.Printf("sliceStart %d sliceEnd %d", sliceStart, sliceEnd) panic("sliceStart < 0 || sliceEnd < 0") } if sliceStart >= len(head.Reassembly.Bytes) { @@ -275,42 +266,3 @@ func getRingSlice(head, tail *types.Ring, sliceStart, sliceEnd int) []byte { overlapBytes = append(overlapBytes, tail.Reassembly.Bytes[:sliceEnd]...) return overlapBytes } - -// getHeadRingOffset receives a given ring element and starting sequence number -// and returns the offset into the ring element where the start sequence is found -func getHeadRingOffset(head *types.Ring, start types.Sequence) int { - return head.Reassembly.Seq.Difference(start) -} - -// getStartOverlapSequenceAndOffset takes a ring element and start sequence and -// returns the closest sequence number available in the element... and the offset -// from the beginning of that element -func getStartOverlapSequenceAndOffset(head *types.Ring, start types.Sequence) (types.Sequence, int) { - seqStart := getStartSequence(head, start) - offset := int(start.Difference(seqStart)) - if offset < 0 { - panic("getStartOverlapSequenceAndOffset offset < 0") - } - return seqStart, offset -} - -// getRingSegmentLastSequence returns the last sequence number represented by -// a given ring elements stream segment -func getRingSegmentLastSequence(segment *types.Ring) types.Sequence { - return segment.Reassembly.Seq.Add(len(segment.Reassembly.Bytes) - 1) -} - -// getTailRingOffset returns the number of bytes the from end of the -// ring element's stream segment that the end sequence is found -func getTailRingOffset(tail *types.Ring, end types.Sequence) int { - tailEndSequence := getRingSegmentLastSequence(tail) - return end.Difference(tailEndSequence) -} - -// getEndOverlapSequenceAndOffset receives a ring element and end sequence. -// It returns the last sequence number represented by that ring element and the offset from the end. -func getEndOverlapSequenceAndOffset(tail *types.Ring, end types.Sequence) (types.Sequence, int) { - seqEnd := getEndSequence(tail, end) - offset := int(seqEnd.Difference(end)) - return seqEnd, offset -} diff --git a/retrospective_test.go b/retrospective_test.go index d8841e2..6544113 100644 --- a/retrospective_test.go +++ b/retrospective_test.go @@ -298,33 +298,6 @@ func TestGetRingSlicePanic4(t *testing.T) { _ = getRingSlice(head, head, 0, 0) } -func TestGetEndSequence(t *testing.T) { - var tail *types.Ring = types.NewRing(10) - var end types.Sequence - - end = 9 - tail.Reassembly = &types.Reassembly{ - Seq: 5, - Bytes: []byte{1, 2, 3, 4, 5}, - } - endSeq := getEndSequence(tail, end) - if endSeq.Difference(end) != 0 { - t.Errorf("endSeq %d != end %d\n", endSeq, end) - t.Fail() - } - - end = 9 - tail.Reassembly = &types.Reassembly{ - Seq: 5, - Bytes: []byte{1, 2, 3, 4, 5}, - } - endSeq = getEndSequence(tail, end.Add(1)) - if endSeq.Difference(end) != 0 { - t.Errorf("endSeq %d != end %d\n", endSeq, end) - t.Fail() - } -} - func TestGetStartSequence(t *testing.T) { var start types.Sequence = 4 var head *types.Ring = types.NewRing(10) @@ -346,162 +319,6 @@ func TestGetStartSequence(t *testing.T) { } } -func TestGetHeadRingOffset(t *testing.T) { - head := types.NewRing(3) - head.Reassembly = &types.Reassembly{ - Seq: 3, - Bytes: []byte{1, 2, 3, 4, 5, 6, 7}, - } - offset := getHeadRingOffset(head, 5) - if offset < 0 { - t.Error("offset less than zero\n") - t.Fail() - } - if offset != 2 { - t.Error("offset incorrect\n") - t.Fail() - } - offset = getHeadRingOffset(head, 3) - if offset != 0 { - t.Error("offset incorrect\n") - t.Fail() - } - - offset = getHeadRingOffset(head, 4) - if offset != 1 { - t.Error("offset incorrect\n") - t.Fail() - } -} - -func TestGetTailRingOffset(t *testing.T) { - tail := types.NewRing(3) - tail.Reassembly = &types.Reassembly{ - Seq: 3, - Bytes: []byte{1, 2, 3, 4, 5, 6, 7}, - } - - offset := getTailRingOffset(tail, 4) - if offset != 5 { - t.Errorf("want 5 got %d\n", offset) - t.Fail() - } - - offset = getTailRingOffset(tail, 5) - if offset != 4 { - t.Errorf("want 4 got %d\n", offset) - t.Fail() - } - - offset = getTailRingOffset(tail, 6) - if offset != 3 { - t.Errorf("want 3 got %d\n", offset) - t.Fail() - } -} - -func TestGetStartOverlapSequenceAndOffset(t *testing.T) { - var start types.Sequence = 3 - head := types.NewRing(3) - head.Reassembly = &types.Reassembly{ - Seq: 3, - Bytes: []byte{1, 2, 3, 4, 5, 6, 7}, - } - sequence, offset := getStartOverlapSequenceAndOffset(head, start) - if offset != 0 { - t.Error("offset != 0\n") - t.Fail() - } - if sequence != 3 { - t.Error("incorrect sequence") - t.Fail() - } - - start = 4 - sequence, offset = getStartOverlapSequenceAndOffset(head, start) - if offset != 0 { - t.Errorf("offset %d != 1\n", offset) - t.Fail() - } - if sequence != 4 { - t.Error("incorrect sequence") - t.Fail() - } - - start = 2 - sequence, offset = getStartOverlapSequenceAndOffset(head, start) - if offset != 1 { - t.Errorf("offset %d != 1\n", offset) - t.Fail() - } - if sequence != 3 { - t.Error("incorrect sequence") - t.Fail() - } - - start = 1 - sequence, offset = getStartOverlapSequenceAndOffset(head, start) - if offset != 2 { - t.Errorf("offset %d != 2\n", offset) - t.Fail() - } - if sequence != 3 { - t.Error("incorrect sequence") - t.Fail() - } -} - -func TestGetEndOverlapSequenceAndOffset(t *testing.T) { - var end types.Sequence = 3 - tail := types.NewRing(3) - tail.Reassembly = &types.Reassembly{ - Seq: 3, - Bytes: []byte{1, 2, 3, 4, 5, 6, 7}, - } - sequence, offset := getEndOverlapSequenceAndOffset(tail, end) - if offset != 0 { - t.Error("offset != 0\n") - t.Fail() - } - if sequence != 3 { - t.Error("incorrect sequence") - t.Fail() - } - - end = 9 - sequence, offset = getEndOverlapSequenceAndOffset(tail, end) - if offset != 0 { - t.Error("offset != 0\n") - t.Fail() - } - if sequence != end { - t.Error("incorrect sequence") - t.Fail() - } - - end = 10 - sequence, offset = getEndOverlapSequenceAndOffset(tail, end) - if offset != 1 { - t.Error("offset != 1\n") - t.Fail() - } - if sequence != end-1 { - t.Error("incorrect sequence") - t.Fail() - } - - end = 11 - sequence, offset = getEndOverlapSequenceAndOffset(tail, end) - if offset != 2 { - t.Error("offset != 2\n") - t.Fail() - } - if sequence != end-2 { - t.Error("incorrect sequence") - t.Fail() - } -} - func TestGetOverlapBytes(t *testing.T) { overlapBytesTests := []struct { in reassemblyInput @@ -539,7 +356,7 @@ func TestGetOverlapBytes(t *testing.T) { reassemblyInput{4, []byte{1, 2, 3, 4, 5, 6, 7}}, TestOverlapBytesWant{ bytes: []byte{6, 7, 8, 9, 10, 11}, startOffset: 1, - endOffset: 7, + endOffset: 6, }, }, { @@ -660,6 +477,7 @@ func TestGetOverlapBytes(t *testing.T) { continue } + log.Printf("test #%d", i) overlapBytes, startOffset, endOffset := getOverlapBytes(head, tail, start, end) if startOffset != overlapBytesTests[i].want.startOffset {