Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions cgi.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,35 @@ func addKnownVariablesToServer(fc *frankenPHPContext, trackVarsArray *C.zval) {
}

func addHeadersToServer(ctx context.Context, request *http.Request, trackVarsArray *C.zval) {
var uncommonKeys, uncommonValues []string

for field, val := range request.Header {
if k := commonHeaders[field]; k != nil {
v := strings.Join(val, ", ")
C.frankenphp_register_known_variable(k, toUnsafeChar(v), C.size_t(len(v)), trackVarsArray)

continue
}

// if the header name could not be cached, it needs to be registered safely
// this is more inefficient but allows additional sanitizing by PHP
k := phpheaders.GetUnCommonHeader(ctx, field)
v := strings.Join(val, ", ")
C.frankenphp_register_variable_safe(toUnsafeChar(k), toUnsafeChar(v), C.size_t(len(v)), trackVarsArray)
// lazily allocate only when the first uncommon header is seen
if uncommonKeys == nil {
uncommonKeys = make([]string, 0, 4)
uncommonValues = make([]string, 0, 4)
}

uncommonKeys = append(uncommonKeys, field)
uncommonValues = append(uncommonValues, strings.Join(val, ", "))
}

if uncommonKeys == nil {
return
}

// uncommon header names need to be registered safely
// this is more inefficient but allows additional sanitizing by PHP
phpKeys := phpheaders.GetUnCommonHeaders(ctx, uncommonKeys)
for i, v := range uncommonValues {
C.frankenphp_register_variable_safe(toUnsafeChar(phpKeys[i]), toUnsafeChar(v), C.size_t(len(v)), trackVarsArray)
}
}

Expand Down
30 changes: 17 additions & 13 deletions internal/phpheaders/phpheaders.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,25 @@ var CommonRequestHeaders = map[string]string{

// Cache up to 256 uncommon headers
// This is ~2.5x faster than converting the header each time
var headerKeyCache = otter.Must[string, string](&otter.Options[string, string]{MaximumSize: 256})
var (
headerKeyCache = otter.Must[string, string](&otter.Options[string, string]{MaximumSize: 256})
headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_")
loader = otter.LoaderFunc[string, string](func(_ context.Context, key string) (string, error) {
return "HTTP_" + headerNameReplacer.Replace(strings.ToUpper(key)) + "\x00", nil
})
)

var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_")
// GetUnCommonHeaders returns PHP header keys aligned with the input keys slice.
func GetUnCommonHeaders(ctx context.Context, keys []string) []string {
phpHeaderKeys := make([]string, len(keys))
for i, key := range keys {
phpHeaderKey, err := headerKeyCache.Get(ctx, key, loader)
if err != nil {
panic(err)
}

func GetUnCommonHeader(ctx context.Context, key string) string {
phpHeaderKey, err := headerKeyCache.Get(
ctx,
key,
otter.LoaderFunc[string, string](func(_ context.Context, key string) (string, error) {
return "HTTP_" + headerNameReplacer.Replace(strings.ToUpper(key)) + "\x00", nil
}),
)
if err != nil {
panic(err)
phpHeaderKeys[i] = phpHeaderKey
}

return phpHeaderKey
return phpHeaderKeys
}
11 changes: 8 additions & 3 deletions internal/phpheaders/phpheaders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ import (
)

func TestAllCommonHeadersAreCorrect(t *testing.T) {
keys := make([]string, 0, len(CommonRequestHeaders))
for k := range CommonRequestHeaders {
keys = append(keys, k)
}
phpKeys := GetUnCommonHeaders(t.Context(), keys)
fakeRequest := httptest.NewRequest("GET", "http://localhost", nil)

for header, phpHeader := range CommonRequestHeaders {
for i, header := range keys {
phpHeader := CommonRequestHeaders[header]
// verify that common and uncommon headers return the same result
expectedPHPHeader := GetUnCommonHeader(t.Context(), header)
assert.Equal(t, phpHeader+"\x00", expectedPHPHeader, "header is not well formed: "+phpHeader)
assert.Equal(t, phpHeader+"\x00", phpKeys[i], "header is not well formed: "+phpHeader)

// net/http will capitalize lowercase headers, verify that headers are capitalized
fakeRequest.Header.Add(header, "foo")
Expand Down