diff --git a/internnav/habitat_extensions/vln/habitat_vln_evaluator.py b/internnav/habitat_extensions/vln/habitat_vln_evaluator.py index 01378df3..d93b3c44 100644 --- a/internnav/habitat_extensions/vln/habitat_vln_evaluator.py +++ b/internnav/habitat_extensions/vln/habitat_vln_evaluator.py @@ -243,8 +243,9 @@ def parse_actions(self, output): def resume_from_output_path(self) -> None: sucs, spls, oss, nes, ndtw = [], [], [], [], [] + completed_episodes = set() if self.rank != 0: - return sucs, spls, oss, nes, ndtw + return sucs, spls, oss, nes, ndtw, completed_episodes # resume from previous results if os.path.exists(os.path.join(self.output_path, 'progress.json')): @@ -257,13 +258,14 @@ def resume_from_output_path(self) -> None: nes.append(res['ne']) if 'ndtw' in res: ndtw.append(res['ndtw']) - return sucs, spls, oss, nes, ndtw + completed_episodes.add((res['scene_id'], res['episode_id'])) + return sucs, spls, oss, nes, ndtw, completed_episodes def _run_eval_dual_system(self) -> tuple: # noqa: C901 self.model.eval() # resume from previous results - sucs, spls, oss, nes, ndtw = self.resume_from_output_path() + sucs, spls, oss, nes, ndtw, completed_episodes = self.resume_from_output_path() # Episode loop is now driven by env.reset() + env.is_running process_bar = tqdm.tqdm(total=len(self.env.episodes), desc=f"Eval Epoch {self.epoch} Rank {self.rank}") @@ -281,6 +283,12 @@ def _run_eval_dual_system(self) -> tuple: # noqa: C901 scene_id = episode.scene_id.split('/')[-2] episode_id = int(episode.episode_id) episode_instruction = episode.instruction.instruction_text + + # skip already completed episodes + if (scene_id, episode_id) in completed_episodes: + process_bar.update(1) + continue + print("episode start", episode_instruction) # save first frame per rank to validate sim quality @@ -632,7 +640,7 @@ def _run_eval_system2(self) -> tuple: self.model.eval() # resume from previous results - sucs, spls, oss, nes, ndtw = self.resume_from_output_path() + sucs, spls, oss, nes, ndtw, _ = self.resume_from_output_path() # Episode loop is now driven by env.reset() + env.is_running process_bar = tqdm.tqdm(total=len(self.env.episodes), desc=f"Eval Epoch {self.epoch} Rank {self.rank}")