From 387149656a426b76bd1e7c5dd48d19b55cc6ebbb Mon Sep 17 00:00:00 2001
From: "Dustin L. Howett" <dustin@howett.net>
Date: Mon, 16 Oct 2017 01:22:08 -0700
Subject: [PATCH] Remove RangeFetcher.Initialize (do it once w/ sync),
 Length->ExpectedLength

---
 http.go      | 69 +++++++++++++++++++++++++++++++---------------------
 http_test.go |  2 +-
 range.go     |  3 +--
 reader.go    |  7 +-----
 4 files changed, 44 insertions(+), 37 deletions(-)

diff --git a/http.go b/http.go
index 66ea440..2537e47 100644
--- a/http.go
+++ b/http.go
@@ -9,6 +9,7 @@ import (
 	"net/http"
 	"net/url"
 	"strings"
+	"sync"
 )
 
 const httpMethodGet = "GET"
@@ -38,7 +39,8 @@ type HTTPRanger struct {
 
 	validator string
 	length    int64
-	blockSize int
+
+	once sync.Once
 }
 
 func statusCodeError(status int) error {
@@ -63,40 +65,46 @@ func validatorFromResponse(resp *http.Response) (string, error) {
 	return "", errors.New("no applicable validator in response")
 }
 
-// Initialize implements the Initialize function from the RangeFetcher interface.
-// It performs a HEAD request to retrieve the required information from the server.
-func (r *HTTPRanger) Initialize(bs int) error {
-	if r.Client == nil {
-		r.Client = &http.Client{}
-	}
+// init performs a HEAD request to determine whether the resource is rangeable.
+func (r *HTTPRanger) init() error {
+	var outerErr error
+	r.once.Do(func() {
+		if r.Client == nil {
+			r.Client = &http.Client{}
+		}
 
-	resp, err := r.Client.Head(r.URL.String())
-	if err != nil {
-		return err
-	}
+		resp, err := r.Client.Head(r.URL.String())
+		if err != nil {
+			outerErr = err
+			return
+		}
 
-	if !statusIsAcceptable(resp.StatusCode) {
-		return statusCodeError(resp.StatusCode)
-	}
+		if !statusIsAcceptable(resp.StatusCode) {
+			outerErr = statusCodeError(resp.StatusCode)
+			return
+		}
 
-	if !strings.Contains(resp.Header.Get(httpHeaderAcceptRanges), "bytes") {
-		return errors.New(r.URL.String() + " does not support byte-ranged requests.")
-	}
+		if !strings.Contains(resp.Header.Get(httpHeaderAcceptRanges), "bytes") {
+			outerErr = errors.New(r.URL.String() + " does not support byte-ranged requests.")
+			return
+		}
 
-	validator, err := validatorFromResponse(resp)
-	if err != nil {
-		return errors.New(r.URL.String() + " did not offer a strong-enough validator for subsequent requests")
-	}
+		validator, err := validatorFromResponse(resp)
+		if err != nil {
+			outerErr = errors.New(r.URL.String() + " did not offer a strong-enough validator for subsequent requests")
+			return
+		}
 
-	r.blockSize = bs
-	r.validator = validator
-	r.length = resp.ContentLength
-	return nil
+		r.validator = validator
+		r.length = resp.ContentLength
+	})
+	return outerErr
 }
 
-// Length returns the length, in bytes, of the ranged-over file.
-func (r *HTTPRanger) Length() int64 {
-	return r.length
+// ExpectedLength returns the length, in bytes, of the ranged-over file.
+func (r *HTTPRanger) ExpectedLength() (int64, error) {
+	err := r.init()
+	return r.length, err
 }
 
 func makeByteRangeHeader(ranges []ByteRange) string {
@@ -135,6 +143,11 @@ func (r *HTTPRanger) FetchRanges(ranges []ByteRange) ([]Block, error) {
 		return nil, nil
 	}
 
+	err := r.init()
+	if err != nil {
+		return nil, err
+	}
+
 	req, err := http.NewRequest(httpMethodGet, r.URL.String(), nil)
 	if err != nil {
 		return nil, err
diff --git a/http_test.go b/http_test.go
index 7e95213..8aabbe2 100644
--- a/http_test.go
+++ b/http_test.go
@@ -8,7 +8,7 @@ import (
 func TestFailureToConnect(t *testing.T) {
 	u, _ := url.Parse("http://257.0.1.258/file")
 	r := &HTTPRanger{URL: u}
-	err := r.Initialize(1048576)
+	err := r.init()
 	if err == nil {
 		t.Fail()
 	} else {
diff --git a/range.go b/range.go
index 85f1c6b..0966f0e 100644
--- a/range.go
+++ b/range.go
@@ -9,8 +9,7 @@ package ranger
 // Initialize, called once and passed the Reader's block size, performs any necessary setup tasks for the RangeFetcher
 type RangeFetcher interface {
 	FetchRanges([]ByteRange) ([]Block, error)
-	Length() int64
-	Initialize(int) error
+	ExpectedLength() (int64, error)
 }
 
 // Block represents a block returned from a ranged read
diff --git a/reader.go b/reader.go
index ad9b143..0d178a1 100644
--- a/reader.go
+++ b/reader.go
@@ -198,12 +198,7 @@ func (r *Reader) init() (err error) {
 			r.BlockSize = DefaultBlockSize
 		}
 
-		err = r.Fetcher.Initialize(r.BlockSize)
-		if err != nil {
-			return
-		}
-
-		r.len = r.Fetcher.Length()
+		r.len, err = r.Fetcher.ExpectedLength()
 	})
 	return
 }
-- 
GitLab