Scala 函数与递归

851 阅读2分钟

Scala 函数的进阶用法,见:Scala :Function as Value - 掘金 (juejin.cn)

函数的基本定义方式

在 Scala 中,函数的基本声明方式有两种:一种是显式标注返回值类型的写法,或者是交由编译器推导类型的缺省写法。

def foo1(p: T) : R = {/*body*/}
def foo2(p : T) = {/*body*/}

如果函数体的结构很简单,比如只有一行,那么花括号 {} 就不是必须的。同时,如果函数体的最后一个表达式能充当返回值,那么 return 关键字就可以缺省掉。比如:

def add(a : Int,b : Int) = a + b

当返回值类型由编译器推导时,它将遵循以下的策略:

  1. 如果函数体最终没有返回有效值,则认为返回值为 Unit 类型。
  2. 如果函数体最终返回的值类型无法具体判断,编译器优先选择可推断出的最上界类型,如 Any

参数列表内的所有类型被视作不可变的值 val。这意味着:

  1. 对于数值型数据,无法进行重新赋值。
  2. 对于引用型数据,无法重新改变引用。

注意,Scala 的不可变值和 Java 的 final 关键字等价。对于引用型数据,如果它内部属性是可变的,那么在函数内仍然可以对其属性重新赋值。

函数内部还可以定义局部函数。比如:

def main(args: Array[String]): Unit = {
  def inner() = print("hello")
  inner()  
}

局部函数只能在其定义域内部使用,比如说上述代码块的 inner 函数只能在 main 函数内部调用。

默认参数

参数后面可以通过 = 号赋默认值,比如:

def foo(w : String = "Scala") = println(s"Hello,${w}!")
foo()

在这个示例中,我们直接使用 w 的默认值,因而无需在调用 foo 函数时再显式地传入参数,程序仍可以打印:Hello,Scala!。但是,当参数列表只有部分参数具备默认值的情况下,可能会引起混淆,比如:

def foo(w1 : String = "Hello,", w2 : String) = println(s"${w1},${w2}!")
foo("Scala!")

编译器不确定其入参将覆盖 w1 的默认值,还是将要赋值给 w2。一个解决方案是在调用时做显式的标注,比如:

foo(w2 = "Scala!")

使用这种方法可以显著地提高代码的可读性。

变长参数

当参数类型后面被标注为 * 符号时,表示该函数接收的是变长参数。比如:

def g(xs : Int*): Unit = for{i <- xs} println(i)

g(1,2,3)
g(1,2,3,4,5)

变长参数最终被会包装为 scala.collection.mutable.WrappedArray 类型,它可以像普通序列那样被遍历。为了避免混淆,变长参数只能在参数列表的最后一个位置,比如:

def g(head : Int,left : Int*) = head :: left.toList

注意,Array[T] 类型,或者 Seq[T] 类型不能直接作为变长参数传入,它们必须使用特殊的标注符 _* 进行一步转换:

def g(xs : Int*) = ()
g(Array(1,2,3) : _*)
g(Seq(1,2,3) : _*)

在 Scala 3 版本,这种标记方式被简化了。见:Scala 3 新特性一览 - 掘金 (juejin.cn) 可变参数拼接部分。

递归函数

递归函数可概括出两个基本的要素:

  1. 临界条件。递归函数必须有明确的条件表示它何时应该退出。
  2. 递归公式。可理解成递归函数在未达到临界条件时所执行的操作,并再一次调用自身,并传入新的参数。一个正确的递归函数将总是向着临界条件收敛。

有一种特殊的递归函数,它不设立临界条件,而仅有递归方程。这种递归函数必定是惰性加载的,常用于生成流数据 ( Stream ),又被称之为 共递归。它是函数式编程中的概念之一,见:探究 Scala 非严格求值与流式数据结构 - 掘金 (juejin.cn) 最后一节。

斐波那契数列

斐波那契(Fibonacci) 数列指:1,1,2,3,5,8,... 。从第 3 项开始,它的值等于前两项值的和,即 f(n)=f(n1)+f(n2),n>2f(n)=f(n-1)+f(n-2),n>2

下面将编写一个 fibonacci 函数以递归形式求第 n 个位置上的值。由于斐波那契对 n <= 2 的情况做了特殊定义,因此可以将它设置为临界条件。当 n > 3 时,则需要根据递归公式执行操作。

// n > 0
def fibonacci(n: Int): Int = if (n <= 2) 1 else fibonacci(n - 1) + fibonacci(n - 2)

避免重复计算

从效率上来看,这段代码的性能并不是很好。观察这段内联的代码: fibonacci(n - 1) + fibonacci(n - 2)。Scala 按照严格的从左到右的顺序进行计算,因此会首先计算 fibonacci(n - 1)的值,再去计算 fibonacci(n - 2) 的值。

实际上,fibonacci(n - 1) 本身就包含了一部分 fibonacci(n - 2) 的计算。而由于函数在递归的过程中是无记忆性的,因此大量的中间结果被反复计算。对于这类问题,其中一个解决思路是以空间换取时间。比如,使用映射表 Map 增加一层缓存功能,避免对中间结果重复计算。

val cache = mutable.Map[Int, Int]()
def fibonacci(n: Int): Int = if (n <= 2) 1 else if (cache.contains(n)) cache(n) else {
    val v = fibonacci(n - 2) + fibonacci(n - 1)
    cache += (n -> v)
    v
}

递归是 Scala 乃至其它支持函数式编程语言的重要一部分。稍微留意我们平常使用的树,链表等就能发现它们的定义都是递归的。对于递归定义的数据类型,使用递归能够方便地表述各类操作,而无需借助传统的 for-loop 或者是 while-loop 进行迭代。见:Scala 函数式数据结构与递归的艺术 - 掘金 (juejin.cn)