Post

函数式编程-Scala实现整型集合

实现代码

本次作业内容是需要实现一个存放整型数组的集合类(IntSet),并且需要实现集合的一些基本运算,比如交集(intersection)、并集(union)、差集(difference)、包含(contains)等,并且要实现集合的一些基本操作比如包含(include)、删除(remove)等等。老样子先贴代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
object BinaryTree:

def inorderTraversal(set: IntSet): List[Int] = set match
    case Empty => List()
    case NonEmpty(elem, left, right) =>
    inorderTraversal(left) ++ List(elem) ++ inorderTraversal(right)

def listToIntSet(lst: List[Int]): IntSet =
    def buildBalanced(elems: List[Int]): IntSet = elems match
    case Nil => Empty
    case _ =>
        val (left, (mid :: right)) = elems.splitAt(elems.length / 2)
        NonEmpty(mid, buildBalanced(left), buildBalanced(right))

    buildBalanced(lst)

def normalize(set: IntSet): IntSet = listToIntSet(inorderTraversal(set))

abstract class IntSet:

infix def include(x: Int): IntSet

infix def remove(x: Int): IntSet

infix def contains(x: Int): Boolean

@targetName("union")
infix def (that: IntSet): IntSet

@targetName("intersection")
infix def (that: IntSet): IntSet

@targetName("complement")
infix def (that: IntSet): IntSet

@targetName("disjunctive union")
infix def (that: IntSet): IntSet

end IntSet

type Empty = Empty.type

case object Empty extends IntSet:

infix def include(x: Int): IntSet = NonEmpty(x, Empty, Empty)

infix def contains(x: Int): Boolean = false

infix def remove(x: Int): IntSet = this

@targetName("union")
infix def (that: IntSet): IntSet = that

@targetName("intersection")
infix def (that: IntSet): IntSet = this

@targetName("complement")
infix def (that: IntSet): IntSet = this

@targetName("disjunctive union")
infix def (that: IntSet): IntSet = that

override def toString: String = "[*]"


end Empty

case class NonEmpty(elem: Int, left: IntSet, right: IntSet) extends IntSet:

infix def include(x: Int): IntSet =
    BinaryTree.normalize(
    if x < elem       then NonEmpty(elem, left include x, right)
    else if x > elem  then NonEmpty(elem, left, right include x)
    else              this
    )


infix def contains(x: Int): Boolean = 
    if x < elem       then left contains x
    else if x > elem  then right contains x
    else              true

// Optional task
infix def remove(x: Int): IntSet =
    BinaryTree.normalize(
    if x < elem then NonEmpty(elem, left remove x, right)
    else if x > elem then NonEmpty(elem, left, right remove x )
    else left  right
    )

@targetName("union")
infix def (that: IntSet): IntSet = (right  (left  that)) include elem

@targetName("intersection")
infix def (that: IntSet): IntSet =
    if that.contains(elem) then
    (left  that)  (right  that) include elem
    else
    left  that  right  that

@targetName("complement")
infix def (that: IntSet): IntSet =
    if that.contains(elem) then
    (left  that)  (right  that)
    else
    (left  that)  (right  that) include elem

@targetName("disjunctive union")
infix def (that: IntSet): IntSet =
    (this  that)  (that  this)

override def toString: String = s"[$left - [$elem] - $right]"

end NonEmpty

单元测试

单元测试依旧是老师提供的挖空代码和自己补充得到的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
// Add additional cases if needed
object EmptySpecification extends Properties("Empty"):
import arbitraries.{given Arbitrary[Int], given Arbitrary[NonEmpty], given Arbitrary[IntSet]}

property("equals to Empty") = propBoolean {
Empty == Empty
}

property("not equal to NonEmpty") = forAll { (nonEmpty: NonEmpty) 
Empty != nonEmpty
}

property("include") = forAll { (element: Int) 
(Empty include element) == NonEmpty(element, Empty, Empty)
}

property("contains") = forAll { (element: Int) 
!(Empty contains element)
}

