diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 0cfa06ff43..35850c2a25 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -199,27 +199,6 @@ class EvalState: self.task_states[self.dataset_type]["grader_log"] = [] self.task_states[self.dataset_type]["grader_log"].append(grader_log) - def print_task_header(self): - tasks_to_show = self.all_tasks if self.all_tasks else self.tasks - cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) - print("Tasks:") - print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status") - for i, task_id in tasks_to_show: - prompt, gold = self.get_case(i) - case = cases.get(task_id, {}) - status = case.get("status", "pending") - extracted = case.get("extracted", "N/A") if status == "ok" else "N/A" - is_correct = case.get("correct", False) if status == "ok" else False - symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") - first_line = prompt.split('\n')[0] - truncated_prompt = first_line[:43] - if len(first_line) > 43: - truncated_prompt += "..." - else: - truncated_prompt = truncated_prompt.ljust(43) + "..." - print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}") - print() - def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0): extracted_display = task_state.extracted if task_state.extracted else "N/A" success_ratio = correct_count / self.processed if self.processed > 0 else 0.0 @@ -328,6 +307,7 @@ class EvalState: def print_all_tasks(self): cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) tasks_to_show = self.all_tasks if self.all_tasks else self.tasks + print() print("Tasks:") print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status") for i, task_id in tasks_to_show: @@ -350,7 +330,7 @@ class EvalState: cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) correct = sum(1 for c in cases.values() if c.get("correct", False)) total = len(cases) - print(f"\n{'='*60}") + print(f"{'='*60}") print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)") print(f"{'='*60}") @@ -803,16 +783,13 @@ class Processor: eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks eval_state.processed = 0 - print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} questions...") + print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...") print(f"Server: {self.server_url} (model: {self.model_name})") print(f"Grader: {self.grader.grader_type}") print(f"Threads: {self.threads}") print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}") print() - if not resume: - eval_state.print_task_header() - correct_count = 0 with ThreadPoolExecutor(max_workers=self.threads) as executor: @@ -965,14 +942,14 @@ def main(): print(f"Loading existing eval state from {args.output}") eval_state = EvalState.load(args.output) - if eval_state.is_complete(): - eval_state.print_all_tasks() - eval_state.print_existing_summary() - return - eval_state.print_all_tasks() eval_state.print_existing_summary() + if eval_state.is_complete(): + return + + print() + if not args.resume: print(f"Evaluation incomplete. Run with --resume to continue.") return @@ -1035,6 +1012,8 @@ def main(): eval_state.dump() resume = False + eval_state.print_all_tasks() + processor = Processor( server_url=args.server, grader=grader,