garlandus / augment   0.0.3

GitHub

A Scala 3 alternative to comprehensions

Scala versions: 3.x

augment

A Scala 3 alternative to comprehensions

This is a zero-dependency alternative to comprehensions that requires no special syntax, and as such can be used from Java and other JVM languages.
Overall it provides a concise, direct-style notation that handles plain and container types indifferently, which can simplify effect handling.
It also inherently distinguishes the "rectangular" from everything else - or if you prefer, array comprehensions from list comprehensions, the parallelizable from the sequential, and applicatives from monads.

Overview

Say you have an existing function f0 of two variables, something along the lines of def f0(a: Int, b: Int) = a*a + b*b. Then if you write

val f = augment(f0)

you effectively have a new function f that can replace f0, in that f(a,b) is the same as f0(a,b) for any integers a and b. But it now has additional behavior: for instance if A and B are ranges of integers, then

f(A, B).plot()

will display a 3D graph using HTML. The new function by itself does the work of a comprehension, since

f(A, B)

contains the same values as

for
  a <- A
  b <- B
yield
  f0(a, b)

More general comprehensions can be replicated by changing the argument types: instead of sequences of integers, they could be functions that return sequences, as shown below.

There are more detailed explanations in the documentation

Quick start

(For Java and Clojure, setup and examples follow further below.)
In Scala, you can add the following to build.sbt:

libraryDependencies += "co.computist" %% "augment" % "0.0.3"

Imports to get you started:

import augmented._
import augmented.given

Examples

Pythagorean triples

select(1 to n, _ to n, _ to n, (a, b, c) => a * a + b * b == c * c)

Pascal's triangle

binomialCoefficient(0 to n, 0 to _)

Tetrahedron

select(1 to n, 1 to _, 1 to _)

Permutations

def prepend[A] = augment((a: A, l: Seq[A]) => Seq(a) ++ l)

def permutations[A](x: Seq[A]): Seq[Seq[A]] =
  prepend(x, a => permutations(x -- Seq(a))) until x == Seq()

Sieve of Erastosthenes (ish)

val mult = augment((a: Int, b: Int) => a * b)

def primes(n: Int): Seq[Int] =
  complement(mult(primes(sqrt(n)), x => x to n / x), 2 to n) until n == 1

8 queens problem

def isSafe(col: Int, queens: Seq[Int]): Boolean = ...

def queens(n: Int, k: Int = 0): Seq[Seq[Int]] =
  prepend(0 until n, queens(n, k + 1), isSafe) until k == n

Applicative examples

Mixed function arguments

val mult = augment((a: Int, b: Int) => a * b)
val add = augment((a: Int, b: Int, c: Int) => a + b + c)

mult(4, 5)        // 20
mult(Some(4), 5)  // Some(20)
mult(4, None)     // None

add(4, 5, 6)                    // 15
add(4, 5, Some(6))              // Some(15)
add(Some(4), Some(5), Some(6))  // Some(15)
add(4, None, 6)                 // None

Type propagation: "bubbling up"

val a = mult(4, 5)    // 20
val b = add(2, a, 3)  // 25
val c = mult(4, b)    // 100

val n = mult(4, Success(5)) // Success(20)
val p = add(2, n, 3)        // Success(25)
val q = mult(4, p)          // Success(100)
q.value()                   // 100

val x = mult(4, Future(5))  // Future(<not completed>)
val y = add(2, x, 3)        // Future(<not completed>)
val z = mult(4, y)          // Future(<not completed>)
z.value()                   // 100

IO

Basic IO / deferred valuing

val nameFromIO =
  sequence(
    println("What is your name?"),    // ordinary println and readLine, not "lifted" versions
    scala.io.StdIn.readLine,
    name => { println(s"Hello, $name\n"); name }
  )
val name = nameFromIO.value() // name is not retrieved from command line until value() is called

val ratioIO =
  sequence(
      5.0,
      Math.sqrt,
      _ + 1,
      _ / 2.0
  )
val ratio = ratioIO.value() // ratio is not calculated until value() is called 

IO with retries

val n = 
  sequence(
    println("Enter a number: "),
    scala.io.StdIn.readLine,
    _.toInt
  )
val p = add(4, 5, n)
val res = mult(4, p)
res.retry(2).value() // will ignore a duff entry or two

IO with retries: Cats Effect version

given Effects[cats.effect.IO] = Effects()

// rest of code is identical
// res is now of type cats.effect.IO

IO with retries: ZIO version

given Effects[zIO] = Effects()

// rest of code is identical
// res is now of type ZIO[Any, IOException, Int]

Cats Effect example: sequential vs parallel

def f(n: Int) =
  cats.effect.IO:
    Thread.sleep(250)
    n * 10

