Accessing the underlying socket of a net/http response

After Issue #30694 is completed, it looks like Go 1.13 will probably support storing the net.Conn in the Request Context, which makes this fairly clean and simple:

package main

import (
  "net/http"
  "context"
  "net"
  "log"
)

type contextKey struct {
  key string
}
var ConnContextKey = &contextKey{"http-conn"}
func SaveConnInContext(ctx context.Context, c net.Conn) (context.Context) {
  return context.WithValue(ctx, ConnContextKey, c)
}
func GetConn(r *http.Request) (net.Conn) {
  return r.Context().Value(ConnContextKey).(net.Conn)
}

func main() {
  http.HandleFunc("/", myHandler)

  server := http.Server{
    Addr: ":8080",
    ConnContext: SaveConnInContext,
  }
  server.ListenAndServe()
}

func myHandler(w http.ResponseWriter, r *http.Request) {
  conn := GetConn(r)
  ...
}

Until then ... For a server listening on a TCP port, net.Conn.RemoteAddr().String() is unique for each connection and is available to the http.Handler as r.RemoteAddr, so it can be used as a key to a global map of Conns:

package main
import (
  "net/http"
  "net"
  "fmt"
  "log"
)

var conns = make(map[string]net.Conn)
func ConnStateEvent(conn net.Conn, event http.ConnState) {
  if event == http.StateActive {
    conns[conn.RemoteAddr().String()] = conn
  } else if event == http.StateHijacked || event == http.StateClosed {
    delete(conns, conn.RemoteAddr().String())
  }
}
func GetConn(r *http.Request) (net.Conn) {
  return conns[r.RemoteAddr]
}

func main() {
  http.HandleFunc("/", myHandler)

  server := http.Server{
    Addr: ":8080",
    ConnState: ConnStateEvent,
  }
  server.ListenAndServe()
}

func myHandler(w http.ResponseWriter, r *http.Request) {
  conn := GetConn(r)
  ...
}

For a server listening on a UNIX socket, net.Conn.RemoteAddr().String() is always "@", so the above doesn't work. To make this work, we can override net.Listener.Accept(), and use that to override net.Conn.RemoteAddr().String() so that it returns a unique string for each connection:

package main

import (
  "net/http"
  "net"
  "os"
  "golang.org/x/sys/unix"
  "fmt"
  "log"
)

func main() {
  http.HandleFunc("/", myHandler)

  listenPath := "/var/run/go_server.sock"
  l, err := NewUnixListener(listenPath)
  if err != nil {
    log.Fatal(err)
  }
  defer os.Remove(listenPath)

  server := http.Server{
    ConnState: ConnStateEvent,
  }
  server.Serve(NewConnSaveListener(l))
}

func myHandler(w http.ResponseWriter, r *http.Request) {
  conn := GetConn(r)
  if unixConn, isUnix := conn.(*net.UnixConn); isUnix {
    f, _ := unixConn.File()
    pcred, _ := unix.GetsockoptUcred(int(f.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
    f.Close()
    log.Printf("Remote UID: %d", pcred.Uid)
  }
}

var conns = make(map[string]net.Conn)
type connSaveListener struct {
  net.Listener
}
func NewConnSaveListener(wrap net.Listener) (net.Listener) {
  return connSaveListener{wrap}
}
func (self connSaveListener) Accept() (net.Conn, error) {
  conn, err := self.Listener.Accept()
  ptrStr := fmt.Sprintf("%d", &conn)
  conns[ptrStr] = conn
  return remoteAddrPtrConn{conn, ptrStr}, err
}
func GetConn(r *http.Request) (net.Conn) {
  return conns[r.RemoteAddr]
}
func ConnStateEvent(conn net.Conn, event http.ConnState) {
  if event == http.StateHijacked || event == http.StateClosed {
    delete(conns, conn.RemoteAddr().String())
  }
}
type remoteAddrPtrConn struct {
  net.Conn
  ptrStr string
}
func (self remoteAddrPtrConn) RemoteAddr() (net.Addr) {
  return remoteAddrPtr{self.ptrStr}
}
type remoteAddrPtr struct {
  ptrStr string
}
func (remoteAddrPtr) Network() (string) {
  return ""
}
func (self remoteAddrPtr) String() (string) {
  return self.ptrStr
}

func NewUnixListener(path string) (net.Listener, error) {
  if err := unix.Unlink(path); err != nil && !os.IsNotExist(err) {
    return nil, err
  }
  mask := unix.Umask(0777)
  defer unix.Umask(mask)

  l, err := net.Listen("unix", path)
  if err != nil {
    return nil, err
  }

  if err := os.Chmod(path, 0660); err != nil {
    l.Close()
    return nil, err
  }

  return l, nil
}

Note that although in current implementation http.ResponseWriter is a *http.response (note the lowercase!) which holds the connection, the field is unexported and you can't access it.

Instead take a look at the Server.ConnState hook: you can "register" a function which will be called when the connection state changes, see http.ConnState for details. For example you will get the net.Conn even before the request enters the handler (http.StateNew and http.StateActive states).

You can install a connection state listener by creating a custom Server like this:

func main() {
    http.HandleFunc("/", myHandler)

    s := &http.Server{
        Addr:           ":8081",
        ReadTimeout:    10 * time.Second,
        WriteTimeout:   10 * time.Second,
        MaxHeaderBytes: 1 << 20,
        ConnState:      ConnStateListener,
    }
    panic(s.ListenAndServe())
}

func ConnStateListener(c net.Conn, cs http.ConnState) {
    fmt.Printf("CONN STATE: %v, %v\n", cs, c)
}

This way you will have exactly the desired net.Conn even before (and also during and after) invoking the handler. The downside is that it is not "paired" with the ResponseWriter, you have to do that manually if you need that.


This can be done with reflection. it's a bit "dirty" but it works:

package main

import "net/http"
import "fmt"
import "runtime"

import "reflect"

func myHandler(w http.ResponseWriter, r *http.Request) {

    ptrVal := reflect.ValueOf(w)
    val := reflect.Indirect(ptrVal)

    // w is a "http.response" struct from which we get the 'conn' field
    valconn := val.FieldByName("conn")
    val1 := reflect.Indirect(valconn)

    // which is a http.conn from which we get the 'rwc' field
    ptrRwc := val1.FieldByName("rwc").Elem()
    rwc := reflect.Indirect(ptrRwc)

    // which is net.TCPConn from which we get the embedded conn
    val1conn := rwc.FieldByName("conn")
    val2 := reflect.Indirect(val1conn)

    // which is a net.conn from which we get the 'fd' field
    fdmember := val2.FieldByName("fd")
    val3 := reflect.Indirect(fdmember)

    // which is a netFD from which we get the 'sysfd' field
    netFdPtr := val3.FieldByName("sysfd")
    fmt.Printf("netFDPtr= %v\n", netFdPtr)

    // which is the system socket (type is plateform specific - Int for linux)
    if runtime.GOOS == "linux" {
        fd := int(netFdPtr.Int())
        fmt.Printf("fd = %v\n", fd)
        // fd is the socket - we can call unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd),....) on it for instance
    }

    fmt.Fprintf(w, "Hello World")
}

func main() {
    http.HandleFunc("/", myHandler)
    err := http.ListenAndServe(":8081", nil)
    fmt.Println(err.Error())
}

Ideally the library should be augmented with a method to get the underlying net.Conn


You can use an HttpHijacker to take over the TCP connection from the ResponseWriter. Once you've done that you're free to use the socket to do whatever you want.

See http://golang.org/pkg/net/http/#Hijacker, which also contains a good example.

Tags:

Sockets

Go