diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Take.scala b/algebird-core/src/main/scala/com/twitter/algebird/Take.scala new file mode 100644 index 000000000..d46833007 --- /dev/null +++ b/algebird-core/src/main/scala/com/twitter/algebird/Take.scala @@ -0,0 +1,66 @@ +/* +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.algebird + +object TakeState { + def apply[T](item: T): TakeState[T] = TakeState(item, false) +} + +case class TakeState[+T](item: T, emitted: Boolean) + +/** Assumes you put the old item on the left + * To get the normal List.take behavior, use this with the Last semigroup: + * e.g. val sg = new TakeSemigroup[Last[T]](100) + */ +class TakeSemigroup[T](count: Long)(implicit sgT: Semigroup[T]) extends TakeWhileSemigroup[(T, Long)] { + assert(count > 0, "TakeSemigroup only makes sense if you take at least one") + def isDone(item: (T, Long)) = { + val (l, cnt) = item + (cnt > count) + } +} + +object TakeWhileSemigroup { + def apply[T](fn: T => Boolean)(implicit sg: Semigroup[T]): TakeWhileSemigroup[T] = + new TakeWhileSemigroup()(sg) { + def isDone(item: T) = !fn(item) + } +} +/** Assumes you put the old item on the left + * + * To get a threshold, use a sum and takeWhile: + * val sg = new TakeWhileSemigroup[Long](_ < 10000L) + * sg.sumOption(items: List[Long]).collect{case TakeState(t, false) => t} + * + * NOTE: for this to be a valid semigroup (and get the desired properties on parallelism) + * you need isDone to have the property that + * if isDone(a) || isDone(b) then isDone(a+b) + * so, isDone is monotonic in adding the valid elements of T is the Semigroup[T]. + */ +abstract class TakeWhileSemigroup[T](implicit sgT: Semigroup[T]) extends Semigroup[TakeState[T]] { + /** + * given and item, are we done taking, and if so emit a normalized value + */ + def isDone(item: T): Boolean + + def plus(left: TakeState[T], right: TakeState[T]): TakeState[T] = { + val TakeState(t2, b2) = left + val TakeState(t1, b1) = right + val nextT = sgT.plus(t2, t1) + TakeState(nextT, b2 || b1 || isDone(nextT)) + } +} + diff --git a/algebird-test/src/test/scala/com/twitter/algebird/CombinatorTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/CombinatorTest.scala index 1a97b275a..b863f4994 100644 --- a/algebird-test/src/test/scala/com/twitter/algebird/CombinatorTest.scala +++ b/algebird-test/src/test/scala/com/twitter/algebird/CombinatorTest.scala @@ -76,4 +76,44 @@ object CombinatorTest extends Properties("Combinator") { } property("MonoidCombinator with top-K forms a Monoid") = monoidLaws[(Map[Int, Int],Set[Int])] + + /** + * Threshold crossing logic can also be done with one sum on + * the following combined Semigroup: + */ + property("Threshold combinator works") = { + val threshold = 10000L + + implicit val orSemi = Semigroup.from[Boolean](_ || _) + implicit val andSemi = Semigroup.from[Option[Boolean]] { (l, r) => + (l, r) match { + case (None, _) => r + case (_, None) => l + case (Some(lb), Some(rb)) => Some(lb && rb) + } + } + implicit val thresh: Semigroup[(Long, (Boolean, Option[Boolean]))] = + new SemigroupCombinator({ (sum: Long, doneCross: (Boolean, Option[Boolean])) => + val (done, cross) = doneCross + if(done) (true, Some(false)) + else { + // just crossed + if(sum >= threshold && doneCross._2.isEmpty) (true, Some(true)) + else (false, None) // not yet crossed + } + }) + semigroupLaws[(Long, (Boolean, Option[Boolean]))] && { + // thresholds could be implemented as so: + // + def sumCrosses(l: List[Long], t: Long) = + !(l.scanLeft(0L)(_ + _).filter( _ >= t).isEmpty) + forAll { (t: List[Long]) => + sumCrosses(t, threshold) == + (Semigroup.sumOption((0L :: t).map { (_, (false, None:Option[Boolean])) }) + .headOption + .map { _._2._1 } + .getOrElse(false)) + } + } + } } diff --git a/algebird-test/src/test/scala/com/twitter/algebird/TakeTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/TakeTest.scala new file mode 100644 index 000000000..9f3fb6976 --- /dev/null +++ b/algebird-test/src/test/scala/com/twitter/algebird/TakeTest.scala @@ -0,0 +1,61 @@ +/* +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package com.twitter.algebird + +import org.scalacheck.Arbitrary +import org.scalacheck.Properties +import org.scalacheck.Prop.forAll +import org.scalacheck.Gen.oneOf +import org.scalacheck.Gen + +import scala.annotation.tailrec + +object TakeTest extends Properties("TakeSemigroup") { + import BaseProperties._ + + implicit def arbTake[T](implicit tarb: Arbitrary[T]): Arbitrary[TakeState[(T, Long)]] = Arbitrary( for { + s <- tarb.arbitrary + c <- Gen.choose(0L, 100000L) + b <- Gen.oneOf(true, false) + } yield TakeState((s, c), b)) + + implicit def takeSemi[T:Semigroup] = new TakeSemigroup[T](100) + + implicit def arbLast[T](implicit tarb: Arbitrary[T]): Arbitrary[Last[T]] = + Arbitrary(tarb.arbitrary.map(Last(_))) + + property("Take is a semigroup") = semigroupLaws[TakeState[(Last[Int], Long)]] + property("Take while summing is a semigroup") = semigroupLaws[TakeState[(Int, Long)]] + + def take[T](t: List[T], cnt: Long): List[T] = { + implicit val takes: Semigroup[TakeState[(Last[T], Long)]] = new TakeSemigroup[Last[T]](cnt) + t.map(item => Option(TakeState((Last(item), 1L)))) + .scanLeft(None: Option[TakeState[(Last[T], Long)]])(Monoid.plus[Option[TakeState[(Last[T], Long)]]](_, _)) + .collect { case Some(TakeState((Last(item), _), false)) => item } + } + + property("Take works as expected") = forAll { (t: List[Int]) => + val posCnt = Gen.choose(1, 2 * t.size + 2).sample.get + t.take(posCnt) == take(t, posCnt) + } + + implicit val arbTakeD: Arbitrary[TakeState[BigInt]] = + Arbitrary(Arbitrary.arbitrary[BigInt].map(b => TakeState(b.abs))) + + implicit val threshold: Semigroup[TakeState[BigInt]] = TakeWhileSemigroup[BigInt](_ < 100L) + property("Takewhile sum < 100 is a semigroup") = semigroupLaws[TakeState[BigInt]] +}