// cats.effect.IO computations that produce the same result, either sequentially or in parallel
add(f(3), f(4), f(5))                 // parallel
image(f(3), f(4), f(5), _ + _ + _)    // parallel
(f(3), f(4), f(5)).mapN(_ + _ + _)    // sequential
(f(3), f(4), f(5)).parMapN(_ + _ + _) // parallel

add(f(3), f(4), f(5)).value()         // 120

ZIO example: reading/sending over channel

def readZIO[A](ch: Channel[A])        = ZIO.attempt(ch.read())
def sendZIO[A](ch: Channel[A], a: A)  = ZIO.attempt(ch.send(a))

// Notations that describe the same effect

// comprehension notation
for
  a <- readZIO(c1)
  b <- readZIO(c2)
  _ <- sendZIO(c3, a + b)
yield()

// ZIO direct
defer {
  val a = c1.read()
  val b = c2.read()
  c3.send(a + b)
}

// augmented function notation
sequence(
  c1.read(),
  c2.read(),
  _ + _,
  c3.send
)

// or equivalently:
sequence(
  c1.read(),
  c2.read(),
  (a, b) => c3.send(a + b)
)

From Java

Quick start

You can add the following to the dependencies in a pom.xml file:

<dependency>
  <groupId>co.computist</groupId>
  <artifactId>augment_3</artifactId>
  <version>0.0.3</version>
</dependency>
<dependency>
  <groupId>org.scala-lang</groupId>
  <artifactId>scala3-library_3</artifactId>
  <version>3.3.1</version>
</dependency>  

Imports to get you started:

import static augmented.augmentJ.*;
import static java.util.stream.IntStream.range;

Although Java 8 is sufficient, Java 11 is recommended since var means you can avoid lengthy explicit type names (often with multiple generic parameters).

Examples

Pythagorean triples

This can be compared with e.g. https://rosettacode.org/wiki/List_comprehensions#Java

var n = 20;

var triangles =
  select(
    range(1, n),
    a -> range(a, n),
    b -> range(b, n),
    (a, b, c) -> a * a + b * b == c * c);     // [[3 4 5], [5 12 13], [6 8 10], [8 15 17], [9 12 15]]

Propagation of future values

import static mappable.Mapper.mappable;

var mult = augment((Integer a, Integer b) -> a * b);
var add = augment((Integer a, Integer b, Integer c) -> a + b + c);

mult.apply(4, 5);   // here mult returns an ordinary value (20)
add.apply(4, 5, 6); // 15

var executor = Executors.newSingleThreadExecutor();
var futureVal = mappable(4, a -> executor.submit(() -> {Thread.sleep(500); return a;}));

var x = mult.apply(futureVal, 5);
var y = add.apply(2, x, 3);
var z = mult.apply(4, y);   // here mult returns a future value

assertEquals(z.mappable() instanceof FutureTask, true);
assertEquals(z.hasValue(), false);
Thread.sleep(1000);
assertEquals(z.hasValue(), true);
assertEquals(z.value(), (Integer) 100);

From Clojure

Quick start

You can add the following to the dependencies in a project.clj file:

[co.computist/augment_3 "0.0.3"]
[org.scala-lang/scala3-library_3 "3.3.1"]

Examples

Pythagorean triples

(defn augment [f] (augmentedClj.augment/apply f))
(def triple (augment (fn [a b c] [a b c])))
(def n 20)

(def triples
  (triple
    (range 1 n)
    #(range % n)
    #(range % n)
    (fn [a b c] (= (+ (* a a) (* b b)) (* c c)))))

(is (= triples [[3 4 5] [5 12 13] [6 8 10] [8 15 17] [9 12 15]]))

Function graph

(def squares (augment (fn [a b] (- 100 (+ (* a a) (* b b))))))
(squares 5 5)                                      ; 50
(.graph (squares (range -10 11) (range -10 11)))   ; plots function using HTML / JavaScript / plotly

Propagation of future values

(defn mappable [x] (augmentedClj.Mapper/mappable x))

;; this returns a Clojure function, i.e. one that implements IFn
(def mult (augment (fn [a b] (* a b))))
(def add (augment (fn [a b c] (+ a b c))))

(mult 4 5)  ; here mult returns an ordinary value (20)
(add 4 5 6) ; 15

(def futureVal (mappable (future (Thread/sleep 500) (println "done") (+ 1 3))))

(def x (mult futureVal 5))
(def y (add 2 x 3))
(def z (mult 4 y))  ; here mult returns a future value

(is (= (type (. z mappable)) FutureTask))
(is (= (. z hasValue) false))
(Thread/sleep 1000)
(is (= (. z hasValue) true))
(is (= (. z value) 100))