|
| 1 | +/* |
| 2 | + The MIT License (MIT) |
| 3 | +
|
| 4 | + Copyright 2015-2017 Daniel Nichter |
| 5 | +
|
| 6 | + Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | + of this software and associated documentation files (the "Software"), to deal |
| 8 | + in the Software without restriction, including without limitation the rights |
| 9 | + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | + copies of the Software, and to permit persons to whom the Software is |
| 11 | + furnished to do so, subject to the following conditions: |
| 12 | +
|
| 13 | + The above copyright notice and this permission notice shall be included in all |
| 14 | + copies or substantial portions of the Software. |
| 15 | +
|
| 16 | + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | + SOFTWARE. |
| 23 | +*/ |
| 24 | + |
| 25 | +// Package deep provides function deep.Equal which is like reflect.DeepEqual but |
| 26 | +// retunrs a list of differences. This is helpful when comparing complex types |
| 27 | +// like structures and maps. |
| 28 | +package deep |
| 29 | + |
| 30 | +import ( |
| 31 | + "errors" |
| 32 | + "fmt" |
| 33 | + "log" |
| 34 | + "reflect" |
| 35 | + "strings" |
| 36 | +) |
| 37 | + |
| 38 | +var ( |
| 39 | + FloatPrecision = 10 |
| 40 | + MaxDiff = 10 |
| 41 | + MaxDepth = 10 |
| 42 | + LogErrors = false |
| 43 | + CompareUnexportedFields = false |
| 44 | +) |
| 45 | + |
| 46 | +var ( |
| 47 | + ErrMaxRecursion = errors.New("recursed to MaxDepth") |
| 48 | + ErrTypeMismatch = errors.New("variables are different reflect.Type") |
| 49 | + ErrNotHandled = errors.New("cannot compare the reflect.Kind") |
| 50 | +) |
| 51 | + |
| 52 | +type cmp struct { |
| 53 | + diff []string |
| 54 | + buff []string |
| 55 | + floatFormat string |
| 56 | +} |
| 57 | + |
| 58 | +// Equal compares variables a and b, recursing into their structure up to |
| 59 | +// MaxDepth levels deep, and returns a list of differences, or nil if there are |
| 60 | +// none. Some differences may not be found if an error is also returned. |
| 61 | +// |
| 62 | +// If a type has an Equal method, like time.Equal, it is called to check for |
| 63 | +// equality. |
| 64 | +func Equal(a, b interface{}) []string { |
| 65 | + aVal := reflect.ValueOf(a) |
| 66 | + bVal := reflect.ValueOf(b) |
| 67 | + c := &cmp{ |
| 68 | + diff: []string{}, |
| 69 | + buff: []string{}, |
| 70 | + floatFormat: fmt.Sprintf("%%.%df", FloatPrecision), |
| 71 | + } |
| 72 | + c.equals(aVal, bVal, 0) |
| 73 | + if len(c.diff) > 0 { |
| 74 | + return c.diff // diffs |
| 75 | + } |
| 76 | + return nil // no diffs |
| 77 | +} |
| 78 | + |
| 79 | +func (c *cmp) equals(a, b reflect.Value, level int) { |
| 80 | + if level > MaxDepth { |
| 81 | + logError(ErrMaxRecursion) |
| 82 | + return |
| 83 | + } |
| 84 | + |
| 85 | + aType := a.Type() |
| 86 | + bType := b.Type() |
| 87 | + if aType != bType { |
| 88 | + c.saveDiff(aType, bType) |
| 89 | + logError(ErrTypeMismatch) |
| 90 | + return |
| 91 | + } |
| 92 | + |
| 93 | + aKind := a.Kind() |
| 94 | + bKind := b.Kind() |
| 95 | + if aKind == reflect.Ptr || aKind == reflect.Interface { |
| 96 | + a = a.Elem() |
| 97 | + aKind = a.Kind() |
| 98 | + if a.IsValid() { |
| 99 | + aType = a.Type() |
| 100 | + } |
| 101 | + } |
| 102 | + if bKind == reflect.Ptr || bKind == reflect.Interface { |
| 103 | + b = b.Elem() |
| 104 | + bKind = b.Kind() |
| 105 | + if b.IsValid() { |
| 106 | + bType = b.Type() |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + // For example: T{x: *X} and T.x is nil. |
| 111 | + if !a.IsValid() || !b.IsValid() { |
| 112 | + if a.IsValid() && !b.IsValid() { |
| 113 | + c.saveDiff(aType, "<nil pointer>") |
| 114 | + } else if !a.IsValid() && b.IsValid() { |
| 115 | + c.saveDiff("<nil pointer>", bType) |
| 116 | + } |
| 117 | + return |
| 118 | + } |
| 119 | + |
| 120 | + // Types with an Equal(), like time.Time. |
| 121 | + eqFunc := a.MethodByName("Equal") |
| 122 | + if eqFunc.IsValid() { |
| 123 | + retVals := eqFunc.Call([]reflect.Value{b}) |
| 124 | + if !retVals[0].Bool() { |
| 125 | + c.saveDiff(a, b) |
| 126 | + } |
| 127 | + return |
| 128 | + } |
| 129 | + |
| 130 | + switch aKind { |
| 131 | + |
| 132 | + ///////////////////////////////////////////////////////////////////// |
| 133 | + // Iterable kinds |
| 134 | + ///////////////////////////////////////////////////////////////////// |
| 135 | + |
| 136 | + case reflect.Struct: |
| 137 | + /* |
| 138 | + The variables are structs like: |
| 139 | + type T struct { |
| 140 | + FirstName string |
| 141 | + LastName string |
| 142 | + } |
| 143 | + Type = <pkg>.T, Kind = reflect.Struct |
| 144 | +
|
| 145 | + Iterate through the fields (FirstName, LastName), recurse into their values. |
| 146 | + */ |
| 147 | + for i := 0; i < a.NumField(); i++ { |
| 148 | + if aType.Field(i).PkgPath != "" && !CompareUnexportedFields { |
| 149 | + continue // skip unexported field, e.g. s in type T struct {s string} |
| 150 | + } |
| 151 | + |
| 152 | + c.push(aType.Field(i).Name) // push field name to buff |
| 153 | + |
| 154 | + // Get the Value for each field, e.g. FirstName has Type = string, |
| 155 | + // Kind = reflect.String. |
| 156 | + af := a.Field(i) |
| 157 | + bf := b.Field(i) |
| 158 | + |
| 159 | + // Recurse to compare the field values |
| 160 | + c.equals(af, bf, level+1) |
| 161 | + |
| 162 | + c.pop() // pop field name from buff |
| 163 | + |
| 164 | + if len(c.diff) >= MaxDiff { |
| 165 | + break |
| 166 | + } |
| 167 | + } |
| 168 | + case reflect.Map: |
| 169 | + /* |
| 170 | + The variables are maps like: |
| 171 | + map[string]int{ |
| 172 | + "foo": 1, |
| 173 | + "bar": 2, |
| 174 | + } |
| 175 | + Type = map[string]int, Kind = reflect.Map |
| 176 | +
|
| 177 | + Or: |
| 178 | + type T map[string]int{} |
| 179 | + Type = <pkg>.T, Kind = reflect.Map |
| 180 | +
|
| 181 | + Iterate through the map keys (foo, bar), recurse into their values. |
| 182 | + */ |
| 183 | + |
| 184 | + if a.IsNil() || b.IsNil() { |
| 185 | + if a.IsNil() && !b.IsNil() { |
| 186 | + c.saveDiff("<nil map>", b) |
| 187 | + } else if !a.IsNil() && b.IsNil() { |
| 188 | + c.saveDiff(a, "<nil map>") |
| 189 | + } |
| 190 | + return |
| 191 | + } |
| 192 | + |
| 193 | + if a.Pointer() == b.Pointer() { |
| 194 | + return |
| 195 | + } |
| 196 | + |
| 197 | + for _, key := range a.MapKeys() { |
| 198 | + c.push(fmt.Sprintf("map[%s]", key)) |
| 199 | + |
| 200 | + aVal := a.MapIndex(key) |
| 201 | + bVal := b.MapIndex(key) |
| 202 | + if bVal.IsValid() { |
| 203 | + c.equals(aVal, bVal, level+1) |
| 204 | + } else { |
| 205 | + c.saveDiff(aVal, "<does not have key>") |
| 206 | + } |
| 207 | + |
| 208 | + c.pop() |
| 209 | + |
| 210 | + if len(c.diff) >= MaxDiff { |
| 211 | + return |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + for _, key := range b.MapKeys() { |
| 216 | + if aVal := a.MapIndex(key); aVal.IsValid() { |
| 217 | + continue |
| 218 | + } |
| 219 | + |
| 220 | + c.push(fmt.Sprintf("map[%s]", key)) |
| 221 | + c.saveDiff("<does not have key>", b.MapIndex(key)) |
| 222 | + c.pop() |
| 223 | + if len(c.diff) >= MaxDiff { |
| 224 | + return |
| 225 | + } |
| 226 | + } |
| 227 | + case reflect.Slice: |
| 228 | + if a.IsNil() || b.IsNil() { |
| 229 | + if a.IsNil() && !b.IsNil() { |
| 230 | + c.saveDiff("<nil slice>", b) |
| 231 | + } else if !a.IsNil() && b.IsNil() { |
| 232 | + c.saveDiff(a, "<nil slice>") |
| 233 | + } |
| 234 | + return |
| 235 | + } |
| 236 | + |
| 237 | + if a.Pointer() == b.Pointer() { |
| 238 | + return |
| 239 | + } |
| 240 | + |
| 241 | + aLen := a.Len() |
| 242 | + bLen := b.Len() |
| 243 | + n := aLen |
| 244 | + if bLen > aLen { |
| 245 | + n = bLen |
| 246 | + } |
| 247 | + for i := 0; i < n; i++ { |
| 248 | + c.push(fmt.Sprintf("slice[%d]", i)) |
| 249 | + if i < aLen && i < bLen { |
| 250 | + c.equals(a.Index(i), b.Index(i), level+1) |
| 251 | + } else if i < aLen { |
| 252 | + c.saveDiff(a.Index(i), "<no value>") |
| 253 | + } else { |
| 254 | + c.saveDiff("<no value>", b.Index(i)) |
| 255 | + } |
| 256 | + c.pop() |
| 257 | + if len(c.diff) >= MaxDiff { |
| 258 | + break |
| 259 | + } |
| 260 | + } |
| 261 | + |
| 262 | + ///////////////////////////////////////////////////////////////////// |
| 263 | + // Primitive kinds |
| 264 | + ///////////////////////////////////////////////////////////////////// |
| 265 | + |
| 266 | + case reflect.Float32, reflect.Float64: |
| 267 | + // Avoid 0.04147685731961082 != 0.041476857319611 |
| 268 | + // 6 decimal places is close enough |
| 269 | + aval := fmt.Sprintf(c.floatFormat, a.Float()) |
| 270 | + bval := fmt.Sprintf(c.floatFormat, b.Float()) |
| 271 | + if aval != bval { |
| 272 | + c.saveDiff(a.Float(), b.Float()) |
| 273 | + } |
| 274 | + case reflect.Bool: |
| 275 | + if a.Bool() != b.Bool() { |
| 276 | + c.saveDiff(a.Bool(), b.Bool()) |
| 277 | + } |
| 278 | + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| 279 | + if a.Int() != b.Int() { |
| 280 | + c.saveDiff(a.Int(), b.Int()) |
| 281 | + } |
| 282 | + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| 283 | + if a.Uint() != b.Uint() { |
| 284 | + c.saveDiff(a.Uint(), b.Uint()) |
| 285 | + } |
| 286 | + case reflect.String: |
| 287 | + if a.String() != b.String() { |
| 288 | + c.saveDiff(a.String(), b.String()) |
| 289 | + } |
| 290 | + |
| 291 | + default: |
| 292 | + logError(ErrNotHandled) |
| 293 | + } |
| 294 | +} |
| 295 | + |
| 296 | +func (c *cmp) push(name string) { |
| 297 | + c.buff = append(c.buff, name) |
| 298 | +} |
| 299 | + |
| 300 | +func (c *cmp) pop() { |
| 301 | + if len(c.buff) > 0 { |
| 302 | + c.buff = c.buff[0 : len(c.buff)-1] |
| 303 | + } |
| 304 | +} |
| 305 | + |
| 306 | +func (c *cmp) saveDiff(aval, bval interface{}) { |
| 307 | + if len(c.buff) > 0 { |
| 308 | + varName := strings.Join(c.buff, ".") |
| 309 | + c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval)) |
| 310 | + } else { |
| 311 | + c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval)) |
| 312 | + } |
| 313 | +} |
| 314 | + |
| 315 | +func init() { |
| 316 | + log.SetFlags(log.Lshortfile) |
| 317 | +} |
| 318 | + |
| 319 | +func logError(err error) { |
| 320 | + if LogErrors { |
| 321 | + log.Println(err) |
| 322 | + } |
| 323 | +} |
0 commit comments