improved efficiency

This commit is contained in:
pem 2022-04-24 19:11:56 +02:00 committed by Ralph Caraveo
parent 41eeb0673a
commit eb73d93c11
1 changed files with 47 additions and 23 deletions

View File

@ -47,6 +47,11 @@ func (s *threadUnsafeSet[T]) Add(v T) bool {
return prevLen != len(*s)
}
// private version of Add which doesn't return a value
func (s *threadUnsafeSet[T]) add(v T) {
(*s)[v] = struct{}{}
}
func (s *threadUnsafeSet[T]) Cardinality() int {
return len(*s)
}
@ -56,9 +61,9 @@ func (s *threadUnsafeSet[T]) Clear() {
}
func (s *threadUnsafeSet[T]) Clone() Set[T] {
clonedSet := newThreadUnsafeSet[T]()
clonedSet := make(threadUnsafeSet[T], s.Cardinality())
for elem := range *s {
clonedSet.Add(elem)
clonedSet.add(elem)
}
return &clonedSet
}
@ -72,13 +77,19 @@ func (s *threadUnsafeSet[T]) Contains(v ...T) bool {
return true
}
// private version of Contains for a single element v
func (s *threadUnsafeSet[T]) contains(v T) bool {
_, ok := (*s)[v]
return ok
}
func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
_ = other.(*threadUnsafeSet[T])
o := other.(*threadUnsafeSet[T])
diff := newThreadUnsafeSet[T]()
for elem := range *s {
if !other.Contains(elem) {
diff.Add(elem)
if !o.contains(elem) {
diff.add(elem)
}
}
return &diff
@ -93,13 +104,13 @@ func (s *threadUnsafeSet[T]) Each(cb func(T) bool) {
}
func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool {
_ = other.(*threadUnsafeSet[T])
o := other.(*threadUnsafeSet[T])
if s.Cardinality() != other.Cardinality() {
return false
}
for elem := range *s {
if !other.Contains(elem) {
if !o.contains(elem) {
return false
}
}
@ -113,14 +124,14 @@ func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
// loop over smaller set
if s.Cardinality() < other.Cardinality() {
for elem := range *s {
if other.Contains(elem) {
intersection.Add(elem)
if o.contains(elem) {
intersection.add(elem)
}
}
} else {
for elem := range *o {
if s.Contains(elem) {
intersection.Add(elem)
if s.contains(elem) {
intersection.add(elem)
}
}
}
@ -128,20 +139,20 @@ func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
}
func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
return s.IsSubset(other) && !s.Equal(other)
return s.Cardinality() < other.Cardinality() && s.IsSubset(other)
}
func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
return s.IsSuperset(other) && !s.Equal(other)
return s.Cardinality() > other.Cardinality() && s.IsSuperset(other)
}
func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
_ = other.(*threadUnsafeSet[T])
o := other.(*threadUnsafeSet[T])
if s.Cardinality() > other.Cardinality() {
return false
}
for elem := range *s {
if !other.Contains(elem) {
if !o.contains(elem) {
return false
}
}
@ -205,11 +216,20 @@ func (s *threadUnsafeSet[T]) String() string {
}
func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
_ = other.(*threadUnsafeSet[T])
o := other.(*threadUnsafeSet[T])
a := s.Difference(other)
b := other.Difference(s)
return a.Union(b)
sd := newThreadUnsafeSet[T]()
for elem := range *s {
if !o.contains(elem) {
sd.add(elem)
}
}
for elem := range *o {
if !s.contains(elem) {
sd.add(elem)
}
}
return &sd
}
func (s *threadUnsafeSet[T]) ToSlice() []T {
@ -224,13 +244,17 @@ func (s *threadUnsafeSet[T]) ToSlice() []T {
func (s *threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
unionedSet := newThreadUnsafeSet[T]()
n := s.Cardinality()
if o.Cardinality() > n {
n = o.Cardinality()
}
unionedSet := make(threadUnsafeSet[T], n)
for elem := range *s {
unionedSet.Add(elem)
unionedSet.add(elem)
}
for elem := range *o {
unionedSet.Add(elem)
unionedSet.add(elem)
}
return &unionedSet
}
@ -266,7 +290,7 @@ func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
for _, v := range i {
switch t := v.(type) {
case T:
s.Add(t)
s.add(t)
default:
// anything else must be skipped.
continue