How to properly write Middleware in Go

Adding middleware in Go without breaking existing functionality is surprisingly difficult. This post explains how you can safely add middleware to Go.

To demonstrate, we'll add 2 middleware functions to a simple HTTP server:

{
  "type": "unordered-list",
  "id": "4d9dc7b8-2e00-4a34-91f3-569e2c7ae3f6",
  "items": [
    {
      "id": "4d9dc7b8-2e00-4a34-91f3-569e2c7ae3f6",
      "content": [
        {
          "type": "text",
          "content": "Response Time",
          "styles": [
            {
              "type": "bold"
            }
          ]
        },
        {
          "type": "text",
          "content": ": Adds an "
        },
        {
          "type": "code",
          "content": "X-Response-Time"
        },
        {
          "type": "text",
          "content": " response header with the response time in milliseconds."
        }
      ]
    },
    {
      "id": "fb291508-4646-47c5-87ac-1fa664d43ebf",
      "content": [
        {
          "type": "text",
          "content": "Response Logging",
          "styles": [
            {
              "type": "bold"
            }
          ]
        },
        {
          "type": "text",
          "content": ": Logs the response status, bytes written and duration"
        }
      ]
    }
  ]
}

We chose these middleware examples because they are simple and both need to write response headers after the handler runs. The solution below is not necessary if your middleware runs before calling your HTTP handler.

How we expect this to work

Let's first try implementing it without overthinking it:

{
  "id": "745f1a3c-08a3-4ea9-9d35-7d9d09ba511b",
  "type": "code",
  "content": [
    {
      "type": "text",
      "content": "package main\n\nfunc ResponseTime() func(next http.Handler) http.Handler {\n  return func(next http.Handler) http.Handler {\n    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n      now := time.Now()\n      next.ServeHTTP(w, r)\n      w.Header().Set(\"X-Response-Time\", time.Since(now))\n    })\n  }\n}\n\nfunc Log(log *log.Logger) func(next http.Handler) http.Handler {\n  return func(next http.Handler) http.Handler {\n    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n      now := time.Now()\n      next.ServeHTTP(w, r)\n      // wait... how do I get the response status and content length from next?\n      log.Printf(\"%d %d %dus\", status, length, time.Since(now))\n    })\n  }\n}\n\nfunc Handler(w http.ResponseWriter, r *http.Request) {\n  time.Sleep(time.Second)\n  w.Write([]byte(\"Howdy!\"))  \n}\n\n// Helper function to compose middleware together\nfunc compose(stack ... func(next http.Handler) http.Handler) http.Handler {\n  return func(h http.Handler) http.Handler {\n\t\tif len(stack) == 0 {\n\t\t\treturn h\n\t\t}\n\t\tfor i := len(stack) - 1; i \u003e= 0; i-- {\n\t\t\th = stack[i](h)\n\t\t}\n\t\treturn h\n\t}\n}\n\nfunc main() {\n  log := log.New(os.Stderr, \"\", log.Ldate|log.Ltime)\n  middleware := compose(\n    ResponseTime(),\n    Log(log),\n  )\n  fn := middleware(http.HandlerFunc(Handler))\n  http.ListenAndServe(\":8000\", fn)\n}"
    }
  ],
  "language": "go"
}

This simple example has 2 problems:

{
  "id": "29625531-ecd5-478b-a26b-a75c24cacbea",
  "type": "ordered-list",
  "items": [
    {
      "id": "29625531-ecd5-478b-a26b-a75c24cacbea",
      "content": [
        {
          "type": "text",
          "content": "The "
        },
        {
          "type": "code",
          "content": "Log"
        },
        {
          "type": "text",
          "content": " middleware doesn't have access to the response status and content length."
        }
      ]
    },
    {
      "id": "c4f6c7fc-abc6-4b6d-9950-19ce5d44fe1a",
      "content": [
        {
          "type": "text",
          "content": "The "
        },
        {
          "type": "code",
          "content": "ResponseTime"
        },
        {
          "type": "text",
          "content": " middleware writes a header "
        },
        {
          "type": "text",
          "content": "after ",
          "styles": [
            {
              "type": "italic"
            }
          ]
        },
        {
          "type": "text",
          "content": "the response status or body have been written. If you try and run this code... 💥. Just kidding, nothing happens. You just won't see your "
        },
        {
          "type": "code",
          "content": "X-Response-Time"
        },
        {
          "type": "text",
          "content": " response header. It silently fails."
        }
      ]
    }
  ]
}

Let's start by fixing the first problem.

The Log middleware doesn't have access to the response status and content length.

This one is an easy fix. http.ResponseWriter is an interface that we can wrap.

We'll create a responseWriter wrapper that implements http.ResponseWriter . This wrapper intercepts the data that passes through, storing the status code and computing the content length.

