From 38b57241280e29c1367272031567e148a5f318dc Mon Sep 17 00:00:00 2001 From: jbykkk <209502307+jbykkk@users.noreply.github.com> Date: Sun, 26 Apr 2026 12:08:05 +0800 Subject: [PATCH] Fix inference checkpoint resume to skip completed episodes --- .../vln/habitat_vln_evaluator.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/internnav/habitat_extensions/vln/habitat_vln_evaluator.py b/internnav/habitat_extensions/vln/habitat_vln_evaluator.py index 01378df3..d93b3c44 100644 --- a/internnav/habitat_extensions/vln/habitat_vln_evaluator.py +++ b/internnav/habitat_extensions/vln/habitat_vln_evaluator.py @@ -243,8 +243,9 @@ def parse_actions(self, output): def resume_from_output_path(self) -> None: sucs, spls, oss, nes, ndtw = [], [], [], [], [] + completed_episodes = set() if self.rank != 0: - return sucs, spls, oss, nes, ndtw + return sucs, spls, oss, nes, ndtw, completed_episodes # resume from previous results if os.path.exists(os.path.join(self.output_path, 'progress.json')): @@ -257,13 +258,14 @@ def resume_from_output_path(self) -> None: nes.append(res['ne']) if 'ndtw' in res: ndtw.append(res['ndtw']) - return sucs, spls, oss, nes, ndtw + completed_episodes.add((res['scene_id'], res['episode_id'])) + return sucs, spls, oss, nes, ndtw, completed_episodes def _run_eval_dual_system(self) -> tuple: # noqa: C901 self.model.eval() # resume from previous results - sucs, spls, oss, nes, ndtw = self.resume_from_output_path() + sucs, spls, oss, nes, ndtw, completed_episodes = self.resume_from_output_path() # Episode loop is now driven by env.reset() + env.is_running process_bar = tqdm.tqdm(total=len(self.env.episodes), desc=f"Eval Epoch {self.epoch} Rank {self.rank}") @@ -281,6 +283,12 @@ def _run_eval_dual_system(self) -> tuple: # noqa: C901 scene_id = episode.scene_id.split('/')[-2] episode_id = int(episode.episode_id) episode_instruction = episode.instruction.instruction_text + + # skip already completed episodes + if (scene_id, episode_id) in completed_episodes: + process_bar.update(1) + continue + print("episode start", episode_instruction) # save first frame per rank to validate sim quality @@ -632,7 +640,7 @@ def _run_eval_system2(self) -> tuple: self.model.eval() # resume from previous results - sucs, spls, oss, nes, ndtw = self.resume_from_output_path() + sucs, spls, oss, nes, ndtw, _ = self.resume_from_output_path() # Episode loop is now driven by env.reset() + env.is_running process_bar = tqdm.tqdm(total=len(self.env.episodes), desc=f"Eval Epoch {self.epoch} Rank {self.rank}")