property("remove") = forAll { (element: Int) 
(Empty remove element) == Empty
}

property("union") = forAll { (set: IntSet) 
(Empty  set) == set
}

property("intersection") = forAll { (set: IntSet) 
(Empty  set) == Empty
}

property("complement of Empty") = forAll { (set: IntSet) 
(set  Empty) == set
}

property("complement of set") = forAll { (set: IntSet) 
(Empty  set) == Empty
}

property("left disjunctive union") = forAll { (set: IntSet) 
(Empty  set) == set
}

property("right disjunctive union") = forAll { (set: IntSet) 
(set  Empty) == set
}

end EmptySpecification

// Add additional cases if needed
object NonEmptySpecification extends Properties("NonEmpty"):
import arbitraries.{given Arbitrary[Int], given Arbitrary[NonEmpty], given Arbitrary[IntSet]}
import arbitraries.{nonEmptyAndAnyElement, removeAnyElement, differentIntSet}
import AbitrariesSpecification.validate

property("normalize") = forAll { (nonEmpty: NonEmpty) 
validate(nonEmpty)
}

property("not equals to Empty") = forAll { (nonEmpty: NonEmpty) 
nonEmpty != Empty
}

property("equal") = forAll { (nonEmpty: NonEmpty) 
nonEmpty == nonEmpty
}

property("not equal") = forAll { (nonEmpty: NonEmpty ) 
forAll(differentIntSet(nonEmpty)) { (different) 
    nonEmpty != different
}
}

property("include") = forAll { (nonEmpty: NonEmpty, element: Int) 
(nonEmpty include element) == listToIntSet(BinaryTree.inorderTraversal(nonEmpty) :+ element)
}

property("contains") = forAll { (nonEmpty: NonEmpty, element: Int) 
(nonEmpty include element) contains element
}

property("remove") = forAll { (nonEmpty: NonEmpty, element: Int) 
// Only one case where the last elem was inserted is removed
!(((nonEmpty include element) remove element) contains element)
}

property("remove any element") = forAll { (nonEmpty: NonEmpty) 
forAll(removeAnyElement(nonEmpty)) { case (setAfterRemoval: IntSet, removedElement: Int) 
    (nonEmpty remove removedElement) == setAfterRemoval
}
}

property("remove and keep BST") = forAll(nonEmptyAndAnyElement) { case (nonEmpty: NonEmpty, element: Int) 
validate(nonEmpty remove element)
}

property("remove a non-existent element") = forAll { (nonEmpty: NonEmpty) 
forAll(removeAnyElement(nonEmpty)) { case (setAfterRemoval: IntSet, removedElement: Int) 
    (setAfterRemoval remove removedElement) == setAfterRemoval
}
}

property("union") = forAll { (nonEmpty: NonEmpty, set: IntSet) 
val unionSet = nonEmpty  set
// It doesn't seem to cover everything ↓
// unionSet ∖ nonEmpty ∖ set  == Empty

val nonEmptyAndSetInUnion = (nonEmpty  unionSet) == Empty && (set  unionSet) == Empty
val noExtraElements = ((unionSet  nonEmpty)  set) == Empty
nonEmptyAndSetInUnion && noExtraElements
}

property("intersection") = forAll { (nonEmpty: NonEmpty, set: IntSet) 
(nonEmpty  set) == (nonEmpty  set)  (nonEmpty  set)
}

property("complement") = forAll { (nonEmpty: NonEmpty, set: IntSet) 
(nonEmpty  set) == (nonEmpty  set)  set
}

property("disjunctive") = forAll { (nonEmpty: NonEmpty, set: IntSet) 
    (nonEmpty  set) == (nonEmpty  set)  (nonEmpty  set)
}

end NonEmptySpecification

// Add additional cases if needed
object IntSetSpecification extends Properties("IntSet"):
import arbitraries.{given Arbitrary[Int], given Arbitrary[IntSet]}
import arbitraries.{removeAnyElement, differentIntSet}
import AbitrariesSpecification.validate