{
  "id": "c8280464-129b-4862-8449-835497ac6552",
  "type": "code",
  "content": [
    {
      "type": "text",
      "content": "type responseWriter struct {\n  http.ResponseWriter\n  status int\n  length int\n}\n\n// Ensure that responseWriter implements `http.ResponseWriter`\nvar _ http.ResponseWriter = (*responseWriter)(nil)\n\nfunc (w *responseWriter) WriteHeader(status int) {\n\tw.status = status\n\tw.ResponseWriter.WriteHeader(status)\n}\n\nfunc (w *responseWriter) Write(b []byte) (int, error) {\n\tif w.status == 0 {\n\t\tw.status = 200\n\t}\n\tw.length += len(b)\n\treturn w.ResponseWriter.Write(b)\n}\n\nfunc Log(log *log.Logger) func(next http.Handler) http.Handler {\n  return func(next http.Handler) http.Handler {\n    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n      now := time.Now()\n      // Wrap http.ResponseWriter inside responseWriter\n      rw := \u0026responseWriter{ResponseWriter: w}\n      next.ServeHTTP(rw, r)\n      // Log the response status, content length and response time.\n      log.Printf(\"%d %d %dus\", rw.status, rw.length, time.Since(now))\n    })\n  }\n}"
    }
  ],
  "language": "go"
}

By wrapping the http.ResponseWriter and intercepting the data that passes through the interface, we can take note of the the response status and compute the content length.

With this simple adjustment, we silently just broke a lot of functionality. The main problem is that http.ResponseWriter has many optional interfaces.

For example, http.ResponseWriter also implements the http.Flusher interface for HTTP/1.x and HTTP/2 connections. Because our wrapper only implements http.ResponseWriter, any downstream handlers won't be able to flush their buffered data.

It sounds like we just need to implement more methods, right? That's not a perfect solution either because some of these interfaces are hard to fake when the underlying http.ResponseWriter doesn't provide the implementation[1]

Fortunately, there's a solution: Felix Geisendörfer's wonderful httpsnoop package.

Wrapping middleware with httpsnoop

This package uses code generation to check all known combinations of implementations before wrapping.

Here's how you use it:

{
  "id": "017fea07-83bd-49e2-9d0a-344d125b1bf8",
  "type": "code",
  "content": [
    {
      "type": "text",
      "content": "func Log(log *log.Logger) func(next http.Handler) http.Handler {\n  return func(next http.Handler) http.Handler {\n    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n      status := 0\n      length := 0\n      now := time.Now()\n      // Wrap the WriteHeader and Write function calls\n      wrapped := httpsnoop.Wrap(w, httpsnoop.Hooks{\n\t\t\t  WriteHeader: func(fn httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {\n\t\t\t\t  return func(code int) {\n            status = code\n\t\t\t\t\t  fn(code)\n\t\t\t\t  }\n\t\t\t  },\n\t\t\t  Write: func(fn httpsnoop.WriteFunc) httpsnoop.WriteFunc {\n\t\t\t\t  return func(b []byte) (int, error) {\n            if status == 0 {\n              status = 200\n            }\n            length += len(b)\n\t\t\t\t\t  return fn(b)\n\t\t\t\t  }\n\t\t\t  },\n\t\t  })\n\t\t  next.ServeHTTP(wrapped, r)\n      // Log the response status, content length and response time.\n      log.Printf(\"%d %d %dus\", status, length, time.Since(now))\n    })\n  }\n}"
    }
  ],
  "language": "go"
}

Okay, this looks super complicated, but I promise it's not too bad.

We wrap the http.ResponseWriter with httpsnoop.Wrap and provide the hooks we want to intercept. In this example we've hooked into our response writer's WriteHeader and Write methods. Any hook we do not implement (e.g. Header() http.Header) will be passed through.

And there you have it. We've solved the first problem. Let's now turn our attention the second problem.

The ResponseTime middleware writes a header after the response status or body have been written.

For this one, we somehow need to write our response headers before we write our status code or response body.

Doesn't that sound familiar? We know how to do this already! In the previous problem, we intercepted the http.ResponseWriter to record the status code and compute the content length before responding. Instead of storing data, how about we write the headers before calling the methods

Let's apply the same technique to response headers.

{
  "id": "1e6ff8c4-9cf2-42ac-9464-e133ca1b948b",
  "type": "code",
  "content": [
    {
      "type": "text",
      "content": "func ResponseTime() func(next http.Handler) http.Handler {\n  return func(next http.Handler) http.Handler {\n    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n      now := time.Now()\n      // Wrap the WriteHeader and Write function calls\n      wrapped := httpsnoop.Wrap(w, httpsnoop.Hooks{\n\t\t\t  WriteHeader: func(fn httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {\n\t\t\t\t  return func(code int) {\n            w.Header().Set(\"X-Response-Time\", time.Since(now))\n\t\t\t\t\t  fn(code)\n\t\t\t\t  }\n\t\t\t  },\n\t\t\t  Write: func(fn httpsnoop.WriteFunc) httpsnoop.WriteFunc {\n\t\t\t\t  return func(b []byte) (int, error) {\n            w.Header().Set(\"X-Response-Time\", time.Since(now))\n\t\t\t\t\t  return fn(b)\n\t\t\t\t  }\n\t\t\t  },\n\t\t  })\n\t\t  next.ServeHTTP(wrapped, r)\n    })\n  }\n}"
    }
  ],
  "language": "go"
}

