diff --git a/examples/llama-eval/llama-eval-new.py b/examples/llama-eval/llama-eval-new.py index 3f202a952b..0c09753cfc 100755 --- a/examples/llama-eval/llama-eval-new.py +++ b/examples/llama-eval/llama-eval-new.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Dict, List, Optional, Any import requests from tqdm import tqdm +import random cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) @@ -194,10 +195,10 @@ class Processor: response.raise_for_status() return response.json() - def _process_single_case(self, i: int) -> TaskState: + def _process_single_case(self, i: int, task_id: str) -> TaskState: """Process a single case (thread-safe)""" question = self.dataset.get_question(i) - case_id = f"aime_{self.dataset.split}_{question['id']}" + dataset_id = f"aime_{self.dataset.split}_{question['id']}" gold = self.dataset.get_answer(question) # Apply template if available @@ -207,7 +208,7 @@ class Processor: prompt = question["problem"] task_state = TaskState( - case_id=case_id, + case_id=task_id, prompt=prompt, gold=gold ) @@ -223,7 +224,7 @@ class Processor: return task_state - def process(self, n_cases: int = None, seed: int = 42): + def process(self, n_cases: int = None, seed: int = 1234): """Process cases and update eval state""" if n_cases is None: n_cases = len(self.dataset.questions) @@ -234,26 +235,37 @@ class Processor: print(f"Max tokens: {self.n_predict}") print() + dataset_size = len(self.dataset.questions) + random.seed(seed) + + task_list = [] + for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size): + chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size) + indices = list(range(dataset_size)) + random.shuffle(indices) + chunk_indices = indices[:chunk_size] + + for i in chunk_indices: + task_id = f"aime_{self.eval_state.id}_{chunk_idx:03d}_{i:03d}" + task_list.append((i, task_id)) + # Print task summary table print("Tasks:") print(" Task ID Dataset Prompt (first 40 chars) Expected Status") - for i in range(min(n_cases, len(self.dataset.questions))): + for i, task_id in task_list: question = self.dataset.get_question(i) - case_id = f"aime_{self.dataset.split}_{question['id']}" prompt = question["problem"] gold = self.dataset.get_answer(question) truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt - print(f" {case_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending") + print(f" {task_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending") print() task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks} total = 0 correct = 0 - indices = list(range(min(n_cases, len(self.dataset.questions)))) - with ThreadPoolExecutor(max_workers=self.threads) as executor: - futures = {executor.submit(self._process_single_case, i): i for i in indices} + futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list} for future in as_completed(futures): task_state = future.result() @@ -309,6 +321,12 @@ def main(): default=None, help="Number of cases to evaluate (default: all)" ) + parser.add_argument( + "--seed", + type=int, + default=1234, + help="Random seed for shuffling (default: 1234)" + ) parser.add_argument( "--n_predict", type=int, @@ -376,7 +394,7 @@ def main(): model_name=args.model ) - eval_state = processor.process(n_cases=args.n_cases) + eval_state = processor.process(n_cases=args.n_cases, seed=args.seed) processor.dump_state(args.output) if __name__ == "__main__":