forked from commercialhaskell/stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVerified.hs
More file actions
320 lines (288 loc) · 12.5 KB
/
Verified.hs
File metadata and controls
320 lines (288 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
module Network.HTTP.Download.Verified
( verifiedDownload
, recoveringHttp
, DownloadRequest(..)
, drRetryPolicyDefault
, HashCheck(..)
, CheckHexDigest(..)
, LengthCheck
, VerifiedDownloadException(..)
) where
import qualified Data.List as List
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Base64 as B64
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.List as CL
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import Control.Monad
import Control.Monad.Catch (Handler (..)) -- would be nice if retry exported this itself
import Stack.Prelude hiding (Handler (..))
import Control.Retry (recovering,limitRetries,RetryPolicy,constantDelay,RetryStatus(..))
import Crypto.Hash
import Crypto.Hash.Conduit (sinkHash)
import Data.ByteArray as Mem (convert)
import Data.ByteArray.Encoding as Mem (convertToBase, Base(Base16))
import Data.ByteString.Char8 (readInteger)
import Data.Conduit
import Data.Conduit.Binary (sourceHandle)
import Data.Text.Encoding (decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import GHC.IO.Exception (IOException(..),IOErrorType(..))
import Network.HTTP.Client (getUri, path)
import Network.HTTP.StackClient (httpSink)
import Network.HTTP.Simple (Request, HttpException, getResponseHeaders)
import Network.HTTP.Types.Header (hContentLength, hContentMD5)
import Path
import Stack.Types.Runner
import Stack.PrettyPrint
import System.Directory
import qualified System.FilePath as FP ((<.>))
-- | A request together with some checks to perform.
data DownloadRequest = DownloadRequest
{ drRequest :: Request
, drHashChecks :: [HashCheck]
, drLengthCheck :: Maybe LengthCheck
, drRetryPolicy :: RetryPolicy
}
-- | Default to retrying thrice with a short constant delay.
drRetryPolicyDefault :: RetryPolicy
drRetryPolicyDefault = limitRetries 3 <> constantDelay onehundredMilliseconds
where onehundredMilliseconds = 100000
data HashCheck = forall a. (Show a, HashAlgorithm a) => HashCheck
{ hashCheckAlgorithm :: a
, hashCheckHexDigest :: CheckHexDigest
}
deriving instance Show HashCheck
data CheckHexDigest
= CheckHexDigestString String
| CheckHexDigestByteString ByteString
| CheckHexDigestHeader ByteString
deriving Show
instance IsString CheckHexDigest where
fromString = CheckHexDigestString
type LengthCheck = Int
-- | An exception regarding verification of a download.
data VerifiedDownloadException
= WrongContentLength
Request
Int -- expected
ByteString -- actual (as listed in the header)
| WrongStreamLength
Request
Int -- expected
Int -- actual
| WrongDigest
Request
String -- algorithm
CheckHexDigest -- expected
String -- actual (shown)
deriving (Typeable)
instance Show VerifiedDownloadException where
show (WrongContentLength req expected actual) =
"Download expectation failure: ContentLength header\n"
++ "Expected: " ++ show expected ++ "\n"
++ "Actual: " ++ displayByteString actual ++ "\n"
++ "For: " ++ show (getUri req)
show (WrongStreamLength req expected actual) =
"Download expectation failure: download size\n"
++ "Expected: " ++ show expected ++ "\n"
++ "Actual: " ++ show actual ++ "\n"
++ "For: " ++ show (getUri req)
show (WrongDigest req algo expected actual) =
"Download expectation failure: content hash (" ++ algo ++ ")\n"
++ "Expected: " ++ displayCheckHexDigest expected ++ "\n"
++ "Actual: " ++ actual ++ "\n"
++ "For: " ++ show (getUri req)
instance Exception VerifiedDownloadException
-- This exception is always caught and never thrown outside of this module.
data VerifyFileException
= WrongFileSize
Int -- expected
Integer -- actual (as listed by hFileSize)
deriving (Show, Typeable)
instance Exception VerifyFileException
-- Show a ByteString that is known to be UTF8 encoded.
displayByteString :: ByteString -> String
displayByteString =
Text.unpack . Text.strip . Text.decodeUtf8
-- Show a CheckHexDigest in human-readable format.
displayCheckHexDigest :: CheckHexDigest -> String
displayCheckHexDigest (CheckHexDigestString s) = s ++ " (String)"
displayCheckHexDigest (CheckHexDigestByteString s) = displayByteString s ++ " (ByteString)"
displayCheckHexDigest (CheckHexDigestHeader h) =
show (B64.decodeLenient h) ++ " (Header. unencoded: " ++ show h ++ ")"
-- | Make sure that the hash digest for a finite stream of bytes
-- is as expected.
--
-- Throws WrongDigest (VerifiedDownloadException)
sinkCheckHash :: MonadThrow m
=> Request
-> HashCheck
-> ConduitM ByteString o m ()
sinkCheckHash req HashCheck{..} = do
digest <- sinkHashUsing hashCheckAlgorithm
let actualDigestString = show digest
let actualDigestHexByteString = Mem.convertToBase Mem.Base16 digest
let actualDigestBytes = Mem.convert digest
let passedCheck = case hashCheckHexDigest of
CheckHexDigestString s -> s == actualDigestString
CheckHexDigestByteString b -> b == actualDigestHexByteString
CheckHexDigestHeader b -> B64.decodeLenient b == actualDigestHexByteString
|| B64.decodeLenient b == actualDigestBytes
-- A hack to allow hackage tarballs to download.
-- They should really base64-encode their md5 header as per rfc2616#sec14.15.
-- https://github.com/commercialhaskell/stack/issues/240
|| b == actualDigestHexByteString
unless passedCheck $
throwM $ WrongDigest req (show hashCheckAlgorithm) hashCheckHexDigest actualDigestString
assertLengthSink :: MonadThrow m
=> Request
-> LengthCheck
-> ZipSink ByteString m ()
assertLengthSink req expectedStreamLength = ZipSink $ do
Sum actualStreamLength <- CL.foldMap (Sum . ByteString.length)
when (actualStreamLength /= expectedStreamLength) $
throwM $ WrongStreamLength req expectedStreamLength actualStreamLength
-- | A more explicitly type-guided sinkHash.
sinkHashUsing :: (Monad m, HashAlgorithm a) => a -> ConduitM ByteString o m (Digest a)
sinkHashUsing _ = sinkHash
-- | Turns a list of hash checks into a ZipSink that checks all of them.
hashChecksToZipSink :: MonadThrow m => Request -> [HashCheck] -> ZipSink ByteString m ()
hashChecksToZipSink req = traverse_ (ZipSink . sinkCheckHash req)
-- 'Control.Retry.recovering' customized for HTTP failures
recoveringHttp :: forall env a. HasRunner env => RetryPolicy -> RIO env a -> RIO env a
recoveringHttp retryPolicy =
#if MIN_VERSION_retry(0,7,0)
helper $ \run -> recovering retryPolicy (handlers run) . const
#else
helper $ \run -> recovering retryPolicy (handlers run)
#endif
where
helper :: (UnliftIO (RIO env) -> IO a -> IO a) -> RIO env a -> RIO env a
helper wrapper action = withUnliftIO $ \run -> wrapper run (unliftIO run action)
handlers :: UnliftIO (RIO env) -> [RetryStatus -> Handler IO Bool]
handlers u = [Handler . alwaysRetryHttp u,const $ Handler retrySomeIO]
alwaysRetryHttp :: UnliftIO (RIO env) -> RetryStatus -> HttpException -> IO Bool
alwaysRetryHttp u rs _ = do
unliftIO u $
prettyWarn $ vcat
[ flow $ unwords
[ "Retry number"
, show (rsIterNumber rs)
, "after a total delay of"
, show (rsCumulativeDelay rs)
, "us"
]
, flow $ unwords
[ "If you see this warning and stack fails to download,"
, "but running the command again solves the problem,"
, "please report here: https://github.com/commercialhaskell/stack/issues/3510"
]
]
return True
retrySomeIO :: Monad m => IOException -> m Bool
retrySomeIO e = return $ case ioe_type e of
-- hGetBuf: resource vanished (Connection reset by peer)
ResourceVanished -> True
-- conservatively exclude all others
_ -> False
-- | Copied and extended version of Network.HTTP.Download.download.
--
-- Has the following additional features:
-- * Verifies that response content-length header (if present)
-- matches expected length
-- * Limits the download to (close to) the expected # of bytes
-- * Verifies that the expected # bytes were downloaded (not too few)
-- * Verifies md5 if response includes content-md5 header
-- * Verifies the expected hashes
--
-- Throws VerifiedDownloadException.
-- Throws IOExceptions related to file system operations.
-- Throws HttpException.
verifiedDownload
:: HasRunner env
=> DownloadRequest
-> Path Abs File -- ^ destination
-> (Maybe Integer -> ConduitM ByteString Void (RIO env) ()) -- ^ custom hook to observe progress
-> RIO env Bool -- ^ Whether a download was performed
verifiedDownload DownloadRequest{..} destpath progressSink = do
let req = drRequest
whenM' (liftIO getShouldDownload) $ do
logDebug $ "Downloading " <> decodeUtf8With lenientDecode (path req)
liftIO $ createDirectoryIfMissing True dir
recoveringHttp drRetryPolicy $
withSinkFile fptmp $ httpSink req . go
liftIO $ renameFile fptmp fp
where
whenM' mp m = do
p <- mp
if p then m >> return True else return False
fp = toFilePath destpath
fptmp = fp FP.<.> "tmp"
dir = toFilePath $ parent destpath
getShouldDownload = do
fileExists <- doesFileExist fp
if fileExists
-- only download if file does not match expectations
then not <$> fileMatchesExpectations
-- or if it doesn't exist yet
else return True
-- precondition: file exists
-- TODO: add logging
fileMatchesExpectations =
((checkExpectations >> return True)
`catch` \(_ :: VerifyFileException) -> return False)
`catch` \(_ :: VerifiedDownloadException) -> return False
checkExpectations = withBinaryFile fp ReadMode $ \h -> do
for_ drLengthCheck $ checkFileSizeExpectations h
runConduit
$ sourceHandle h
.| getZipSink (hashChecksToZipSink drRequest drHashChecks)
-- doesn't move the handle
checkFileSizeExpectations h expectedFileSize = do
fileSizeInteger <- hFileSize h
when (fileSizeInteger > toInteger (maxBound :: Int)) $
throwM $ WrongFileSize expectedFileSize fileSizeInteger
let fileSize = fromInteger fileSizeInteger
when (fileSize /= expectedFileSize) $
throwM $ WrongFileSize expectedFileSize fileSizeInteger
checkContentLengthHeader headers expectedContentLength =
case List.lookup hContentLength headers of
Just lengthBS -> do
let lengthStr = displayByteString lengthBS
when (lengthStr /= show expectedContentLength) $
throwM $ WrongContentLength drRequest expectedContentLength lengthBS
_ -> return ()
go sink res = do
let headers = getResponseHeaders res
mcontentLength = do
hLength <- List.lookup hContentLength headers
(i,_) <- readInteger hLength
return i
for_ drLengthCheck $ checkContentLengthHeader headers
let hashChecks = (case List.lookup hContentMD5 headers of
Just md5BS ->
[ HashCheck
{ hashCheckAlgorithm = MD5
, hashCheckHexDigest = CheckHexDigestHeader md5BS
}
]
Nothing -> []
) ++ drHashChecks
maybe id (\len -> (CB.isolate len .|)) drLengthCheck
$ getZipSink
( hashChecksToZipSink drRequest hashChecks
*> maybe (pure ()) (assertLengthSink drRequest) drLengthCheck
*> ZipSink sink
*> ZipSink (progressSink mcontentLength))