Yup, one wrapper per middleware

You may be tempted to consolidate your httpsnoop wrappers into one top-level middleware. Resist the urge! The only way to do this is by buffering the response. If you buffer the response, you lose the ability to stream responses back. Fortunately httpsnoop doesn't add much overhead [2] and you only need to wrap the response writer if your middleware has something to do after you call your handler.

All Together Now

Here's the full file in all it's wrapped glory.

{
  "id": "d80fdce6-510f-4d67-976f-fbc5ae24ace2",
  "type": "code",
  "content": [
    {
      "type": "text",
      "content": "package main\n\nfunc ResponseTime() func(next http.Handler) http.Handler {\n  return func(next http.Handler) http.Handler {\n    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n      now := time.Now()\n      // Wrap the WriteHeader and Write function calls\n      wrapped := httpsnoop.Wrap(w, httpsnoop.Hooks{\n\t\t\t  WriteHeader: func(fn httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {\n\t\t\t\t  return func(code int) {\n            w.Header().Set(\"X-Response-Time\", time.Since(now))\n\t\t\t\t\t  fn(code)\n\t\t\t\t  }\n\t\t\t  },\n\t\t\t  Write: func(fn httpsnoop.WriteFunc) httpsnoop.WriteFunc {\n\t\t\t\t  return func(b []byte) (int, error) {\n            w.Header().Set(\"X-Response-Time\", time.Since(now))\n\t\t\t\t\t  return fn(b)\n\t\t\t\t  }\n\t\t\t  },\n\t\t  })\n\t\t  next.ServeHTTP(wrapped, r)\n    })\n  }\n}\n\nfunc Log(log *log.Logger) func(next http.Handler) http.Handler {\n  return func(next http.Handler) http.Handler {\n    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {\n      status := 0\n      length := 0\n      now := time.Now()\n      // Wrap the WriteHeader and Write function calls\n      wrapped := httpsnoop.Wrap(w, httpsnoop.Hooks{\n\t\t\t  WriteHeader: func(fn httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {\n\t\t\t\t  return func(code int) {\n            status = code\n\t\t\t\t\t  fn(code)\n\t\t\t\t  }\n\t\t\t  },\n\t\t\t  Write: func(fn httpsnoop.WriteFunc) httpsnoop.WriteFunc {\n\t\t\t\t  return func(b []byte) (int, error) {\n            if status == 0 {\n              status = 200\n            }\n            length += len(b)\n\t\t\t\t\t  return fn(b)\n\t\t\t\t  }\n\t\t\t  },\n\t\t  })\n\t\t  next.ServeHTTP(wrapped, r)\n      // Log the response status, content length and response time.\n      log.Printf(\"%d %d %dus\", status, length, time.Since(now))\n    })\n  }\n}\n\nfunc Handler(w http.ResponseWriter, r *http.Request) {\n  time.Sleep(time.Second)\n  w.Write([]byte(\"Howdy!\"))  \n}\n\n// Helper function to compose middleware together\nfunc compose(stack ... func(next http.Handler) http.Handler) http.Handler {\n  return func(h http.Handler) http.Handler {\n\t\tif len(stack) == 0 {\n\t\t\treturn h\n\t\t}\n\t\tfor i := len(stack) - 1; i \u003e= 0; i-- {\n\t\t\th = stack[i](h)\n\t\t}\n\t\treturn h\n\t}\n}\n\nfunc main() {\n  log := log.New(os.Stderr, \"\", log.Ldate|log.Ltime)\n  middleware := compose(\n    ResponseTime(),\n    Log(log),\n  )\n  fn := middleware(http.HandlerFunc(Handler))\n  http.ListenAndServe(\":8000\", fn)\n}"
    }
  ],
  "language": "go"
}

Certainly not the most beautiful thing in the world, but we successfully created middleware without losing features. Fortunately, all these details can be tucked away in neat and tidy middleware packages.

{
  "id": "a3dab8e0-d749-4008-a37c-4bc7c2d9cc5b",
  "type": "divider"
}

Enjoyed this post? To get notified of future posts like this, enter your email address below. No spam, I promise.

{
  "id": "8301899d-8368-418c-a04b-f0afb67bff6d",
  "type": "callout",
  "icon": {
    "type": "emoji",
    "emoji": "📩"
  },
  "content": [
    {
      "type": "text",
      "content": "Your email..."
    }
  ],
  "styles": [
    {
      "type": "background",
      "color": "gray"
    }
  ]
}

Don't use email? I also have an RSS feed.

{
  "id": "5fae8253-e2d9-4f46-9341-52deb74e1791",
  "type": "divider"
}

And... that's a wrap!

References

[1] I'll spare you all the details here but if you're curious I highly recommend reading Felix's excellent explanation in his httpsnoop package. https://github.com/felixge/httpsnoop#why-this-package-exists

[2] https://github.com/felixge/httpsnoop#performance