ref: 02effd9dc46201df250564cdbbee1ca2291eb8b4
parent: f52b040ee126ec0c48f1d273681a860fe7814314
author: Bjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
date: Mon Mar 21 16:42:27 EDT 2016
Protect against concurrent Scratch read and write Fixes #2005
--- a/hugolib/scratch.go
+++ b/hugolib/scratch.go
@@ -17,11 +17,13 @@
"github.com/spf13/hugo/helpers"
"reflect"
"sort"
+ "sync"
)
// Scratch is a writable context used for stateful operations in Page/Node rendering.
type Scratch struct { values map[string]interface{}+ mu sync.RWMutex
}
// For single values, Add will add (using the + operator) the addend to the existing addend (if found).
@@ -29,6 +31,9 @@
//
// If the first add for a key is an array or slice, then the next value(s) will be appended.
func (c *Scratch) Add(key string, newAddend interface{}) (string, error) {+ c.mu.Lock()
+ defer c.mu.Unlock()
+
var newVal interface{}existingAddend, found := c.values[key]
if found {@@ -59,6 +64,9 @@
// Set stores a value with the given key in the Node context.
// This value can later be retrieved with Get.
func (c *Scratch) Set(key string, value interface{}) string {+ c.mu.Lock()
+ defer c.mu.Unlock()
+
c.values[key] = value
return ""
}
@@ -65,6 +73,9 @@
// Get returns a value previously set by Add or Set
func (c *Scratch) Get(key string) interface{} {+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
return c.values[key]
}
@@ -71,6 +82,9 @@
// SetInMap stores a value to a map with the given key in the Node context.
// This map can later be retrieved with GetSortedMapValues.
func (c *Scratch) SetInMap(key string, mapKey string, value interface{}) string {+ c.mu.Lock()
+ defer c.mu.Unlock()
+
_, found := c.values[key]
if !found { c.values[key] = make(map[string]interface{})@@ -82,6 +96,9 @@
// GetSortedMapValues returns a sorted map previously filled with SetInMap
func (c *Scratch) GetSortedMapValues(key string) interface{} {+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
if c.values[key] == nil {return nil
}
--- a/hugolib/scratch_test.go
+++ b/hugolib/scratch_test.go
@@ -16,6 +16,7 @@
import (
"github.com/stretchr/testify/assert"
"reflect"
+ "sync"
"testing"
)
@@ -78,6 +79,41 @@
scratch := newScratch()
scratch.Set("key", "val") assert.Equal(t, "val", scratch.Get("key"))+}
+
+// Issue #2005
+func TestScratchInParallel(t *testing.T) {+ var wg sync.WaitGroup
+ scratch := newScratch()
+ key := "counter"
+ scratch.Set(key, 1)
+ for i := 1; i <= 10; i++ {+ wg.Add(1)
+ go func(j int) {+ for k := 0; k < 10; k++ {+ newVal := k + j
+
+ _, err := scratch.Add(key, newVal)
+ if err != nil {+ t.Errorf("Got err %s", err)+ }
+
+ scratch.Set(key, newVal)
+
+ val := scratch.Get(key)
+
+ if counter, ok := val.(int); ok {+ if counter < 1 {+ t.Errorf("Got %d", counter)+ }
+ } else {+ t.Errorf("Got %T", val)+ }
+ }
+ wg.Done()
+ }(i)
+ }
+ wg.Wait()
}
func TestScratchGet(t *testing.T) {--
⑨