httpsify.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. package main
  2. import (
  3. "crypto/tls"
  4. "flag"
  5. "fmt"
  6. "io"
  7. "log"
  8. "net"
  9. "net/http"
  10. "net/http/httputil"
  11. "net/url"
  12. "strings"
  13. "golang.org/x/crypto/acme/autocert"
  14. )
  15. var (
  16. // CMD options
  17. listen = flag.String("listen", ":443", "the local listen address")
  18. domains = flag.String("domains", "", "a comma separated strings of domain[->[ip]:port]")
  19. backend = flag.String("backend", ":80", "the default backend to be used")
  20. sslCacheDir = flag.String("ssl-cache-dir", "./httpsify-ssl-cache", "the cache directory to cache generated ssl certs")
  21. // internal vars
  22. domain_backend = map[string]string{}
  23. whitelisted = []string{}
  24. )
  25. func main() {
  26. flag.Parse()
  27. if *domains == "" {
  28. flag.Usage()
  29. fmt.Println(`Example(1): httpsify -domains "example.org,api.example.org->localhost:366, api2.example.org->:367"`)
  30. fmt.Println(`Example(2): httpsify -domains "www.site.com,apiv1.site.com->:8080,apiv2.site.com->:8081"`)
  31. return
  32. }
  33. for _, zone := range strings.Split(*domains, ",") {
  34. parts := strings.SplitN(zone, "->", 2)
  35. if len(parts) < 2 {
  36. parts = append(parts, *backend)
  37. }
  38. parts[1] = fixUrl(parts[1])
  39. domain_backend[parts[0]] = parts[1]
  40. whitelisted = append(whitelisted, parts[0])
  41. }
  42. m := autocert.Manager{
  43. Prompt: autocert.AcceptTOS,
  44. HostPolicy: autocert.HostWhitelist(whitelisted...),
  45. Cache: autocert.DirCache(*sslCacheDir),
  46. }
  47. h := handler()
  48. s := &http.Server{
  49. Addr: *listen,
  50. Handler: h,
  51. TLSConfig: &tls.Config{GetCertificate: m.GetCertificate},
  52. }
  53. log.Fatal(s.ListenAndServeTLS("", ""))
  54. }
  55. // fix the specified url
  56. // this function will make sure that "http://" already exists,
  57. // also it will make sure that it has a hostname .
  58. func fixUrl(u string) string {
  59. u = strings.TrimPrefix(strings.TrimSpace(u), "https://")
  60. if strings.Index(u, ":") == 0 {
  61. u = "localhost" + u
  62. }
  63. if !strings.HasPrefix(u, "ws://") && !strings.HasPrefix(u, "http://") {
  64. u = "http://" + u
  65. }
  66. u = strings.TrimRight(u, "/")
  67. return u
  68. }
  69. // the proxy handler
  70. func handler() http.Handler {
  71. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  72. r.Host = strings.SplitN(r.Host, ":", 2)[0]
  73. if _, found := domain_backend[r.Host]; !found {
  74. http.Error(w, r.Host+": not found", http.StatusNotImplemented)
  75. return
  76. }
  77. r.Header["X-Forwarded-Proto"] = []string{"https"}
  78. r.Header["X-Forwarded-For"] = append(r.Header["X-Forwarded-For"], strings.SplitN(r.RemoteAddr, ":", 2)[0])
  79. u, _ := url.Parse(domain_backend[r.Host] + "/" + strings.TrimLeft(r.URL.RequestURI(), "/"))
  80. if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
  81. NewWebsocketReverseProxy(u).ServeHTTP(w, r)
  82. return
  83. } else {
  84. proxy := httputil.NewSingleHostReverseProxy(u)
  85. defaultDirector := proxy.Director
  86. proxy.Director = func(req *http.Request) {
  87. defaultDirector(req)
  88. req.Host = r.Host
  89. req.URL = u
  90. }
  91. proxy.ServeHTTP(w, r)
  92. return
  93. }
  94. })
  95. }
  96. // the websocket proxy handler
  97. func NewWebsocketReverseProxy(u *url.URL) http.Handler {
  98. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  99. backConn, err := net.Dial("tcp", u.Host)
  100. if err != nil {
  101. http.Error(w, err.Error(), http.StatusInternalServerError)
  102. return
  103. }
  104. defer backConn.Close()
  105. hj, ok := w.(http.Hijacker)
  106. if !ok {
  107. http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
  108. return
  109. }
  110. clientConn, _, err := hj.Hijack()
  111. if err != nil {
  112. http.Error(w, err.Error(), http.StatusInternalServerError)
  113. return
  114. }
  115. defer clientConn.Close()
  116. message := r.Method + " " + r.URL.RequestURI() + " " + r.Proto + "\n"
  117. message += "Host: " + r.Host + "\n"
  118. for k, vals := range r.Header {
  119. for _, v := range vals {
  120. message += k + ": " + v + "\n"
  121. }
  122. }
  123. message += "\n"
  124. go io.Copy(backConn, io.MultiReader(strings.NewReader(message), r.Body, clientConn))
  125. io.Copy(clientConn, backConn)
  126. })
  127. }