diff --git a/src/FSharpPlus/Control/Collection.fs b/src/FSharpPlus/Control/Collection.fs index 0cb45df88..d2d1ff34a 100644 --- a/src/FSharpPlus/Control/Collection.fs +++ b/src/FSharpPlus/Control/Collection.fs @@ -12,24 +12,24 @@ open FSharpPlus.Internals type OfSeq = inherit Default1 - - static member inline OfSeq ((x: seq<'t>, _: 'R), _: Default5) : 'R = + + static member inline OfSeq ((x: seq<'t>, _: '``Foldable'``), _: Default5) : '``Foldable'`` = #if TEST_TRACE - Traces.add "OfSeq, Default5-seq<'t>" + Traces.add "OfSeq, Return+Sum<'t>" #endif - (^R : (new : seq<'t> -> ^R) x) + x |> Seq.map Return.Invoke |> Sum.Invoke - static member inline OfSeq ((x: seq>, _: 'R), _: Default5) : 'R = + static member inline OfSeq ((x: seq<'t>, _: 'R), _: Default4) : 'R = #if TEST_TRACE - Traces.add "OfSeq, Default5-seq>" + Traces.add "OfSeq, #new seq<'t>" #endif - (^R : (new : seq<'k*'v> -> ^R) (Seq.map (|KeyValue|) x)) + (^R : (new : seq<'t> -> ^R) x) - static member inline OfSeq ((x: seq<'t>, _: '``Foldable'``), _: Default4) : '``Foldable'`` = + static member inline OfSeq ((x: seq>, _: 'R), _: Default4) : 'R = #if TEST_TRACE - Traces.add "OfSeq, Default4-seq<'t>" + Traces.add "OfSeq, #new seq>" #endif - x |> Seq.map Return.Invoke |> Sum.Invoke + (^R : (new : seq<'k*'v> -> ^R) (Seq.map (|KeyValue|) x)) static member OfSeq ((x: seq<'t> , _: seq<'t> ), _: Default3) = x static member OfSeq ((x: seq<'t> , _: ICollection<'t> ), _: Default3) = let d = ResizeArray () in Seq.iter d.Add x; d :> ICollection<'t> @@ -81,10 +81,10 @@ type OfSeq = type OfList = inherit Default1 - static member inline OfList ((x: list<'t> , _: 'R ), _: Default6) = (^R : (new : seq<'t> -> ^R) (List.toSeq x)) : 'R - static member inline OfList ((x: list>, _: 'R ), _: Default6) = (^R : (new : seq<'k*'v> -> ^R) (Seq.map (|KeyValue|) x)) : 'R + static member inline OfList ((x: list<'t> , _: '``Foldable'`` ), _: Default6) = x |> List.map Return.Invoke |> Sum.Invoke : '``Foldable'`` - static member inline OfList ((x: list<'t> , _: '``Foldable'`` ), _: Default5) = x |> List.map Return.Invoke |> Sum.Invoke : '``Foldable'`` + static member inline OfList ((x: list<'t> , _: 'R ), _: Default5) = (^R : (new : seq<'t> -> ^R) (List.toSeq x)) : 'R + static member inline OfList ((x: list>, _: 'R ), _: Default5) = (^R : (new : seq<'k*'v> -> ^R) (Seq.map (|KeyValue|) x)) : 'R static member OfList ((x: list<'t> , _: seq<'t> ), _: Default4) = List.toSeq x #if !FABLE_COMPILER diff --git a/src/FSharpPlus/Control/Monad.fs b/src/FSharpPlus/Control/Monad.fs index c355dc020..ba4f17cd2 100644 --- a/src/FSharpPlus/Control/Monad.fs +++ b/src/FSharpPlus/Control/Monad.fs @@ -111,8 +111,12 @@ type Return = - static member Return (_: seq<'a> , _: Default2) = fun x -> Seq.singleton x : seq<'a> - static member Return (_: IEnumerator<'a>, _: Default2) = fun x -> Enumerator.upto None (fun _ -> x) : IEnumerator<'a> + static member Return (_: seq<'a> , _: Default4) = fun x -> Seq.singleton x : seq<'a> + static member Return (_: IEnumerator<'a>, _: Default3) = fun x -> Enumerator.upto None (fun _ -> x) : IEnumerator<'a> + static member Return (_: IDictionary<'k,'t> , _: Default2) = fun x -> Dict.emptyWithDefault x : IDictionary<'k,'t> + #if (!FABLE_COMPILER_3) // TODO Dummy overload for now + static member Return (_: IReadOnlyDictionary<'k,'t>, _: Default3) = fun x -> readOnlyDict [Unchecked.defaultof<'k>, x] : IReadOnlyDictionary<'k,'t> + #endif static member inline Return (_: 'R , _: Default1) = fun (x: 'T) -> Return.InvokeOnInstance x : 'R static member Return (_: Lazy<'a> , _: Return ) = fun x -> Lazy<_>.CreateFromValue x : Lazy<'a> #if !FABLE_COMPILER @@ -132,9 +136,11 @@ type Return = static member Return (_: 'a Async , _: Return ) = fun (x: 'a) -> async.Return x static member Return (_: Result<'a,'e> , _: Return ) = fun x -> Ok x : Result<'a,'e> static member Return (_: Choice<'a,'e> , _: Return ) = fun x -> Choice1Of2 x : Choice<'a,'e> + #if !FABLE_COMPILER static member Return (_: Expr<'a> , _: Return ) = fun x -> Expr.Cast<'a> (Expr.Value (x: 'a)) #endif + static member Return (_: ResizeArray<'a>, _: Return ) = fun x -> ResizeArray<'a> (Seq.singleton x) //Restricted diff --git a/src/FSharpPlus/Control/ZipApplicative.fs b/src/FSharpPlus/Control/ZipApplicative.fs index 53ccfe1f1..62a7ac583 100644 --- a/src/FSharpPlus/Control/ZipApplicative.fs +++ b/src/FSharpPlus/Control/ZipApplicative.fs @@ -29,8 +29,12 @@ type Pure = let inline call (mthd: ^M, output: ^R) = ((^M or ^R) : (static member Pure : _*_ -> _) output, mthd) call (Unchecked.defaultof, Unchecked.defaultof<'``ZipApplicative<'T>``>) x - static member Pure (_: seq<'a> , _: Default2 ) = fun x -> Seq.initInfinite (fun _ -> x) : seq<'a> - static member Pure (_: IEnumerator<'a> , _: Default2 ) = fun x -> Enumerator.upto None (fun _ -> x) : IEnumerator<'a> + static member Pure (_: seq<'a> , _: Default4 ) = fun x -> Seq.initInfinite (fun _ -> x) : seq<'a> + static member Pure (_: IEnumerator<'a> , _: Default3 ) = fun x -> Enumerator.upto None (fun _ -> x) : IEnumerator<'a> + static member Pure (_: IDictionary<'k,'t>, _: Default2) = fun x -> Dict.emptyWithDefault x: IDictionary<'k,'t> + #if (!FABLE_COMPILER_3) // TODO Dummy overload for now + static member Pure (_: IReadOnlyDictionary<'k,'t>, _: Default3) = fun x -> readOnlyDict [Unchecked.defaultof<'k>, x] : IReadOnlyDictionary<'k,'t> + #endif static member inline Pure (_: 'R , _: Default1 ) = fun (x: 'T) -> Pure.InvokeOnInstance x : 'R static member Pure (x: Lazy<'a> , _: Pure) = Return.Return (x, Unchecked.defaultof) : _ -> Lazy<'a> #if !FABLE_COMPILER diff --git a/src/FSharpPlus/Extensions/Dict.fs b/src/FSharpPlus/Extensions/Dict.fs index b10bddacc..f1d3b3f55 100644 --- a/src/FSharpPlus/Extensions/Dict.fs +++ b/src/FSharpPlus/Extensions/Dict.fs @@ -1,10 +1,73 @@ namespace FSharpPlus +[] +module Auto = + open System + open System.Collections + open System.Collections.Generic + + let icollection (konst: 'TValue) (source: IDictionary<'TKey,'TValue>) = + { + new ICollection<'TValue> with + member _.Contains item = source.Values.Contains item || obj.ReferenceEquals (item, konst) + member _.GetEnumerator () = source.Values.GetEnumerator () :> System.Collections.IEnumerator + member _.GetEnumerator () = source.Values.GetEnumerator () : IEnumerator<'TValue> + member _.IsReadOnly = true + member _.Add (_item: 'TValue) : unit = raise (NotImplementedException ()) + member _.Clear () : unit = raise (NotImplementedException ()) + member _.CopyTo (_array: 'TValue [], _arrayIndex: int) : unit = raise (NotImplementedException ()) + member _.Count : int = source.Count + member _.Remove (_item: 'TValue): bool = raise (NotImplementedException ()) + } + + type DefaultableDict<'TKey, 'TValue> (konst: 'TValue, source: IDictionary<'TKey,'TValue>) = + + interface IDictionary<'TKey, 'TValue> with + member _.TryGetValue (key: 'TKey, value: byref<'TValue>) = + match source.TryGetValue key with + | true, v -> value <- v + | _ -> value <- konst + true + member _.Count = source.Count + member _.ContainsKey (_key: 'TKey) = true + member _.Contains (item: KeyValuePair<'TKey,'TValue>) = + match source.TryGetValue item.Key with + | true, v -> obj.ReferenceEquals (item.Value, v) + | _ -> obj.ReferenceEquals (item.Value, konst) + member _.GetEnumerator () = source.GetEnumerator () : System.Collections.IEnumerator + member _.GetEnumerator () = source.GetEnumerator () : IEnumerator> + member _.IsReadOnly = true + member _.Values = icollection konst source + member _.Item + with get (key: 'TKey) : 'TValue = match source.TryGetValue key with (true, v) -> v | _ -> konst + and set (_key: 'TKey) (_: 'TValue) : unit = raise (System.NotImplementedException()) + + member _.Add (_key: 'TKey, _value: 'TValue) : unit = raise (NotImplementedException ()) + member _.Add (_item: KeyValuePair<'TKey,'TValue>) : unit = raise (NotImplementedException ()) + member _.Clear () : unit = raise (NotImplementedException ()) + member _.CopyTo (_arr: KeyValuePair<'TKey,'TValue> [], _arrayIndex: int) : unit = raise (NotImplementedException ()) + member _.Keys : ICollection<'TKey> = raise (NotImplementedException ()) + member _.Remove (_key: 'TKey) : bool = raise (NotImplementedException ()) + member _.Remove (_item: KeyValuePair<'TKey,'TValue>) : bool = raise (NotImplementedException ()) + + member _.DefaultValue = konst + /// Additional operations on IDictionary<'Key, 'Value> [] module Dict = open System.Collections.Generic open System.Collections.ObjectModel + open Auto + + /// Creates a defaultable dictionary. + /// The value for all missing keys. + /// The source dictionary. + let emptyWithDefault<'TKey,'TValue when 'TKey : equality> (konst: 'TValue) : IDictionary<'TKey,'TValue> = new DefaultableDict<'TKey,'TValue>(konst, dict []) + + /// Creates a defaultable dictionary. + /// The value for all missing keys. + /// The source dictionary. + let initWithDefault<'TKey,'TValue> (konst: 'TValue) (source: IDictionary<'TKey,'TValue>) : IDictionary<'TKey,'TValue> = new DefaultableDict<'TKey,'TValue>(konst, source) #if !FABLE_COMPILER open System.Linq @@ -66,10 +129,13 @@ module Dict = /// /// The mapped dictionary. let map mapper (source: IDictionary<'Key, 'T>) = - let dct = Dictionary<'Key, 'U> () + let dct = + match source with + | :? DefaultableDict<'Key, 'T> as s -> emptyWithDefault (mapper s.DefaultValue) + | _ -> Dictionary<'Key, 'U> () :> IDictionary<'Key, 'U> for KeyValue(k, v) in source do dct.Add (k, mapper v) - dct :> IDictionary<'Key, 'U> + dct /// Applies each function in the dictionary of functions to the corresponding value in the dictionary of values, /// producing a new dictionary of values. @@ -88,13 +154,17 @@ module Dict = /// This function is useful for applying a set of transformations to a dictionary of values, /// where each transformation is defined by a function in a dictionary of functions. /// - let apply (f: IDictionary<'Key, _>) (x: IDictionary<'Key, 'T>) : IDictionary<'Key, 'U> = - let dct = Dictionary () - for KeyValue (k, vf) in f do - match x.TryGetValue k with - | true, vx -> dct.Add (k, vf vx) - | _ -> () - dct :> IDictionary<'Key, 'U> + let apply (f: IDictionary<'Key, 'T -> 'U>) (x: IDictionary<'Key, 'T>) : IDictionary<'Key, 'U> = + let apply () = + let dct = Dictionary () + for KeyValue (k, vf) in f do + match x.TryGetValue k with + | true, vx -> dct.Add (k, vf vx) + | _ -> () + dct :> IDictionary<'Key, 'U> + match f, x with + | (:? DefaultableDict<'Key, 'T -> 'U> as s1), (:? DefaultableDict<'Key, 'T> as s2) -> initWithDefault (s1.DefaultValue s2.DefaultValue) (apply ()) + | _, _ -> apply () /// Creates a Dictionary value from a pair of Dictionaries, using a function to combine them. /// Keys that are not present on both dictionaries are dropped. @@ -104,30 +174,41 @@ module Dict = /// /// The combined dictionary. let map2 mapper (source1: IDictionary<'Key, 'T1>) (source2: IDictionary<'Key, 'T2>) = - let dct = Dictionary<'Key, 'U> () - let f = OptimizedClosures.FSharpFunc<_, _, _>.Adapt mapper - for KeyValue(k, vx) in source1 do - match tryGetValue k source2 with - | Some vy -> dct.Add (k, f.Invoke (vx, vy)) - | None -> () - dct :> IDictionary<'Key, 'U> + let map () = + let dct = Dictionary<'Key, 'U> () + let f = OptimizedClosures.FSharpFunc<_, _, _>.Adapt mapper + let keys = Seq.append source1.Keys source2.Keys |> Seq.distinct + for k in keys do + match tryGetValue k source1, tryGetValue k source2 with + | Some vx, Some vy -> dct.Add (k, f.Invoke (vx, vy)) + | _ , _ -> () + dct :> IDictionary<'Key, 'U> + match source1, source2 with + | (:? DefaultableDict<'Key,'T1> as s1), (:? DefaultableDict<'Key,'T2> as s2) -> initWithDefault (mapper s1.DefaultValue s2.DefaultValue) (map ()) + | _, _ -> map () /// Combines values from three dictionaries using mapping function. /// Keys that are not present on every dictionary are dropped. - /// The mapping function. + /// The mapping function. /// First input dictionary. /// Second input dictionary. /// Third input dictionary. /// /// The mapped dictionary. - let map3 mapping (source1: IDictionary<'Key, 'T1>) (source2: IDictionary<'Key, 'T2>) (source3: IDictionary<'Key, 'T3>) = - let dct = Dictionary<'Key, 'U> () - let f = OptimizedClosures.FSharpFunc<_,_,_,_>.Adapt mapping - for KeyValue(k, vx) in source1 do - match tryGetValue k source2, tryGetValue k source3 with - | Some vy, Some vz -> dct.Add (k, f.Invoke (vx, vy, vz)) - | _ , _ -> () - dct :> IDictionary<'Key, 'U> + let map3 mapper (source1: IDictionary<'Key, 'T1>) (source2: IDictionary<'Key, 'T2>) (source3: IDictionary<'Key, 'T3>) = + let map () = + let dct = Dictionary<'Key, 'U> () + let f = OptimizedClosures.FSharpFunc<_,_,_,_>.Adapt mapper + let keys = source1.Keys |> Seq.append source2.Keys |> Seq.append source3.Keys |> Seq.distinct + for k in keys do + match tryGetValue k source1, tryGetValue k source2, tryGetValue k source3 with + | Some vx, Some vy, Some vz -> dct.Add (k, f.Invoke (vx, vy, vz)) + | _ , _ , _ -> () + dct :> IDictionary<'Key, 'U> + match source1, source2, source3 with + | (:? DefaultableDict<'Key,'T1> as s1), (:? DefaultableDict<'Key,'T2> as s2), (:? DefaultableDict<'Key,'T3> as s3) -> + initWithDefault (mapper s1.DefaultValue s2.DefaultValue s3.DefaultValue) (map ()) + | _, _, _ -> map () /// Applies given function to each value of the given dictionary. /// The mapping function. @@ -148,13 +229,7 @@ module Dict = /// The second input dictionary. /// /// The tupled dictionary. - let zip (source1: IDictionary<'Key, 'T1>) (source2: IDictionary<'Key, 'T2>) = - let dct = Dictionary<'Key, 'T1 * 'T2> () - for KeyValue(k, vx) in source1 do - match tryGetValue k source2 with - | Some vy -> dct.Add (k, (vx, vy)) - | None -> () - dct :> IDictionary<'Key, 'T1 * 'T2> + let zip (source1: IDictionary<'Key, 'T1>) (source2: IDictionary<'Key, 'T2>) = map2 (fun x y -> (x, y)) source1 source2 /// Tuples values of three dictionaries. /// Keys that are not present on all three dictionaries are dropped. @@ -225,11 +300,18 @@ module Dict = /// Returns the union of two dictionaries, using the combiner function for duplicate keys. let unionWith combiner (source1: IDictionary<'Key, 'Value>) (source2: IDictionary<'Key, 'Value>) = - let d = Dictionary<'Key,'Value> () - let f = OptimizedClosures.FSharpFunc<_,_,_>.Adapt combiner - for KeyValue(k, v ) in source1 do d.[k] <- v - for KeyValue(k, v') in source2 do d.[k] <- match d.TryGetValue k with true, v -> f.Invoke (v, v') | _ -> v' - d :> IDictionary<'Key,'Value> + let combine () = + let d = Dictionary<'Key,'Value> () + let f = OptimizedClosures.FSharpFunc<_,_,_>.Adapt combiner + for KeyValue(k, v ) in source1 do d.[k] <- v + for KeyValue(k, v') in source2 do d.[k] <- match d.TryGetValue k with true, v -> f.Invoke (v, v') | _ -> v' + d :> IDictionary<'Key,'Value> + match source1, source2 with + | (:? DefaultableDict<'Key,'Value> as s1) , (:? DefaultableDict<'Key,'Value> as s2) -> initWithDefault (combiner s1.DefaultValue s2.DefaultValue) (combine()) + | (:? DefaultableDict<'Key,'Value> as s), _ | _, (:? DefaultableDict<'Key,'Value> as s) -> initWithDefault s.DefaultValue (combine()) + | s, empty | empty, s when empty.Count = 0 -> s + | _, _ -> combine() + #if !FABLE_COMPILER ///Returns the union of two maps, preferring values from the first in case of duplicate keys. diff --git a/tests/FSharpPlus.Tests/Collections.fs b/tests/FSharpPlus.Tests/Collections.fs index b3da89094..f919dc5f2 100644 --- a/tests/FSharpPlus.Tests/Collections.fs +++ b/tests/FSharpPlus.Tests/Collections.fs @@ -153,7 +153,7 @@ module Collections = let _12: WrappedListI<_> = seq [1;2] |> ofSeq #if TEST_TRACE - CollectionAssert.AreEqual (["OfSeq, Default2-#Add"; "OfSeq, Default2-#Add"; "OfSeq, Default2-#Add"; "OfSeq, Default4-seq<'t>"], Traces.get()) + CollectionAssert.AreEqual (["OfSeq, Default2-#Add"; "OfSeq, Default2-#Add"; "OfSeq, Default2-#Add"; "OfSeq, Return+Sum"], Traces.get()) #endif () diff --git a/tests/FSharpPlus.Tests/General.fs b/tests/FSharpPlus.Tests/General.fs index 62da5a027..bcd6df2f8 100644 --- a/tests/FSharpPlus.Tests/General.fs +++ b/tests/FSharpPlus.Tests/General.fs @@ -396,6 +396,18 @@ module Functor = Assert.IsInstanceOf>> (Some testVal10) areEqual 2 (testVal10 |> Async.RunSynchronously) + let testVal11 = (+) "h" dict [1, "i"; 2, "ello"] + CollectionAssert.AreEqual (dict [(1, "hi"); (2, "hello")], testVal11) + + let testVal12 = + let h: IDictionary = result "h" + try + (+) h <*> dict [1, "i"; 2, "ello"] + with _ -> dict [0, "failure"] + CollectionAssert.AreEqual (dict [0, "failure"], testVal12) + + + [] let mapSquared () = let x =