Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breeze "sortrows" function - first implementation below - help wanted? #770

Open
Quafadas opened this issue Jan 15, 2020 · 1 comment
Open

Comments

@Quafadas
Copy link
Contributor

Quafadas commented Jan 15, 2020

One of the points I observer with Breeze vs other libraries, is that some "higher order" functions are missing. For example I wanted "sortrows" - similar to matlab.

I have a naive implementation below. I tried to add it to breeze itself by constructing a PR, but drowned in Generics, implicits and the difficulties of working inside a large project.

I'd love to see it added to the core library, simply because, for my use case, it makes seeing inside the data much easier. Unfortunately, the integration piece is beyond my skill, and I'm not clear what the qualityx requirements would be.

In case it helps someone with a similar requirement, there is someone willing to help me with contributing, or it's useful as an algorithmic outline.

 package helperFuns

  import breeze.linalg.{*, DenseMatrix, unique}

  object sortrows {

   def apply(dm: DenseMatrix[Double], cols: IndexedSeq[Int]): DenseMatrix[Double] = {

    val colOfInterest = cols.head
    val uniqueIndex = unique(dm(::, colOfInterest))
    val theTail = cols.tail
    val numColumns = dm.cols

    val naiveLog = false // possibly the worst logging framework ever.

    if (naiveLog) {
      println("Num columns : " + numColumns)
      println("Sort in column order: " + cols)
      println("Current sort column : " + colOfInterest)
    }

    val lotsOfSmallMatricies = for (i <- uniqueIndex.toArray) yield {
      //      Identify the subgroup to sort
      val tmp = dm(dm(::, colOfInterest) :== i, *).underlying.toDenseMatrix

      if (naiveLog) {
        println("sort val  : " + i)
        println("sort matrix  : " + tmp)
      }
      if (tmp.rows == 1) {
        //        Optimisation... if there's only 1 row, then just return it! Don't need all the complexity below
        tmp
      } else {
         //        println("Remaining columns to sort " + theTail)
        theTail.isEmpty match {
          case false => {
            val remaining = (0 until dm.cols).toSet.diff(Set(colOfInterest)).toList
            if (naiveLog) {
              println("Unsorted : ")
              println(remaining)
            }

            val remainingSort = tmp(*, remaining).underlying.toDenseMatrix
            if (naiveLog) {
              println("Chop out sorted column : \n" + remainingSort)
            }

            // Maintain the correct indicies of the columns we want to "sort", during the recursion
            val processTailIndicies = theTail.map(x => if(x > colOfInterest) x- 1 else x)
            if (naiveLog) {
              println("Remaining cols to sort : " + processTailIndicies)
            }

            val nextSort = sortrows(remainingSort, processTailIndicies)
            val nextSortCols = nextSort.cols
            val thisColumn = tmp(::, colOfInterest).toDenseMatrix.t

            colOfInterest match {
              case 0 => DenseMatrix.horzcat(thisColumn, nextSort)

              case `numColumns` => DenseMatrix.horzcat(nextSort, thisColumn)

              case _ => {
                if (naiveLog) {
                  println("\n" + nextSort )
                  println("0 until colOfInterest : \n" + (0 until colOfInterest))
                  println("colOfInterest until num cols: \n" + ( colOfInterest + 1 until numColumns))
                }
                DenseMatrix.horzcat(nextSort(::, 0 until colOfInterest), thisColumn, nextSort(::, colOfInterest until nextSort.cols ))
              }
            }
          }
          case true => tmp
        }
      }
    }
    lotsOfSmallMatricies.reduce(DenseMatrix.vertcat(_, _))
   } 
 }

Tests

package models

import java.io.File

import breeze.linalg.{DenseMatrix, DenseVector, convert, csvwrite}
import org.scalatest.{FlatSpec, Matchers}
import helperFuns.sortrows

class SortRowsSpec extends FlatSpec with Matchers {

  val groups = DenseMatrix(1.0, 1.0, 3.0, 2.0, 3.0, 1.0, 1.0)

  "sortrows" should "sort a column" in {

    val expected = DenseMatrix(1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 3.0)
    val out = sortrows(groups, Vector(0))

    out shouldEqual expected
  }

  it should "correctly sort subgroups in column preference order" in {

    val dv = DenseMatrix(0.0, 5.0, 6.0, 3.0, 7.0, 3.0, 1.0)
    val dv1 = DenseMatrix(2.0, 6.0, 7.0, 4.0, 8.0, 4.0, 1.0)
    val dv2 = DenseMatrix(10.0, 1.0, 4.0, 2.0, 3.0, 0.0, 10.0)

    val dm = DenseMatrix.horzcat(groups, dv, dv1, dv2)
    println(dm)

    val expected = DenseMatrix(
      (1.0, 3.0, 4.0, 0.0),
      (1.0, 5.0, 6.0, 1.0),
      (1.0, 1.0, 1.0, 10.0),
      (1.0, 0.0, 2.0, 10.0),
      (2.0, 3.0, 4.0, 2.0),
      (3.0, 7.0, 8.0, 3.0),
      (3.0, 6.0, 7.0, 4.0)
    )

    val sorted = sortrows(dm, Vector(0, 3, 2))

    sorted shouldEqual expected
  }


    it should "correctly sort subgroups in column preference order again" in {

      val dv = DenseMatrix(0.0, 5.0, 6.0, 3.0, 7.0, 3.0, 1.0)
      val dv1 = DenseMatrix(2.0, 6.0, 7.0, 4.0, 8.0, 4.0, 1.0)
      val dv2 = DenseMatrix(10.0, 1.0, 4.0, 2.0, 3.0, 0.0, 10.0)
      val dm = DenseMatrix.horzcat(groups, dv, dv1, dv2)
      println(dm)

      val expected2 = DenseMatrix(
        (1.0, 1.0, 1.0, 10.0),
        (1.0, 0.0, 2.0, 10.0),
        (1.0, 3.0, 4.0, 0.0),
        (2.0, 3.0, 4.0, 2.0),
        (1.0, 5.0, 6.0, 1.0),
        (3.0, 6.0, 7.0, 4.0),
        (3.0, 7.0, 8.0, 3.0),
      )

      val sorted2 = sortrows(dm, Vector(2, 0))

      sorted2 shouldEqual expected2

    }

  it should "deal with at least a non trivial number of rows" in {
    val r = scala.util.Random
    val dm = DenseMatrix.tabulate(1000000, 5){case (i, j) => i%5 + j + r.nextInt(5) }
//    println(dm)

    //    This takes sbout 2s on my PC, which isn't that bad!
     val sorted = sortrows(convert(dm, Double) , Vector(4,1,3,0,2))


//    val fileloc = new File("c:/temp/temp.csv")
//     csvwrite(fileloc, sorted, ',')
      sorted.rows shouldEqual 1000000
  }


}


@dlwh
Copy link
Member

dlwh commented Jan 30, 2020

thanks for sending this my way! I'll try to make it breezy sometime soon and i'll followup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants