Scala中Seq的排序实现

2020/04/29 算法 共 4980 字,约 15 分钟
梦境迷离

对Scala Seq进行排序,常见的是使用sortBy、sorted、sortWith三个函数。

其中sortBy实现很简洁,如下

  def sortBy[B](f: A => B)(implicit ord: Ordering[B]): Repr = sorted(ord on f)

其本质是调用sorted函数,使用隐式对象ord,将f参数转化为比较函数,之所以这样我猜是为了利用现有的Java排序实现。

on函数实现如下

  def on[U](f: U => T): Ordering[U] = new Ordering[U] {
     def compare(x: U, y: U) = outer.compare(f(x), f(y))
   }

on函数的作用:对于给定接收参数为U返回为T的f函数,使用outer(混入Ordering特质的实例,Ordering继承了Java的Comparator接口)创建一个比较函数,其比较函数等于

  def compare(x:U, y:U) = Ordering[T].compare(f(x), f(y))

当使用sorted而不是sortBy和sortWith时,默认使用的是Ordering伴生对象中的预定义的比较函数。这些函数最终使用Java类型的compare方法来比较一些常见的基本类型的大小。

sortBy和sortWith自身都是调用了sorted方法,只不过它们传递的ord排序函数不同。

这两个方法调用sorted是使用显示传递参数,而不是隐式,直接调用sorted时,会在Ordering及其伴生对象域内查找可用的隐式排序函数),比较函数就是Java的比较器。

  def sorted[B >: A](implicit ord: Ordering[B]): Repr = {
    val len = this.length
    //可变集合构建(从这可以看到Scala常规操作就是在函数式结构内部封装可变结构,而不对用户暴露可变性)
    val b = newBuilder
    //单个元素不用排序,直接返回当前集合
    if (len == 1) b ++= this
    else if (len > 1) {
        //标记需要添加多少个元素
        //当调用“result”返回集合结果时。一些生成器(builder)将根据此值优化它们的表示。如Set对个数为0~4的集合会进行优化处理
      b.sizeHint(len)
      val arr = new Array[AnyRef](len)  //以前使用的ArraySeq用于更紧凑但速度较慢的代码
      var i = 0
      for (x <- this) {
          //此处进行类型转换为了调用Java方法
        arr(i) = x.asInstanceOf[AnyRef]
        i += 1
      }
      //使用Java的Arrays.sort进行排序
      java.util.Arrays.sort(arr, ord.asInstanceOf[Ordering[Object]])
      i = 0
      while (i < arr.length) {
          //重新转换为原类型
        b += arr(i).asInstanceOf[A]
        i += 1
      }
    }
    //返回排序结果
    b.result()
  }

Java的Arrays.sort方法主要对数组进行排序

 public static <T> void sort(T[] a, Comparator<? super T> c) {
        if (c == null) {
            //没有比较器,调用sort,并判断是否选择旧的归并排序(归并)
            sort(a);
        } else {
            //有比较器,仍然需要判断是否用户选择了旧的归并排序
            if (LegacyMergeSort.userRequested)
                legacyMergeSort(a, c);
            else
            //调用新的TimSort排序
                TimSort.sort(a, 0, a.length, c, null, 0, 0);
        }
    }

根据指定的比较器对指定的对象数组进行排序。数组中的所有元素都必须是由指定的比较器进行相互比较(即c.compare(e1,e2)不能对数组中的任何元素抛出CastException异常)。 对Scala而言,Seq可以存储不同类型的数据,但是由于底层排序是基于Java数组的,所以存储了不同类型数据的Seq将不能被sorted排序。

TimSort是一种稳定的、自适应的、迭代的归并排序,在部分排序的数组上运行时需要的比较远远少于 nlg(n),而在随机未排序的数组上运行时提供的性能与传统的归并排序相当。 像所有正确的归并排序一样,这种排序是稳定的,运行需要 O(nlogn) 时间(最坏情况)。在最坏的情况下,这种排序需要n/2个对象引用的临时存储空间;在最好的情况下,它只需要少量的恒定空间。

本质是一种经过优化能很好利用元素已有序列状态的归并排序。

无比较器的Arrays.sort实现

  public static void sort(Object[] a) {
      //可以使用系统属性java.util.Arrays.useLegacyMergeSort来选择旧的归并排序实现
      if (LegacyMergeSort.userRequested)
          legacyMergeSort(a);
      else
      //ComparableTimSort实际是TimSort的一个近似副本实现,用于实现Comparable的对象数组,而不是使用显式比较器。很显然,这就是一个没有比较器的TimSort,不对其进行讨论。
          ComparableTimSort.sort(a, 0, a.length, null, 0, 0);
  }

