Skip to content

Commit 6d9c840

Browse files
Initial code commit
1 parent 83b800e commit 6d9c840

File tree

5 files changed

+1024
-1
lines changed

5 files changed

+1024
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.swp
2+
*.out

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2017 go-test
3+
Copyright 2015-2017 Daniel Nichter
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Deep Variable Equality for Humans
2+
3+
This package provides a single function: `deep.Equal`. It's like [reflect.DeepEqual](http://golang.org/pkg/reflect/#DeepEqual) but much friendlier to humans (or any sentient being) for two reason:
4+
5+
* `deep.Equal` returns a list of differences
6+
* `deep.Equal` does not compare unexported fields (by default)
7+
8+
`reflect.DeepEqual` is good (like all things Golang!), but it's a game of [Hunt the Wumpus](https://en.wikipedia.org/wiki/Hunt_the_Wumpus). For large maps, slices, and structs, finding the difference is difficult.
9+
10+
`deep.Equal` doesn't play games with you, it lists the differences:
11+
12+
```go
13+
package main_test
14+
15+
import (
16+
"testing"
17+
"github.com/go-test/deep"
18+
)
19+
20+
type T struct {
21+
Name string
22+
Numbers []float64
23+
}
24+
25+
func TestDeepEqual(t *testing.T) {
26+
// Can you spot the difference?
27+
t1 := T{
28+
Name: "Isabella",
29+
Numbers: []float64{1.13459, 2.29343, 3.010100010},
30+
}
31+
t2 := T{
32+
Name: "Isabella",
33+
Numbers: []float64{1.13459, 2.29843, 3.010100010},
34+
}
35+
36+
if diff := deep.Equal(t1, t2); diff != nil {
37+
t.Error(diff)
38+
}
39+
}
40+
```
41+
42+
43+
```
44+
$ go test
45+
--- FAIL: TestDeepEqual (0.00s)
46+
main_test.go:25: [Numbers.slice[1]: 2.29343 != 2.29843]
47+
```
48+
49+
The difference is in `Numbers.slice[1]`: the two values aren't equal using Go `==`.

