精慕HU
通过阅读github.com/gorilla/websocket 的源代码,我发现函数中有一个逻辑(*Dialer) DialContext将 http.Request 指针传递给可定制的函数,这意味着我可以编写一个与forProxy做同样事情的函数http.RoundTripper注入标头。WebSocketpackage transportimport ( "context" "crypto/sha256" "encoding/hex" "net/http" "net/url" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/gorilla/websocket")func NewWebSocketDialer(config aws.Config) (*websocket.Dialer, error) { return &websocket.Dialer{ HandshakeTimeout: 45 * time.Second, Proxy: func(request *http.Request) (*url.URL, error) { credentials, err := config.Credentials.Retrieve(request.Context()) if err != nil { return nil, err } // Because AWS may sign some unrelated headers and cause authentication failure, you need to create a blank request. internalRequest, err := http.NewRequest(http.MethodGet, request.URL.String(), nil) if err != nil { return nil, err } header := request.Header.Clone() hash := sha256.New() signer := v4.NewSigner() if err := signer.SignHTTP(context.Background(), credentials, internalRequest, hex.EncodeToString(hash.Sum(nil)), "managedblockchain", config.Region, time.Now()); err != nil { return nil, err } request.Header = internalRequest.Header request.Header.Set("Connection", header["Connection"][0]) request.Header.Set("Sec-WebSocket-Key", header["Sec-WebSocket-Key"][0]) request.Header.Set("Sec-WebSocket-Version", header["Sec-WebSocket-Version"][0]) request.Header.Set("Upgrade", header["Upgrade"][0]) return http.ProxyFromEnvironment(request) }, }, nil}HTTPpackage transportimport ( "compress/gzip" "context" "crypto/sha256" "encoding/base64" "encoding/hex" "io" "net/http" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/signer/v4")var _ http.RoundTripper = &httpRoundTripper{}type httpRoundTripper struct { config aws.Config}func (h httpRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { credentials, err := h.config.Credentials.Retrieve(request.Context()) if err != nil { return nil, err } internalRequest := request.Clone(request.Context()) bodyReader, err := request.GetBody() if err != nil { return nil, err } hash := sha256.New() if _, err := io.Copy(hash, bodyReader); err != nil { return nil, err } signer := v4.NewSigner() if err := signer.SignHTTP(context.Background(), credentials, internalRequest, hex.EncodeToString(hash.Sum(nil)), "managedblockchain", h.config.Region, time.Now()); err != nil { return nil, err } response, err := h.config.HTTPClient.Do(internalRequest) if err != nil { return nil, err } if response.Header.Get("Content-Type") == "gzip" { gzipReader, err := gzip.NewReader(base64.NewDecoder(base64.StdEncoding, response.Body)) if err != nil { return nil, err } request.Header.Set("Content-Type", "application/json") response.Body = gzipReader } return response, nil}func NewHttpRoundTripper(cfg aws.Config) http.RoundTripper { return httpRoundTripper{ config: cfg, }}