有比较器的TimSort.sort的实现

  static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c,
                        T[] work, int workBase, int workLen) {
      assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;

      //计算需要排序的元素个数  
      int nRemaining  = hi - lo;
      if (nRemaining < 2)
          return;  //大小为0和1的数组始终是排序的

    //如果数组很小,则使用不需要归并的“mini-TimSort”,这个值为32,小于32将不会使用归并,而使用binarySort(二分排序, 它需要O(nlogn)次比较,但是最坏的情况需要O(n^2)次数据移动)
    //MIN_MERGE常数应为2的幂。Tim Peter的C实现中为64,但是凭经验确定32在此实现中能更好地工作。 万一您将此常量设置为一个不是2的幂的数字,则需要更改minRunLength计算
    
    //对于binarySort,如果指定范围的初始部分已经排序,则此方法可以利用它:此方法假定索引lo(包含)到start(不包含)中的元素已经排序。
    //1.从数组开始处找到一组连接升序或严格降序(找到后翻转)的数
    //2.使用二分查找的方法将后续的数插入之前的已排序数组,binarySort对数组a[lo:hi]进行排序,并且a[lo:start]是已经排好序的。
    //算法的思路是对a[start:hi]中的元素,每次使用binarySearch为它在a[lo:start]中找到相应位置,并插入。

      if (nRemaining < MIN_MERGE) {
          int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
          //使用二分排序,计算出从哪开始二分排序并利用已有序的元素(截止到initRunLen,数组一定是升序的)
          binarySort(a, lo, hi, lo + initRunLen, c);
          return;
      }

      TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
      //从剩下需要排序的元素个数中,选取minRun大小,之后待排序数组将被分成以minRun大小为区块的一块块子数组
      int minRun = minRunLength(nRemaining);
      do {
        //找到初始的一组升序数列,countRunAndMakeAscending会找到一个run
        //这个run必须是已经排序的,并且该函数会保证它为升序,也就是说,如果找到的是一个降序的,会对其进行翻转
          int runLen = countRunAndMakeAscending(a, lo, hi, c);

        //若这组区块大小小于minRun,则将后续的数补足,利用binarySort对run进行扩展,并且扩展后,run仍然是有序的
          if (runLen < minRun) {
              int force = nRemaining <= minRun ? nRemaining : minRun;
              binarySort(a, lo, lo + force, lo + runLen, c);
              runLen = force;
          }

        //当前的run位于a[lo:runLen],将其入栈ts.pushRun(lo, runLen);
        //为后续merge各区块作准备:记录当前已排序的各区块的大小(pushRun记录run中元素索引和元素个数)
          ts.pushRun(lo, runLen);
        //对当前的各区块进行merge
          ts.mergeCollapse();

         //寻找下一个run,lo之前是已经排序的
          lo += runLen;
          nRemaining -= runLen; //计算剩下未排序的元素
      } while (nRemaining != 0); //直到将待排序数组排序完退出循环

      //归并所有剩余run以完成排序
      assert lo == hi;
      ts.mergeForceCollapse();
      assert ts.stackSize == 1;
}

countRunAndMakeAscending实现

  private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi,
                                                  Comparator<? super T> c) {
      assert lo < hi;
      int runHi = lo + 1;
      if (runHi == hi)
          return 1;

      if (c.compare(a[runHi++], a[lo]) < 0) { //降序的需要反转一下
          while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
              runHi++;
          reverseRange(a, lo, runHi);
      } else {                              //升序
          while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
              runHi++;
      }
      //返回已经升序的元素个数
      return runHi - lo;
}

minRunLength实现

  private static int minRunLength(int n) {
      assert n >= 0;
      int r = 0;     
      while (n >= MIN_MERGE) {
          r |= (n & 1);
           n >>= 1;
      }
      return n + r;
  }
  • 如果数组大小小于MIN_MERGE,返回n
  • 如果数组大小为2的N次幂,则返回16(MIN_MERGE / 2)
  • 其他情况下,逐位向右位移(即除以2),直到找到介于16和32间的一个数
  private void mergeCollapse() {
      while (stackSize > 1) {
          int n = stackSize - 2;
          if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
              if (runLen[n - 1] < runLen[n + 1])
                  n--;
               mergeAt(n);
          } else if (runLen[n] <= runLen[n + 1]) {
              mergeAt(n);
          } else {
              break; // Invariant is established
          }
      }
  }
  • 只对相邻的区块归并,注意:要归并的两个run是已经排序的
  • 若当前区块数仅为2,If X<=Y,将X和Y归并
  • 若当前区块数>=3,If X<=Y+Z,将X和Y归并,直到同时满足X>Y+Z和Y>Z

文档信息

Search

    Table of Contents