218 lines
6.1 KiB
Go
218 lines
6.1 KiB
Go
|
package runtime
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"html"
|
||
|
"maps"
|
||
|
"reflect"
|
||
|
"slices"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/a-h/templ"
|
||
|
"github.com/a-h/templ/safehtml"
|
||
|
)
|
||
|
|
||
|
// SanitizeStyleAttributeValues renders a style attribute value.
|
||
|
// The supported types are:
|
||
|
// - string
|
||
|
// - templ.SafeCSS
|
||
|
// - map[string]string
|
||
|
// - map[string]templ.SafeCSSProperty
|
||
|
// - templ.KeyValue[string, string] - A map of key/values where the key is the CSS property name and the value is the CSS property value.
|
||
|
// - templ.KeyValue[string, templ.SafeCSSProperty] - A map of key/values where the key is the CSS property name and the value is the CSS property value.
|
||
|
// - templ.KeyValue[string, bool] - The bool determines whether the value should be included.
|
||
|
// - templ.KeyValue[templ.SafeCSS, bool] - The bool determines whether the value should be included.
|
||
|
// - func() (anyOfTheAboveTypes)
|
||
|
// - func() (anyOfTheAboveTypes, error)
|
||
|
// - []anyOfTheAboveTypes
|
||
|
//
|
||
|
// In the above, templ.SafeCSS and templ.SafeCSSProperty are types that are used to indicate that the value is safe to render as CSS without sanitization.
|
||
|
// All other types are sanitized before rendering.
|
||
|
//
|
||
|
// If an error is returned by any function, or a non-nil error is included in the input, the error is returned.
|
||
|
func SanitizeStyleAttributeValues(values ...any) (string, error) {
|
||
|
if err := getJoinedErrorsFromValues(values...); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
sb := new(strings.Builder)
|
||
|
for _, v := range values {
|
||
|
if v == nil {
|
||
|
continue
|
||
|
}
|
||
|
if err := sanitizeStyleAttributeValue(sb, v); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
}
|
||
|
return sb.String(), nil
|
||
|
}
|
||
|
|
||
|
func sanitizeStyleAttributeValue(sb *strings.Builder, v any) error {
|
||
|
// Process concrete types.
|
||
|
switch v := v.(type) {
|
||
|
case string:
|
||
|
return processString(sb, v)
|
||
|
|
||
|
case templ.SafeCSS:
|
||
|
return processSafeCSS(sb, v)
|
||
|
|
||
|
case map[string]string:
|
||
|
return processStringMap(sb, v)
|
||
|
|
||
|
case map[string]templ.SafeCSSProperty:
|
||
|
return processSafeCSSPropertyMap(sb, v)
|
||
|
|
||
|
case templ.KeyValue[string, string]:
|
||
|
return processStringKV(sb, v)
|
||
|
|
||
|
case templ.KeyValue[string, bool]:
|
||
|
if v.Value {
|
||
|
return processString(sb, v.Key)
|
||
|
}
|
||
|
return nil
|
||
|
|
||
|
case templ.KeyValue[templ.SafeCSS, bool]:
|
||
|
if v.Value {
|
||
|
return processSafeCSS(sb, v.Key)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Fall back to reflection.
|
||
|
|
||
|
// Handle functions first using reflection.
|
||
|
if handled, err := handleFuncWithReflection(sb, v); handled {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Handle slices using reflection before concrete types.
|
||
|
if handled, err := handleSliceWithReflection(sb, v); handled {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
_, err := sb.WriteString(TemplUnsupportedStyleAttributeValue)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func processSafeCSS(sb *strings.Builder, v templ.SafeCSS) error {
|
||
|
if v == "" {
|
||
|
return nil
|
||
|
}
|
||
|
sb.WriteString(html.EscapeString(string(v)))
|
||
|
if !strings.HasSuffix(string(v), ";") {
|
||
|
sb.WriteRune(';')
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func processString(sb *strings.Builder, v string) error {
|
||
|
if v == "" {
|
||
|
return nil
|
||
|
}
|
||
|
sanitized := strings.TrimSpace(safehtml.SanitizeStyleValue(v))
|
||
|
sb.WriteString(html.EscapeString(sanitized))
|
||
|
if !strings.HasSuffix(sanitized, ";") {
|
||
|
sb.WriteRune(';')
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var ErrInvalidStyleAttributeFunctionSignature = errors.New("invalid function signature, should be in the form func() (string, error)")
|
||
|
|
||
|
// handleFuncWithReflection handles functions using reflection.
|
||
|
func handleFuncWithReflection(sb *strings.Builder, v any) (bool, error) {
|
||
|
rv := reflect.ValueOf(v)
|
||
|
if rv.Kind() != reflect.Func {
|
||
|
return false, nil
|
||
|
}
|
||
|
|
||
|
t := rv.Type()
|
||
|
if t.NumIn() != 0 || (t.NumOut() != 1 && t.NumOut() != 2) {
|
||
|
return false, ErrInvalidStyleAttributeFunctionSignature
|
||
|
}
|
||
|
|
||
|
// Check the types of the return values
|
||
|
if t.NumOut() == 2 {
|
||
|
// Ensure the second return value is of type `error`
|
||
|
secondReturnType := t.Out(1)
|
||
|
if !secondReturnType.Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||
|
return false, fmt.Errorf("second return value must be of type error, got %v", secondReturnType)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
results := rv.Call(nil)
|
||
|
|
||
|
if t.NumOut() == 2 {
|
||
|
// Check if the second return value is an error
|
||
|
if errVal := results[1].Interface(); errVal != nil {
|
||
|
if err, ok := errVal.(error); ok && err != nil {
|
||
|
return true, err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true, sanitizeStyleAttributeValue(sb, results[0].Interface())
|
||
|
}
|
||
|
|
||
|
// handleSliceWithReflection handles slices using reflection.
|
||
|
func handleSliceWithReflection(sb *strings.Builder, v any) (bool, error) {
|
||
|
rv := reflect.ValueOf(v)
|
||
|
if rv.Kind() != reflect.Slice {
|
||
|
return false, nil
|
||
|
}
|
||
|
for i := 0; i < rv.Len(); i++ {
|
||
|
elem := rv.Index(i).Interface()
|
||
|
if err := sanitizeStyleAttributeValue(sb, elem); err != nil {
|
||
|
return true, err
|
||
|
}
|
||
|
}
|
||
|
return true, nil
|
||
|
}
|
||
|
|
||
|
// processStringMap processes a map[string]string.
|
||
|
func processStringMap(sb *strings.Builder, m map[string]string) error {
|
||
|
for _, name := range slices.Sorted(maps.Keys(m)) {
|
||
|
name, value := safehtml.SanitizeCSS(name, m[name])
|
||
|
sb.WriteString(html.EscapeString(name))
|
||
|
sb.WriteRune(':')
|
||
|
sb.WriteString(html.EscapeString(value))
|
||
|
sb.WriteRune(';')
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// processSafeCSSPropertyMap processes a map[string]templ.SafeCSSProperty.
|
||
|
func processSafeCSSPropertyMap(sb *strings.Builder, m map[string]templ.SafeCSSProperty) error {
|
||
|
for _, name := range slices.Sorted(maps.Keys(m)) {
|
||
|
sb.WriteString(html.EscapeString(safehtml.SanitizeCSSProperty(name)))
|
||
|
sb.WriteRune(':')
|
||
|
sb.WriteString(html.EscapeString(string(m[name])))
|
||
|
sb.WriteRune(';')
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// processStringKV processes a templ.KeyValue[string, string].
|
||
|
func processStringKV(sb *strings.Builder, kv templ.KeyValue[string, string]) error {
|
||
|
name, value := safehtml.SanitizeCSS(kv.Key, kv.Value)
|
||
|
sb.WriteString(html.EscapeString(name))
|
||
|
sb.WriteRune(':')
|
||
|
sb.WriteString(html.EscapeString(value))
|
||
|
sb.WriteRune(';')
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// getJoinedErrorsFromValues collects and joins errors from the input values.
|
||
|
func getJoinedErrorsFromValues(values ...any) error {
|
||
|
var errs []error
|
||
|
for _, v := range values {
|
||
|
if err, ok := v.(error); ok {
|
||
|
errs = append(errs, err)
|
||
|
}
|
||
|
}
|
||
|
return errors.Join(errs...)
|
||
|
}
|
||
|
|
||
|
// TemplUnsupportedStyleAttributeValue is the default value returned for unsupported types.
|
||
|
var TemplUnsupportedStyleAttributeValue = "zTemplUnsupportedStyleAttributeValue:Invalid;"
|