Skip to content

Commit 15422ae

Browse files
vanpeltraubitsj
andauthored
[WB-3544] Apply diff.patch when calling wandb.restore (wandb#1339)
* wow, this has been busted since gorilla was released... * Attempt to fix windows * Just skip the test in Windows Co-authored-by: Jeff Raubitschek <jeff@wandb.com>
1 parent b30f465 commit 15422ae

4 files changed

Lines changed: 56 additions & 29 deletions

File tree

tests/test_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,7 @@ def test_local_already_running(runner, docker, local_settings):
667667
assert "A container named wandb-local is already running" in result.output
668668

669669

670+
@pytest.mark.skipif(platform.system() == "Windows", reason="The patch in mock_server.py doesn't work in windows")
670671
def test_restore_no_remote(runner, mock_server, git_repo, docker, monkeypatch):
671672
with open("patch.txt", "w") as f:
672673
f.write("test")

tests/utils/mock_server.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -199,22 +199,20 @@ def set_ctx(ctx):
199199

200200
def _bucket_config():
201201
return {
202-
'patch': '''
203-
diff --git a/patch.txt b/patch.txt
204-
index 30d74d2..9a2c773 100644
205-
--- a/patch.txt
206-
+++ b/patch.txt
207-
@@ -1 +1 @@
208-
-test
209-
\ No newline at end of file
210-
+testing
211-
\ No newline at end of file
212-
''',
213202
'commit': 'HEAD',
214203
'github': 'https://github.com/vanpelt',
215204
'config': '{"foo":{"value":"bar"}}',
216205
'files': {
217-
'edges': [{'node': {'directUrl': request.url_root + "/storage?file=metadata.json"}}]
206+
'edges': [
207+
{'node': {
208+
'directUrl': request.url_root + "/storage?file=wandb-metadata.json",
209+
'name': 'wandb-metadata.json'
210+
}},
211+
{'node': {
212+
'directUrl': request.url_root + "/storage?file=diff.patch",
213+
'name': 'diff.patch'
214+
}}
215+
]
218216
}
219217
}
220218

@@ -629,7 +627,7 @@ def storage():
629627
if request.method == "GET" and size:
630628
return os.urandom(size), 200
631629
# make sure to read the data
632-
data = request.get_data()
630+
request.get_data()
633631
if file == "wandb_manifest.json":
634632
return {
635633
"version": 1,
@@ -639,8 +637,27 @@ def storage():
639637
"digits.h5": {"digest": "TeSJ4xxXg0ohuL5xEdq2Ew==", "size": 81299},
640638
},
641639
}
642-
elif file == "metadata.json":
643-
return {"docker": "test/docker", "program": "train.py", "args": ["--test", "foo"], "git": ctx.get("git", {})}
640+
elif file == "wandb-metadata.json":
641+
return {
642+
"docker": "test/docker",
643+
"program": "train.py",
644+
"args": ["--test", "foo"],
645+
"git": ctx.get("git", {})
646+
}
647+
elif file == "diff.patch":
648+
# TODO: make sure the patch is valid for windows as well,
649+
# and un skip the test in test_cli.py
650+
return '''
651+
diff --git a/patch.txt b/patch.txt
652+
index 30d74d2..9a2c773 100644
653+
--- a/patch.txt
654+
+++ b/patch.txt
655+
@@ -1 +1 @@
656+
-test
657+
\ No newline at end of file
658+
+testing
659+
\ No newline at end of file
660+
'''
644661
return "", 200
645662

646663
@app.route("/artifacts/<entity>/<digest>", methods=["GET", "POST"])

wandb/cli/cli.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,6 @@ def restore(ctx, run, no_git, branch, project, entity):
13061306
commit, json_config, patch_content, metadata = api.run_config(
13071307
project, run=run, entity=entity
13081308
)
1309-
print(metadata)
13101309
repo = metadata.get("git", {}).get("repo")
13111310
image = metadata.get("docker")
13121311
restore_message = (
@@ -1325,6 +1324,7 @@ def restore(ctx, run, no_git, branch, project, entity):
13251324
)
13261325

13271326
if commit and api.git.enabled:
1327+
wandb.termlog("Fetching origin and finding commit: {}".format(commit))
13281328
subprocess.check_call(["git", "fetch", "--all"])
13291329
try:
13301330
api.git.repo.commit(commit)
@@ -1378,10 +1378,15 @@ def restore(ctx, run, no_git, branch, project, entity):
13781378
patch_rel_path = os.path.relpath(patch_path, start=root)
13791379
# --reject is necessary or else this fails any time a binary file
13801380
# occurs in the diff
1381-
# we use .call() instead of .check_call() for the same reason
1382-
# TODO(adrian): this means there is no error checking here
1383-
subprocess.call(["git", "apply", "--reject", patch_rel_path], cwd=root)
1384-
wandb.termlog("Applied patch")
1381+
exit_code = subprocess.call(
1382+
["git", "apply", "--reject", patch_rel_path], cwd=root
1383+
)
1384+
if exit_code == 0:
1385+
wandb.termlog("Applied patch")
1386+
else:
1387+
wandb.termerror(
1388+
"Failed to apply patch, try un-staging any un-committed changes"
1389+
)
13851390

13861391
util.mkdir_exists_ok(wandb_dir())
13871392
config_path = os.path.join(wandb_dir(), "config.yaml")

wandb/internal/internal_api.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,10 @@ def run_config(self, project, run=None, entity=None):
661661
bucket(name: $run) {
662662
config
663663
commit
664-
patch
665-
files(names: ["wandb-metadata.json"]) {
664+
files(names: ["wandb-metadata.json", "diff.patch"]) {
666665
edges {
667666
node {
667+
name
668668
directUrl
669669
}
670670
}
@@ -682,15 +682,19 @@ def run_config(self, project, run=None, entity=None):
682682
raise CommError("Run {}/{}/{} not found".format(entity, project, run))
683683
run = response["model"]["bucket"]
684684
commit = run["commit"]
685-
patch = run["patch"]
686685
config = json.loads(run["config"] or "{}")
686+
patch = None
687+
metadata = {}
687688
if len(run["files"]["edges"]) > 0:
688-
url = run["files"]["edges"][0]["node"]["directUrl"]
689-
res = requests.get(url)
690-
res.raise_for_status()
691-
metadata = res.json()
692-
else:
693-
metadata = {}
689+
for file_edge in run["files"]["edges"]:
690+
name = file_edge["node"]["name"]
691+
url = file_edge["node"]["directUrl"]
692+
res = requests.get(url)
693+
res.raise_for_status()
694+
if name == "wandb-metadata.json":
695+
metadata = res.json()
696+
elif name == "diff.patch":
697+
patch = res.text
694698
return (commit, config, patch, metadata)
695699

696700
@normalize_exceptions

0 commit comments

Comments
 (0)