Scala:从元组数组/RDD中获取第n个元素的总和

问题描述 投票:0回答:4

我有一系列

tuple
,如下所示:

val a = Array((1,2,3), (2,3,4))

我想为如下方法编写一个通用方法:

def sum2nd(aa: Array[(Int, Int, Int)]) = {
      aa.map { a => a._2 }.sum
      }

所以我正在寻找这样的方法:

def sumNth(aa: Array[(Int, Int, Int)], n: Int)
scala collections functional-programming
4个回答
9
投票

有几种方法可以解决这个问题。最简单的是使用

productElement
:

def unsafeSumNth[P <: Product](xs: Seq[P], n: Int): Int =
  xs.map(_.productElement(n).asInstanceOf[Int]).sum

然后(请注意,索引从零开始,因此

n = 1
为我们提供了第二个元素):

scala> val a = Array((1, 2, 3), (2, 3, 4))
a: Array[(Int, Int, Int)] = Array((1,2,3), (2,3,4))

scala> unsafeSumNth(a, 1)
res0: Int = 5

但是,此实现可能会以两种不同的方式在运行时崩溃:

scala> unsafeSumNth(List((1, 2), (2, 3)), 3)
java.lang.IndexOutOfBoundsException: 3
  at ...

scala> unsafeSumNth(List((1, "a"), (2, "b")), 1)
java.lang.ClassCastException: java.lang.String cannot be cast to java.lang.Integer
  at ...

即,如果元组没有足够的元素,或者您要求的元素不是

Int

您可以编写一个在运行时不会崩溃的版本:

import scala.util.Try

def saferSumNth[P <: Product](xs: Seq[P], n: Int): Try[Int] = Try(
  xs.map(_.productElement(n).asInstanceOf[Int]).sum
)

然后:

scala> saferSumNth(a, 1)
res4: scala.util.Try[Int] = Success(5)

scala> saferSumNth(List((1, 2), (2, 3)), 3)
res5: scala.util.Try[Int] = Failure(java.lang.IndexOutOfBoundsException: 3)

scala> saferSumNth(List((1, "a"), (2, "b")), 1)
res6: scala.util.Try[Int] = Failure(java.lang.ClassCastException: ...

这是一项改进,因为它迫使调用者解决失败的可能性,但它也有点烦人,因为它迫使调用者解决失败的可能性。

如果您愿意使用Shapeless,您可以两全其美:

import shapeless._, shapeless.ops.tuple.At

def sumNth[P <: Product](xs: Seq[P], n: Nat)(implicit
  atN: At.Aux[P, n.N, Int]
): Int = xs.map(p => atN(p)).sum

然后:

scala> sumNth(a, 1)
res7: Int = 5

但是坏的甚至无法编译:

scala> sumNth(List((1, 2), (2, 3)), 3)
<console>:17: error: could not find implicit value for parameter atN: ...

但这仍然不完美,因为这意味着第二个参数必须是一个文字数字(因为它需要在编译时知道):

scala> val x = 1
x: Int = 1

scala> sumNth(a, x)
<console>:19: error: Expression x does not evaluate to a non-negative Int literal
       sumNth(a, x)
                 ^

但在很多情况下这不是问题。

总结一下:如果您愿意为导致程序崩溃的合理代码承担责任,请使用

productElement
。如果您想要更多的安全性(以一些不便为代价),请将
productElement
Try
一起使用。如果您想要编译时安全(但有一些限制),请使用 Shapeless。


2
投票

你可以做这样的事情,尽管它并不是真正的类型安全:

  def sumNth(aa: Array[Product], n: Int)= {
    aa.map { a =>
      a.productElement(n) match {
        case i:Int => i
        case _ => 0
      }
    }.sum
  }

sumNth(Array((1,2,3), (2,3,4)), 2) // 7

2
投票

不使用

shapeless
的另一种类型安全方法是提供一个函数来提取您需要的元素:

def sumNth[T, E: Numeric](array: Array[T])(extract: T => E) =
  array.map(extract).sum

然后你可以像这样定义

sum2nd

def sum2nd(array: Array[(Int, Int, Int)]): Int = sumNth(array)(_._2)

或者像这样:

val sum2nd: Array[(Int, Int, Int)] => Int = sumNth(_)(_._2)

0
投票

自 Scala 3 起,可以通过使用类似函数调用的语法进行索引来访问元组的元素(例如

t(1)
获取元组的第二个元素
t
)。因此
sumNth
可以这样实现:

def sumNth(arr: Array[(Int, Int, Int)], n: Int) = arr.map(_(n).asInstanceOf[Int]).sum

Scastie 演示

© www.soinside.com 2019 - 2024. All rights reserved.