Skip to content

Add the Take semigroup #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions algebird-core/src/main/scala/com/twitter/algebird/Take.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
Copyright 2013 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package com.twitter.algebird

object TakeState {
def apply[T](item: T): TakeState[T] = TakeState(item, false)
}

case class TakeState[+T](item: T, emitted: Boolean)

/** Assumes you put the old item on the left
* To get the normal List.take behavior, use this with the Last semigroup:
* e.g. val sg = new TakeSemigroup[Last[T]](100)
*/
class TakeSemigroup[T](count: Long)(implicit sgT: Semigroup[T]) extends TakeWhileSemigroup[(T, Long)] {
assert(count > 0, "TakeSemigroup only makes sense if you take at least one")
def isDone(item: (T, Long)) = {
val (l, cnt) = item
(cnt > count)
}
}

object TakeWhileSemigroup {
def apply[T](fn: T => Boolean)(implicit sg: Semigroup[T]): TakeWhileSemigroup[T] =
new TakeWhileSemigroup()(sg) {
def isDone(item: T) = !fn(item)
}
}
/** Assumes you put the old item on the left
*
* To get a threshold, use a sum and takeWhile:
* val sg = new TakeWhileSemigroup[Long](_ < 10000L)
* sg.sumOption(items: List[Long]).collect{case TakeState(t, false) => t}
*
* NOTE: for this to be a valid semigroup (and get the desired properties on parallelism)
* you need isDone to have the property that
* if isDone(a) || isDone(b) then isDone(a+b)
* so, isDone is monotonic in adding the valid elements of T is the Semigroup[T].
*/
abstract class TakeWhileSemigroup[T](implicit sgT: Semigroup[T]) extends Semigroup[TakeState[T]] {
/**
* given and item, are we done taking, and if so emit a normalized value
*/
def isDone(item: T): Boolean

def plus(left: TakeState[T], right: TakeState[T]): TakeState[T] = {
val TakeState(t2, b2) = left
val TakeState(t1, b1) = right
val nextT = sgT.plus(t2, t1)
TakeState(nextT, b2 || b1 || isDone(nextT))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,44 @@ object CombinatorTest extends Properties("Combinator") {
}
property("MonoidCombinator with top-K forms a Monoid") = monoidLaws[(Map[Int, Int],Set[Int])]


/**
* Threshold crossing logic can also be done with one sum on
* the following combined Semigroup:
*/
property("Threshold combinator works") = {
val threshold = 10000L

implicit val orSemi = Semigroup.from[Boolean](_ || _)
implicit val andSemi = Semigroup.from[Option[Boolean]] { (l, r) =>
(l, r) match {
case (None, _) => r
case (_, None) => l
case (Some(lb), Some(rb)) => Some(lb && rb)
}
}
implicit val thresh: Semigroup[(Long, (Boolean, Option[Boolean]))] =
new SemigroupCombinator({ (sum: Long, doneCross: (Boolean, Option[Boolean])) =>
val (done, cross) = doneCross
if(done) (true, Some(false))
else {
// just crossed
if(sum >= threshold && doneCross._2.isEmpty) (true, Some(true))
else (false, None) // not yet crossed
}
})
semigroupLaws[(Long, (Boolean, Option[Boolean]))] && {
// thresholds could be implemented as so:
//
def sumCrosses(l: List[Long], t: Long) =
!(l.scanLeft(0L)(_ + _).filter( _ >= t).isEmpty)
forAll { (t: List[Long]) =>
sumCrosses(t, threshold) ==
(Semigroup.sumOption((0L :: t).map { (_, (false, None:Option[Boolean])) })
.headOption
.map { _._2._1 }
.getOrElse(false))
}
}
}
}
61 changes: 61 additions & 0 deletions algebird-test/src/test/scala/com/twitter/algebird/TakeTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
Copyright 2013 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

import org.scalacheck.Arbitrary
import org.scalacheck.Properties
import org.scalacheck.Prop.forAll
import org.scalacheck.Gen.oneOf
import org.scalacheck.Gen

import scala.annotation.tailrec

object TakeTest extends Properties("TakeSemigroup") {
import BaseProperties._

implicit def arbTake[T](implicit tarb: Arbitrary[T]): Arbitrary[TakeState[(T, Long)]] = Arbitrary( for {
s <- tarb.arbitrary
c <- Gen.choose(0L, 100000L)
b <- Gen.oneOf(true, false)
} yield TakeState((s, c), b))

implicit def takeSemi[T:Semigroup] = new TakeSemigroup[T](100)

implicit def arbLast[T](implicit tarb: Arbitrary[T]): Arbitrary[Last[T]] =
Arbitrary(tarb.arbitrary.map(Last(_)))

property("Take is a semigroup") = semigroupLaws[TakeState[(Last[Int], Long)]]
property("Take while summing is a semigroup") = semigroupLaws[TakeState[(Int, Long)]]

def take[T](t: List[T], cnt: Long): List[T] = {
implicit val takes: Semigroup[TakeState[(Last[T], Long)]] = new TakeSemigroup[Last[T]](cnt)
t.map(item => Option(TakeState((Last(item), 1L))))
.scanLeft(None: Option[TakeState[(Last[T], Long)]])(Monoid.plus[Option[TakeState[(Last[T], Long)]]](_, _))
.collect { case Some(TakeState((Last(item), _), false)) => item }
}

property("Take works as expected") = forAll { (t: List[Int]) =>
val posCnt = Gen.choose(1, 2 * t.size + 2).sample.get
t.take(posCnt) == take(t, posCnt)
}

implicit val arbTakeD: Arbitrary[TakeState[BigInt]] =
Arbitrary(Arbitrary.arbitrary[BigInt].map(b => TakeState(b.abs)))

implicit val threshold: Semigroup[TakeState[BigInt]] = TakeWhileSemigroup[BigInt](_ < 100L)
property("Takewhile sum < 100 is a semigroup") = semigroupLaws[TakeState[BigInt]]
}