deep.go

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
/*
2+
The MIT License (MIT)
3+
4+
Copyright 2015-2017 Daniel Nichter
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
*/
24+
25+
// Package deep provides function deep.Equal which is like reflect.DeepEqual but
26+
// retunrs a list of differences. This is helpful when comparing complex types
27+
// like structures and maps.
28+
package deep
29+
30+
import (
31+
"errors"
32+
"fmt"
33+
"log"
34+
"reflect"
35+
"strings"
36+
)
37+
38+
var (
39+
FloatPrecision = 10
40+
MaxDiff = 10
41+
MaxDepth = 10
42+
LogErrors = false
43+
CompareUnexportedFields = false
44+
)
45+
46+
var (
47+
ErrMaxRecursion = errors.New("recursed to MaxDepth")
48+
ErrTypeMismatch = errors.New("variables are different reflect.Type")
49+
ErrNotHandled = errors.New("cannot compare the reflect.Kind")
50+
)
51+
52+
type cmp struct {
53+
diff []string
54+
buff []string
55+
floatFormat string
56+
}
57+
58+
// Equal compares variables a and b, recursing into their structure up to
59+
// MaxDepth levels deep, and returns a list of differences, or nil if there are
60+
// none. Some differences may not be found if an error is also returned.
61+
//
62+
// If a type has an Equal method, like time.Equal, it is called to check for
63+
// equality.
64+
func Equal(a, b interface{}) []string {
65+
aVal := reflect.ValueOf(a)
66+
bVal := reflect.ValueOf(b)
67+
c := &cmp{
68+
diff: []string{},
69+
buff: []string{},
70+
floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
71+
}
72+
c.equals(aVal, bVal, 0)
73+
if len(c.diff) > 0 {
74+
return c.diff // diffs
75+
}
76+
return nil // no diffs
77+
}
78+
79+
func (c *cmp) equals(a, b reflect.Value, level int) {
80+
if level > MaxDepth {
81+
logError(ErrMaxRecursion)
82+
return
83+
}
84+
85+
aType := a.Type()
86+
bType := b.Type()
87+
if aType != bType {
88+
c.saveDiff(aType, bType)
89+
logError(ErrTypeMismatch)
90+
return
91+
}
92+
93+
aKind := a.Kind()
94+
bKind := b.Kind()
95+
if aKind == reflect.Ptr || aKind == reflect.Interface {
96+
a = a.Elem()
97+
aKind = a.Kind()
98+
if a.IsValid() {
99+
aType = a.Type()
100+
}
101+
}
102+
if bKind == reflect.Ptr || bKind == reflect.Interface {
103+
b = b.Elem()
104+
bKind = b.Kind()
105+
if b.IsValid() {
106+
bType = b.Type()
107+
}
108+
}
109+
110+
// For example: T{x: *X} and T.x is nil.
111+
if !a.IsValid() || !b.IsValid() {
112+
if a.IsValid() && !b.IsValid() {
113+
c.saveDiff(aType, "<nil pointer>")
114+
} else if !a.IsValid() && b.IsValid() {
115+
c.saveDiff("<nil pointer>", bType)
116+
}
117+
return
118+
}
119+
120+
// Types with an Equal(), like time.Time.
121+
eqFunc := a.MethodByName("Equal")
122+
if eqFunc.IsValid() {
123+
retVals := eqFunc.Call([]reflect.Value{b})
124+
if !retVals[0].Bool() {
125+
c.saveDiff(a, b)
126+
}
127+
return
128+
}
129+
130+
switch aKind {
131+
132+
/////////////////////////////////////////////////////////////////////
133+
// Iterable kinds
134+
/////////////////////////////////////////////////////////////////////
135+
136+
case reflect.Struct:
137+
/*
138+
The variables are structs like:
139+
type T struct {
140+
FirstName string
141+
LastName string
142+
}
143+
Type = <pkg>.T, Kind = reflect.Struct
144+
145+
Iterate through the fields (FirstName, LastName), recurse into their values.
146+
*/
147+
for i := 0; i < a.NumField(); i++ {
148+
if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
149+
continue // skip unexported field, e.g. s in type T struct {s string}
150+
}
151+
152+
c.push(aType.Field(i).Name) // push field name to buff
153+
154+
// Get the Value for each field, e.g. FirstName has Type = string,
155+
// Kind = reflect.String.
156+
af := a.Field(i)
157+
bf := b.Field(i)
158+
159+
// Recurse to compare the field values
160+
c.equals(af, bf, level+1)
161+
162+
c.pop() // pop field name from buff
163+
164+
if len(c.diff) >= MaxDiff {
165+
break
166+
}
167+
}
168+
case reflect.Map:
169+
/*
170+
The variables are maps like:
171+
map[string]int{
172+
"foo": 1,
173+
"bar": 2,
174+
}
175+
Type = map[string]int, Kind = reflect.Map
176+
177+
Or:
178+
type T map[string]int{}
179+
Type = <pkg>.T, Kind = reflect.Map
180+
181+
Iterate through the map keys (foo, bar), recurse into their values.
182+
*/
183+
184+
if a.IsNil() || b.IsNil() {
185+
if a.IsNil() && !b.IsNil() {
186+
c.saveDiff("<nil map>", b)
187+
} else if !a.IsNil() && b.IsNil() {
188+
c.saveDiff(a, "<nil map>")
189+
}
190+
return
191+
}
192+
193+
if a.Pointer() == b.Pointer() {
194+
return
195+
}
196+
197+
for _, key := range a.MapKeys() {
198+
c.push(fmt.Sprintf("map[%s]", key))
199+
200+
aVal := a.MapIndex(key)
201+
bVal := b.MapIndex(key)
202+
if bVal.IsValid() {
203+
c.equals(aVal, bVal, level+1)
204+
} else {
205+
c.saveDiff(aVal, "<does not have key>")
206+
}
207+
208+
c.pop()
209+
210+
if len(c.diff) >= MaxDiff {
211+
return
212+
}
213+
}
214+
215+
for _, key := range b.MapKeys() {
216+
if aVal := a.MapIndex(key); aVal.IsValid() {
217+
continue
218+
}
219+
220+
c.push(fmt.Sprintf("map[%s]", key))
221+
c.saveDiff("<does not have key>", b.MapIndex(key))
222+
c.pop()
223+
if len(c.diff) >= MaxDiff {
224+
return
225+
}
226+
}
227+
case reflect.Slice:
228+
if a.IsNil() || b.IsNil() {
229+
if a.IsNil() && !b.IsNil() {
230+
c.saveDiff("<nil slice>", b)
231+
} else if !a.IsNil() && b.IsNil() {
232+
c.saveDiff(a, "<nil slice>")
233+
}
234+
return
235+
}
236+
237+
if a.Pointer() == b.Pointer() {
238+
return
239+
}
240+
241+
aLen := a.Len()
242+
bLen := b.Len()
243+
n := aLen
244+
if bLen > aLen {
245+
n = bLen
246+
}
247+
for i := 0; i < n; i++ {
248+
c.push(fmt.Sprintf("slice[%d]", i))
249+
if i < aLen && i < bLen {
250+
c.equals(a.Index(i), b.Index(i), level+1)
251+
} else if i < aLen {
252+
c.saveDiff(a.Index(i), "<no value>")
253+
} else {
254+
c.saveDiff("<no value>", b.Index(i))
255+
}
256+
c.pop()
257+
if len(c.diff) >= MaxDiff {
258+
break
259+
}
260+
}
261+
262+
/////////////////////////////////////////////////////////////////////
263+
// Primitive kinds
264+
/////////////////////////////////////////////////////////////////////
265+
266+
case reflect.Float32, reflect.Float64:
267+
// Avoid 0.04147685731961082 != 0.041476857319611
268+
// 6 decimal places is close enough
269+
aval := fmt.Sprintf(c.floatFormat, a.Float())
270+
bval := fmt.Sprintf(c.floatFormat, b.Float())
271+
if aval != bval {
272+
c.saveDiff(a.Float(), b.Float())
273+
}
274+
case reflect.Bool:
275+
if a.Bool() != b.Bool() {
276+
c.saveDiff(a.Bool(), b.Bool())
277+
}
278+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
279+
if a.Int() != b.Int() {
280+
c.saveDiff(a.Int(), b.Int())
281+
}
282+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
283+
if a.Uint() != b.Uint() {
284+
c.saveDiff(a.Uint(), b.Uint())
285+
}
286+
case reflect.String:
287+
if a.String() != b.String() {
288+
c.saveDiff(a.String(), b.String())
289+
}
290+
291+
default:
292+
logError(ErrNotHandled)
293+
}
294+
}
295+
296+
func (c *cmp) push(name string) {
297+
c.buff = append(c.buff, name)
298+
}
299+
300+
func (c *cmp) pop() {
301+
if len(c.buff) > 0 {
302+
c.buff = c.buff[0 : len(c.buff)-1]
303+
}
304+
}
305+
306+
func (c *cmp) saveDiff(aval, bval interface{}) {
307+
if len(c.buff) > 0 {
308+
varName := strings.Join(c.buff, ".")
309+
c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval))
310+
} else {
311+
c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval))
312+
}
313+
}
314+
315+
func init() {
316+
log.SetFlags(log.Lshortfile)
317+
}
318+
319+
func logError(err error) {
320+
if LogErrors {
321+
log.Println(err)
322+
}
323+
}

0 commit comments

Comments
 (0)