property("normalize") = forAll { (set: IntSet) 
validate(set)
}

property("equals") = forAll { (set: IntSet) 
set == set
}

property("not equal") = forAll { (set: IntSet) 
forAll(differentIntSet(set)) { (different) 
    set != different
}
}

property("include") = forAll { (set: IntSet, element: Int) 
(set include element) == listToIntSet(BinaryTree.inorderTraversal(set) :+ element)
}

property("contains") = forAll { (set: IntSet, element: Int) 
(set include element) contains element
}

property("remove") = forAll { (set: IntSet, element: Int) 
// Only one case where the last elem was inserted is removed
!(((set include element) remove element) contains element)
}

property("remove any element") = forAll { (set: IntSet) =>
forAll(removeAnyElement(set)) { case (setAfterRemoval: IntSet, removedElement: Int) 
    (set remove removedElement) == setAfterRemoval
}
}

property("remove with Empty") = forAll { (set: IntSet, element: Int) 
((set  set) remove element) == Empty
}

property("remove a non-existent element") = forAll { (set: IntSet) 
forAll(removeAnyElement(set)) { case (setAfterRemoval: IntSet, removedElement: Int) 
    (setAfterRemoval remove removedElement) == setAfterRemoval
}
}

property("union") = forAll { (left: IntSet, right: IntSet) 
val unionSet = left  right
val nonEmptyAndSetInUnion = (left  unionSet) == Empty && (right  unionSet) == Empty
val noExtraElements = ((unionSet  left)  right) == Empty
nonEmptyAndSetInUnion && noExtraElements
}

property("intersection") = forAll { (left: IntSet, right: IntSet) 
(left  right) == left  right  (left  right)
}

property("complement") = forAll { (left: IntSet, right: IntSet) 
(left  right) == (left  right)  right
}

property("disjunctive") = forAll { (left: IntSet, right: IntSet) 
(left  right) == (left  right)  (left  right)
}

Equals

除了Equals那部分代码以外,其他单元测试均没太大问题,Equals部分老师给我挖了个坑提了一个要求:

1
2
3
4
That's very good that you followed the contract.
But case objects and case classes have appropriate equals method out of the box, so you don't need to implement it. So I'd rely on that and remove equals and hashCode for Empty and NonEmpty.

老师要求的是去掉Empty和NonEmpty的equals和hashCode方法,但是我在实现的时候,发现老师给的代码中,Empty和NonEmpty的equals和hashCode方法并没有实现,所以我就没有实现,而是直接继承了case class和case object的equals和hashCode方法。

但是直接使用这种方法我遇到了一些问题,如果我删除了自定义的equals实现,然后使用case类提供的现成的equals来比较IntSet类的实例,由于排序问题,比较可能会失败,如下面的场景所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Both sets contain the same number
val setA = listToIntSet(List(5, 2, 19, 4, 7, 12))
val setB = listToIntSet(List(12, 5, 2, 19, 4, 7))
// Will yield False
setA == setB

setA:                           setB:
       5                               12
      / \                             /  \
     /   \                           /    \
    2     19                        5      19
     \   /                         / \
      4 7                        2   7
         \                        \
         12                        4

因此,我发现问题从比较两个集合变成了比较两个二叉树。因此,我的方法是实现一个归一化函数,在每次增加或删除节点时,对所有节点重新排序以构建平衡树:

1
2
3
4
5
6
7
8
9
10
val normA = BinaryTree.normalize(setA)

setA:                             	normA:
         5                                 7
       /   \                             /   \
      2     19                          4     19
       \   /                           / \   /
        4 7                           2   5 12
           \
           12

这样实际上是构成一个二叉搜索树,也叫二叉排序树,这样就可以保证每次比较两个集合时,都能得到正确的结果。

当然这样写是没有效率的,老师也提到了,但是老师也提到了,这个作业的目的是让我们熟悉Scala的函数式编程,课程设置上并没有要求效率,所以我就没有继续优化。

This post is licensed under CC BY 4.0 by the author.