Skip to content

Commit

Permalink
fix data races (#49)
Browse files Browse the repository at this point in the history
* make routeOpt.funcMap concurrency safe

* set get channel conncurrency safe

* route concurrency safe

* remove setting option which is data race unsafe
  • Loading branch information
adnaan authored Feb 10, 2024
1 parent 654a09f commit b810db8
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 30 deletions.
1 change: 1 addition & 0 deletions controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ func (c *controller) defaults() *routeOpt {
onLoad: func(ctx RouteContext) error {
return nil
},
funcMapMutex: &sync.RWMutex{},
}
return defaultRouteOpt
}
Expand Down
14 changes: 7 additions & 7 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ func layoutEmptyContentSet(opt routeOpt, content, layoutContentName string) (*te
return parseString(
template.New(
layoutContentName).
Funcs(opt.funcMap),
Funcs(opt.getFuncMap()),
content)
}
// content must be a file or directory
pageFiles := getPartials(opt, find(pageContentPath, opt.extensions, opt.embedfs))
contentTemplate := template.New(filepath.Base(pageContentPath)).Funcs(opt.funcMap)
contentTemplate := template.New(filepath.Base(pageContentPath)).Funcs(opt.getFuncMap())

return parseFiles(contentTemplate, opt.readFile, pageFiles...)
}
Expand All @@ -34,7 +34,7 @@ func layoutSetContentEmpty(opt routeOpt, layout string) (*template.Template, eve
evt := make(eventTemplates)
// is layout html content or a file/directory
if !opt.existFile(pageLayoutPath) {
return parseString(template.New("").Funcs(opt.funcMap), layout)
return parseString(template.New("").Funcs(opt.getFuncMap()), layout)
}

// layout must be a file
Expand All @@ -44,7 +44,7 @@ func layoutSetContentEmpty(opt routeOpt, layout string) (*template.Template, eve

// compile layout
commonFiles := getPartials(opt, []string{pageLayoutPath})
layoutTemplate := template.New(filepath.Base(pageLayoutPath)).Funcs(opt.funcMap)
layoutTemplate := template.New(filepath.Base(pageLayoutPath)).Funcs(opt.getFuncMap())

return parseFiles(template.Must(layoutTemplate.Clone()), opt.readFile, commonFiles...)
}
Expand Down Expand Up @@ -76,7 +76,7 @@ func layoutSetContentSet(opt routeOpt, content, layout, layoutContentName string
return pageTemplate, evt, nil
} else {
pageFiles := getPartials(opt, find(pageContentPath, opt.extensions, opt.embedfs))
pageTemplate, currEvt, err := parseFiles(layoutTemplate.Funcs(opt.funcMap), opt.readFile, pageFiles...)
pageTemplate, currEvt, err := parseFiles(layoutTemplate.Funcs(opt.getFuncMap()), opt.readFile, pageFiles...)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -106,7 +106,7 @@ func checkPageContent(tmpl *template.Template, layoutContentName string) error {

// creates a html/template for the route
func parseTemplate(opt routeOpt) (*template.Template, eventTemplates, error) {
opt.funcMap["fir"] = newFirFuncMap(RouteContext{}, nil)["fir"]
opt.addFunc("fir", newFirFuncMap(RouteContext{}, nil)["fir"])

// if both layout and content is empty show a default page.
if opt.layout == "" && opt.content == "" {
Expand All @@ -130,7 +130,7 @@ func parseTemplate(opt routeOpt) (*template.Template, eventTemplates, error) {

// creates a html/template for the route errors
func parseErrorTemplate(opt routeOpt) (*template.Template, eventTemplates, error) {
opt.funcMap["fir"] = newFirFuncMap(RouteContext{}, nil)["fir"]
opt.addFunc("fir", newFirFuncMap(RouteContext{}, nil)["fir"])
if opt.errorLayout == "" {
opt.errorLayout = opt.layout
opt.errorLayoutContentName = opt.layoutContentName
Expand Down
9 changes: 4 additions & 5 deletions render.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ func renderRoute(ctx RouteContext, errorRouteTemplate bool) routeRenderer {
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)

tmpl := ctx.route.template
tmpl := ctx.route.getTemplate()
if errorRouteTemplate {
tmpl = ctx.route.errorTemplate
tmpl = ctx.route.getErrorTemplate()
}
var errs map[string]any
errMap, ok := data["errors"]
Expand Down Expand Up @@ -59,7 +59,7 @@ func renderRoute(ctx RouteContext, errorRouteTemplate bool) routeRenderer {
func renderDOMEvents(ctx RouteContext, pubsubEvent pubsub.Event) []dom.Event {
eventIDWithState := fmt.Sprintf("%s:%s", *pubsubEvent.ID, pubsubEvent.State)
var templateNames []string
for k := range ctx.route.eventTemplates[eventIDWithState] {
for k := range ctx.route.getEventTemplates()[eventIDWithState] {
templateNames = append(templateNames, k)
}

Expand Down Expand Up @@ -127,7 +127,7 @@ func buildDOMEventFromTemplate(ctx RouteContext, pubsubEvent pubsub.Event, event
}
eventType := fir(eventIDWithState, templateName)
templateData := pubsubEvent.Detail
routeTemplate := ctx.route.template.Funcs(newFirFuncMap(ctx, nil))
routeTemplate := ctx.route.getTemplate().Funcs(newFirFuncMap(ctx, nil))
if pubsubEvent.State == eventstate.Error && pubsubEvent.Detail != nil {
errs, ok := pubsubEvent.Detail.(map[string]any)
if !ok {
Expand Down Expand Up @@ -224,7 +224,6 @@ func buildTemplateValue(t *template.Template, templateName string, data any) (st
if templateName == "_fir_html" {
dataBuf.WriteString(data.(string))
} else {
t.Option("missingkey=zero")
err := t.ExecuteTemplate(dataBuf, templateName, data)
if err != nil {
return "", err
Expand Down
108 changes: 95 additions & 13 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,7 @@ func Extensions(extensions ...string) RouteOption {
// FuncMap appends to the default template function map for the route's template engine
func FuncMap(funcMap template.FuncMap) RouteOption {
return func(opt *routeOpt) {
mergedFuncMap := make(template.FuncMap)
for k, v := range opt.funcMap {
mergedFuncMap[k] = v
}
for k, v := range funcMap {
mergedFuncMap[k] = v
}
opt.funcMap = mergedFuncMap
opt.mergeFuncMap(funcMap)
}
}

Expand Down Expand Up @@ -168,12 +161,38 @@ type routeOpt struct {
partials []string
extensions []string
funcMap template.FuncMap
funcMapMutex *sync.RWMutex
eventSender chan Event
onLoad OnEventFunc
onEvents map[string]OnEventFunc
opt
}

// add func to funcMap
func (opt *routeOpt) addFunc(key string, f any) {
opt.funcMapMutex.Lock()
defer opt.funcMapMutex.Unlock()

opt.funcMap[key] = f
}

// mergeFuncMap merges a value to the funcMap in a concurrency safe way.
func (opt *routeOpt) mergeFuncMap(funcMap template.FuncMap) {
opt.funcMapMutex.Lock()
defer opt.funcMapMutex.Unlock()
for k, v := range funcMap {
opt.funcMap[k] = v
}
}

// getFuncMap lists the funcMap in a concurrency safe way.
func (opt *routeOpt) getFuncMap() template.FuncMap {
opt.funcMapMutex.Lock()
defer opt.funcMapMutex.Unlock()

return opt.funcMap
}

type route struct {
template *template.Template
errorTemplate *template.Template
Expand Down Expand Up @@ -231,6 +250,63 @@ func writeAndPublishEvents(ctx RouteContext) eventPublisher {
}
}

// set route channel concurrency safe
func (rt *route) setChannel(channel string) {
rt.Lock()
defer rt.Unlock()
rt.channel = channel
}

// get route channel concurrency safe
func (rt *route) getChannel() string {
rt.RLock()
defer rt.RUnlock()
return rt.channel
}

// set route template concurrency safe
func (rt *route) setTemplate(t *template.Template) {
rt.Lock()
defer rt.Unlock()
rt.template = t
}

// get route template concurrency safe

func (rt *route) getTemplate() *template.Template {
rt.RLock()
defer rt.RUnlock()
return rt.template
}

// set route error template concurrency safe
func (rt *route) setErrorTemplate(t *template.Template) {
rt.Lock()
defer rt.Unlock()
rt.errorTemplate = t
}

// get route error template concurrency safe
func (rt *route) getErrorTemplate() *template.Template {
rt.RLock()
defer rt.RUnlock()
return rt.errorTemplate
}

// set event templates concurrency safe
func (rt *route) setEventTemplates(templates eventTemplates) {
rt.Lock()
defer rt.Unlock()
rt.eventTemplates = templates
}

// get event templates concurrency safe
func (rt *route) getEventTemplates() eventTemplates {
rt.RLock()
defer rt.RUnlock()
return rt.eventTemplates
}

func (rt *route) ServeHTTP(w http.ResponseWriter, r *http.Request) {
timing := servertiming.FromContext(r.Context())
defer timing.NewMetric("route").Start().Stop()
Expand Down Expand Up @@ -592,20 +668,25 @@ func handleOnLoadResult(err, onFormErr error, ctx RouteContext) {

func (rt *route) parseTemplates() {
var err error
if rt.template == nil || (rt.template != nil && rt.disableTemplateCache) {
if rt.getTemplate() == nil || (rt.getTemplate() != nil && rt.disableTemplateCache) {
var successEventTemplates eventTemplates
rt.template, successEventTemplates, err = parseTemplate(rt.routeOpt)
var rtTemplate *template.Template
rtTemplate, successEventTemplates, err = parseTemplate(rt.routeOpt)
if err != nil {
panic(err)
}
rt.setTemplate(rtTemplate)

var errorEventTemplates eventTemplates
rt.errorTemplate, errorEventTemplates, err = parseErrorTemplate(rt.routeOpt)
var rtErrorTemplate *template.Template
rtErrorTemplate, errorEventTemplates, err = parseErrorTemplate(rt.routeOpt)
if err != nil {
panic(err)
}
rt.setErrorTemplate(rtErrorTemplate)

rt.eventTemplates = deepMergeEventTemplates(errorEventTemplates, successEventTemplates)
for eventID, templates := range rt.eventTemplates {
rtEventTemplates := deepMergeEventTemplates(errorEventTemplates, successEventTemplates)
for eventID, templates := range rt.getEventTemplates() {
var templatesStr string
for k := range templates {
if k == "-" {
Expand All @@ -615,6 +696,7 @@ func (rt *route) parseTemplates() {
}
fmt.Println("eventID: ", eventID, " templates: ", templatesStr)
}
rt.setEventTemplates(rtEventTemplates)

}
}
10 changes: 5 additions & 5 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ func onWebsocket(w http.ResponseWriter, r *http.Request, cntrl *controller) {
http.Error(w, "channel is empty", http.StatusUnauthorized)
return
}
route.channel = *routeChannel
route.setChannel(*routeChannel)

// subscribers: subscribe to pubsub events
subscription, err := route.pubsub.Subscribe(ctx, route.channel)
subscription, err := route.pubsub.Subscribe(ctx, route.getChannel())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -138,7 +138,7 @@ func onWebsocket(w http.ResponseWriter, r *http.Request, cntrl *controller) {
response: w,
route: route,
}
go renderAndWriteEvent(send, route.channel, routeCtx, pubsubEvent)
go renderAndWriteEvent(send, route.getChannel(), routeCtx, pubsubEvent)
}
}()

Expand Down Expand Up @@ -169,7 +169,7 @@ func onWebsocket(w http.ResponseWriter, r *http.Request, cntrl *controller) {
// update request context with user
eventCtx.request = eventCtx.request.WithContext(context.WithValue(context.Background(), UserKey, user))

handleOnEventResult(onEventFunc(eventCtx), eventCtx, publishEvents(ctx, eventCtx, route.channel))
handleOnEventResult(onEventFunc(eventCtx), eventCtx, publishEvents(ctx, eventCtx, route.getChannel()))
}
}()

Expand Down Expand Up @@ -287,7 +287,7 @@ loop:
}

// handle user events
go handleOnEventResult(onEventFunc(eventCtx), eventCtx, publishEvents(ctx, eventCtx, eventRoute.channel))
go handleOnEventResult(onEventFunc(eventCtx), eventCtx, publishEvents(ctx, eventCtx, eventRoute.getChannel()))
}
// close writers to send
close(done)
Expand Down

0 comments on commit b810db8

Please sign in to comment.