| import json |
| import os |
| import sys |
| import time |
| import yaml |
| import spacy |
| import ast |
| from PIL import Image |
| from glob import glob |
| from tqdm import tqdm |
| from collections import defaultdict |
| import pandas as pd |
| from io import BytesIO |
| import base64 |
| from anls import anls_score |
| import torch |
| from torch.utils.data import Dataset, DataLoader, DistributedSampler |
| import torchvision.transforms as T |
| from eval import conversation as conversation_lib |
| from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \ |
| process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro |
| from eval.mmmu_utils import evaluate as evaluate_mmmu |
| from torchvision.transforms.functional import InterpolationMode |
| from datasets import load_dataset, concatenate_datasets |
|
|
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
| def build_transform(input_size): |
| MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
| transform = T.Compose([ |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
| T.ToTensor(), |
| T.Normalize(mean=MEAN, std=STD) |
| ]) |
| return transform |
|
|
|
|
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| best_ratio_diff = float('inf') |
| best_ratio = (1, 1) |
| area = width * height |
| for ratio in target_ratios: |
| target_aspect_ratio = ratio[0] / ratio[1] |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| if ratio_diff < best_ratio_diff: |
| best_ratio_diff = ratio_diff |
| best_ratio = ratio |
| elif ratio_diff == best_ratio_diff: |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| best_ratio = ratio |
| return best_ratio |
|
|
|
|
| def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): |
| orig_width, orig_height = image.size |
| aspect_ratio = orig_width / orig_height |
|
|
| |
| target_ratios = set( |
| (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
| i * j <= max_num and i * j >= min_num) |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
| |
| target_aspect_ratio = find_closest_aspect_ratio( |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
| |
| target_width = image_size * target_aspect_ratio[0] |
| target_height = image_size * target_aspect_ratio[1] |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
| |
| resized_img = image.resize((target_width, target_height)) |
| processed_images = [] |
| for i in range(blocks): |
| box = ( |
| (i % (target_width // image_size)) * image_size, |
| (i // (target_width // image_size)) * image_size, |
| ((i % (target_width // image_size)) + 1) * image_size, |
| ((i // (target_width // image_size)) + 1) * image_size |
| ) |
| |
| split_img = resized_img.crop(box) |
| processed_images.append(split_img) |
| assert len(processed_images) == blocks |
| if use_thumbnail and len(processed_images) != 1: |
| thumbnail_img = image.resize((image_size, image_size)) |
| processed_images.append(thumbnail_img) |
| return processed_images |
|
|
|
|
| def load_image(image, input_size=448, max_num=6, decoded=False): |
| if not decoded: |
| image = Image.open(image).convert('RGB') |
| transform = build_transform(input_size=input_size) |
| images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
| pixel_values = [transform(image) for image in images] |
| pixel_values = torch.stack(pixel_values) |
| return pixel_values |
|
|
|
|
| def levenshtein_distance(s1, s2): |
| if len(s1) > len(s2): |
| s1, s2 = s2, s1 |
|
|
| distances = range(len(s1) + 1) |
| for i2, c2 in enumerate(s2): |
| distances_ = [i2 + 1] |
| for i1, c1 in enumerate(s1): |
| if c1 == c2: |
| distances_.append(distances[i1]) |
| else: |
| distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) |
| distances = distances_ |
| return distances[-1] |
|
|
|
|
| def get_anls_score(pred, gold_labels, threshold, llava_eval=False): |
| values = [] |
| for answer in gold_labels: |
| |
| gt_answer = ' '.join(answer.strip().lower().split()) |
| det_answer = ' '.join(pred.strip().lower().split()) |
|
|
| dist = levenshtein_distance(gt_answer, det_answer) |
| length = max(len(answer.upper()), len(pred.upper())) |
| values.append(0.0 if length == 0 else float(dist) / float(length)) |
|
|
| question_result = 1 - min(values) |
|
|
| if llava_eval: |
| question_result = 1.0 if question_result >= threshold else 0.0 |
| else: |
| if (question_result < threshold): |
| question_result = 0 |
|
|
| return question_result |
|
|
|
|
| def isNumber(n: str): |
| try: |
| float(n) |
| return True |
| except ValueError: |
| return False |
|
|
|
|
| class COCOEvalDataset(Dataset): |
| def __init__(self, args, img_dir, subset=None): |
| self.args = args |
| self.img_files = sorted(glob(os.path.join(img_dir, "*"))) |
|
|
| if subset: |
| self.img_files = self.img_files[:subset] |
|
|
| self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files] |
|
|
| def __len__(self): |
| return len(self.img_files) |
|
|
| def __getitem__(self, idx): |
| img_path = self.img_files[idx] |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| return self.image_ids[idx], img |
|
|
|
|
| class Flickr30KEvalDataset(Dataset): |
| def __init__(self, args, img_dir, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8')) |
|
|
| if subset: |
| self.test_samples = self.test_samples[:subset] |
|
|
| def __len__(self): |
| return len(self.test_samples) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"]) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", "")) |
|
|
| return image_id, img |
|
|
|
|
| class VQAv2EvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["question_id"] |
| question = self.gt[idx]["question"] |
| answer = self.gt[idx]["answer"] |
|
|
| return img, question_id, question, answer |
|
|
|
|
| class TextVQAEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg') |
| if not os.path.exists(img_path): |
| img_path = img_path.replace('.jpg', '.png') |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["question_id"] |
| question = self.gt[idx]["question"] |
| answer = self.gt[idx]["answers"] |
|
|
| return img, question_id, question, answer |
|
|
|
|
| class GQAEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8')) |
| self.gt = [{ |
| "question_id": int(k), |
| "image": v['imageId'] + ".jpg", |
| "question": v['question'], |
| "answer": v['answer'] |
| } for k, v in self.gt.items()] |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["question_id"] |
| question = self.gt[idx]["question"] |
| answer = self.gt[idx]["answer"] |
|
|
| return img, question_id, question, [answer] |
|
|
|
|
| class ChartQAEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8')) |
| for i in range(len(self.gt)): |
| self.gt[i]['question_id'] = i |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"]) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["question_id"] |
| question = self.gt[idx]["query"] |
| answer = self.gt[idx]["label"] |
|
|
| return img, question_id, question, [answer] |
|
|
|
|
| class OKVQAEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, question_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations'] |
| self.questions = json.load(open(question_path, 'r'))['questions'] |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| qid2q = {q['question_id']: q['question'] for q in self.questions} |
|
|
| for ann in self.gt: |
| ann['answers'] = [ans['answer'] for ans in ann['answers']] |
| ann['question'] = qid2q[ann['question_id']] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_id = str(self.gt[idx]["image_id"]) |
| img_id = '0' * (12 - len(img_id)) + img_id |
| img_file_name = f"COCO_val2014_{img_id}.jpg" |
| img_path = os.path.join(self.img_dir, img_file_name) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["question_id"] |
| question = self.gt[idx]["question"] |
| answer = self.gt[idx]["answers"] |
|
|
| return img, question_id, question, answer |
|
|
|
|
| class DocVQAEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, split='val', subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| self.split = split |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1]) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["questionId"] |
| question = self.gt[idx]["question"] |
|
|
| if self.split == 'val': |
| answer = self.gt[idx]["answers"] |
| else: |
| answer = [''] |
|
|
| return img, question_id, question, answer |
|
|
|
|
| class OCRBenchEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]['image_path']) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| dataset_name = self.gt[idx]["dataset_name"] |
| question_id = f"{idx}" |
| question = self.gt[idx]["question"] |
| answer = self.gt[idx]["answers"] |
| data_type = self.gt[idx]["type"] |
|
|
| return img, question_id, question, answer, dataset_name, data_type |
|
|
|
|
| class AI2DiagramEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
|
|
| with open(gt_path, 'r') as json_file: |
| json_list = list(json_file) |
| self.gt = [json.loads(json_str) for json_str in json_list] |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]['image']) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["question_id"] |
| question = self.gt[idx]["question"] |
| answer = self.gt[idx]["answer"] |
|
|
| return img, question_id, question, answer |
|
|
|
|
| class AI2DiagramNoMaskEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
|
|
| with open(gt_path, 'r') as json_file: |
| json_list = list(json_file) |
| self.gt = [json.loads(json_str) for json_str in json_list] |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") |
| img_path = os.path.join(self.img_dir, img_file_name) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = self.gt[idx]["question_id"] |
| question = self.gt[idx]["question"] |
| answer = self.gt[idx]["answer"] |
|
|
| return img, question_id, question, answer |
|
|
|
|
| class RealworldQAEvalDataset(Dataset): |
| def __init__(self, args, img_dir, gt_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.gt = json.load(open(gt_path, encoding='utf-8')) |
|
|
| if subset: |
| self.gt = self.gt[:subset] |
|
|
| def __len__(self): |
| return len(self.gt) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.gt[idx]['image']) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
|
|
| question_id = int(self.gt[idx]['image'].replace(".webp", "")) |
| question = self.gt[idx]["question"] |
|
|
| if self.gt[idx]['question_type'] == "multi-choice": |
| choices = self.gt[idx]["choices"] |
| start_chr = 'A' |
| choices_str = '' |
| index2ans = {} |
| all_choices = [] |
| for choice in choices: |
| all_choices.append(start_chr) |
| index2ans[start_chr] = choice |
| choices_str += f"{start_chr}. {choice}\n" |
| start_chr = chr(ord(start_chr) + 1) |
|
|
| question = question + '\n' + choices_str |
| question = question + "Answer with the option's letter from the given choices directly." |
| answer = chr(ord('A') + self.gt[idx]['correct_choice_index']) |
| else: |
| question = question + "\nAnswer the question using a single word or phrase." |
| answer = self.gt[idx]['answer'] |
|
|
| return img, question_id, question, [answer] |
|
|
|
|
| class MathVistaEvalDataset(Dataset): |
| def __init__(self, args, task_cfg, gt_path=None): |
| self.args = args |
| self.task_cfg = task_cfg |
| self.dataset = load_dataset("AI4Math/MathVista")['testmini'] |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| img = self.dataset[idx]['decoded_image'] |
| img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
| question_id = self.dataset[idx]["pid"] |
| question = self.dataset[idx]["question"] |
| question_type = self.dataset[idx]["question_type"] |
| query = self.dataset[idx]["query"] |
| choices = self.dataset[idx]["choices"] |
| answer = self.dataset[idx]["answer"] |
|
|
| if question_type == 'multi_choice': |
| start_chr = 'A' |
| choices_str = '' |
| index2ans = {} |
| all_choices = [] |
| for choice in choices: |
| all_choices.append(start_chr) |
| index2ans[start_chr] = choice |
| choices_str += f"{start_chr}. {choice}\n" |
| start_chr = chr(ord(start_chr) + 1) |
|
|
| question = question + '\n' + choices_str |
| question = question + "Answer with the option's letter from the given choices directly." |
| answer = chr(ord('A') + choices.index(answer)) |
| else: |
| question = query.replace("Hint: ", "") |
| index2ans = {} |
| all_choices = [] |
|
|
| return img, question_id, question_type, question, answer, str(index2ans), str(all_choices) |
|
|
|
|
| def construct_prompt_for_fewshot(sample): |
| config = { |
| "task_instructions": "", |
| "multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.", |
| "short_ans_example_format": "{}\nAnswer the question using a single word or phrase." |
| } |
|
|
| question = sample['question'].strip() |
|
|
|
|
| options = eval(sample['options']) |
| example = "" |
| if sample['question_type'] == 'multiple-choice': |
| start_chr = 'A' |
| prediction_range = [] |
| index2ans = {} |
| for option in options: |
| prediction_range.append(start_chr) |
| example += f"({start_chr}) {option}\n" |
| index2ans[start_chr] = option |
| start_chr = chr(ord(start_chr) + 1) |
| empty_prompt_sample_structure = config['multi_choice_example_format'] |
| empty_prompt = empty_prompt_sample_structure.format(question, example) |
| res_dict = {'type': 'multichoice'} |
| res_dict['index2ans'] = index2ans |
| res_dict['correct_choice'] = sample['answer'] |
| res_dict['all_choices'] = prediction_range |
| res_dict['empty_prompt'] = empty_prompt |
| if config['task_instructions']: |
| res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt |
| else: |
| res_dict['final_input_prompt'] = empty_prompt |
|
|
| res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')] |
| else: |
| empty_prompt_sample_structure = config['short_ans_example_format'] |
| empty_prompt = empty_prompt_sample_structure.format(question) |
| res_dict = {'type': 'open'} |
| res_dict['empty_prompt'] = empty_prompt |
| if config['task_instructions']: |
| res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt |
| else: |
| res_dict['final_input_prompt'] = empty_prompt |
| res_dict['gt_content'] = sample['answer'] |
|
|
| res_dict.update(sample) |
| return res_dict |
|
|
|
|
| def process_image_tag(q): |
| q = q.strip() |
|
|
| |
| if q == '<image 1>': |
| q = 'Answer the question in the image.' |
| elif ':<image 1>' in q: |
| q = q.replace(':<image 1>', ' in the image. ') |
| q = q.strip() |
| elif ': <image 1>' in q: |
| q = q.replace(': <image 1>', ' in the image. ') |
| q = q.strip() |
| elif '.<image 1>' in q or '. <image 1>' in q: |
| q_list = q.split('<image 1>') |
| q_list = [part.strip() for part in q_list if part.strip() != ''] |
| q = ' '.join(q_list) |
| elif q.startswith('<image 1> '): |
| if q[10].isupper(): |
| q = q.replace('<image 1>', '') |
| else: |
| q = q.replace('<image 1>', 'The image') |
| q = q.strip() |
| elif q.startswith('<image 1>'): |
| q = q.replace('<image 1>', '') |
| elif q.endswith('<image 1>?'): |
| q = q.replace('<image 1>', 'the image') |
| elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'): |
| q = q.replace('<image 1>', '') |
| q = q.strip() |
| elif ' <image 1> ' in q: |
| q = q.replace('<image 1>', 'the image') |
| elif ' <image 1>' in q: |
| q = q.replace('<image 1>', 'the image') |
| elif '()<image 1>' in q: |
| q = q.replace('()<image 1>', '') |
| elif '(<image 1>)' in q: |
| q = q.replace('(<image 1>)', '') |
| elif '<image 1>.' in q: |
| q = q.replace("<image 1>.", ". ") |
| else: |
| q = q.replace("<image 1>", ". ") |
| q = q.strip() |
|
|
| |
| for i in range(2, 8): |
| q = q.replace(f"<image {i}>", "") |
|
|
| return q |
|
|
|
|
| class MMMUProEvalDataset(Dataset): |
| def __init__(self, args, task_cfg, subset=None): |
| self.args = args |
| self.task_cfg = task_cfg |
| sub_dataset_list = [] |
| |
| |
|
|
| MMMU_path = "MMMU/MMMU_Pro" |
|
|
| _split = "test" |
|
|
| self.dataset = load_dataset(MMMU_path, "standard", split=_split) |
| if subset: |
| self.dataset = self.dataset[:subset] |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| |
| sample = self.dataset[idx] |
| sample = process_single_sample_pro(sample) |
| sample = construct_prompt_pro(sample, self.task_cfg) |
| img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
| |
|
|
| question_id = sample['id'] |
| question = sample['final_input_prompt'] |
| answer = sample['answer'] |
|
|
| question = process_image_tag(question) |
| question = self.task_cfg['default_image_token'] + '\n' + question |
|
|
| if sample['question_type'] == 'multiple-choice': |
| index2ans = sample['index2ans'] |
| all_choices = sample['all_choices'] |
| else: |
| index2ans = {} |
| all_choices = [] |
|
|
| return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ |
| (all_choices) |
|
|
|
|
| class MMMUEvalDataset(Dataset): |
| def __init__(self, args, task_cfg, subset=None, start_idx=None): |
| self.args = args |
| self.task_cfg = task_cfg |
| sub_dataset_list = [] |
| |
| |
|
|
| MMMU_path = "MMMU/MMMU" |
|
|
| _split = "test" if task_cfg["split"] == "test" else "validation" |
| for subject in CAT_SHORT2LONG.values(): |
| sub_dataset = load_dataset( |
| MMMU_path, subject, |
| split=_split, |
| ) |
| sub_dataset_list.append(sub_dataset) |
|
|
| dataset = concatenate_datasets(sub_dataset_list) |
|
|
| if task_cfg["split"] != "test": |
| dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])] |
|
|
| |
|
|
| self.dataset = dataset |
|
|
| if subset: |
| self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))] |
| print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}") |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| |
| sample = self.dataset[idx] |
| sample = process_single_sample(sample) |
| sample = construct_prompt(sample, self.task_cfg) |
|
|
| img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
| question_id = sample['id'] |
| question = sample['final_input_prompt'] |
| answer = sample['answer'] |
|
|
| question = process_image_tag(question) |
| question = self.task_cfg['default_image_token'] + '\n' + question |
|
|
|
|
| if sample['question_type'] == 'multiple-choice': |
| index2ans = sample['index2ans'] |
| all_choices = sample['all_choices'] |
| else: |
| index2ans = {} |
| all_choices = [] |
|
|
| return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ |
| (all_choices) |
|
|
|
|
|
|
| class VizWizEvalDataset(Dataset): |
| def __init__(self, args, img_dir, question_path, subset=None): |
| self.args = args |
| self.img_dir = img_dir |
| self.questions = json.load(open(question_path, encoding='utf-8')) |
|
|
| def __len__(self): |
| return len(self.questions) |
|
|
| def __getitem__(self, idx): |
| img_path = os.path.join(self.img_dir, self.questions[idx]["image"]) |
| img = load_image(img_path, max_num=6).to(torch.bfloat16) |
| question = self.questions[idx]["question"] |
| question_id = self.questions[idx]["image"] |
|
|
| return img, question_id, question |
|
|
|
|
| class MMBenchEvalDataset(Dataset): |
| def __init__(self, args, gt_path, subset=None): |
| self.args = args |
| df = pd.read_csv(gt_path, sep='\t') |
| self.dataset = [] |
| for i, row in df.iterrows(): |
| choices = [] |
| for choice in ['A', 'B', 'C', 'D']: |
| if str(row[choice]) != 'nan': |
| choices.append(row[choice]) |
|
|
| this_sample = { |
| 'index': row['index'], |
| 'question': row['question'], |
| 'hint': row['hint'], |
| 'category': row['category'], |
| 'image': Image.open(BytesIO(base64.b64decode(row['image']))), |
| 'choices': choices |
| } |
|
|
| |
| if 'answer' in row.keys(): |
| this_sample['answer'] = row['answer'] |
| else: |
| this_sample['answer'] = '' |
|
|
| self.dataset.append(this_sample) |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) |
|
|
| question = self.dataset[idx]["question"] |
| hint = self.dataset[idx]["hint"] |
| question_id = self.dataset[idx]["index"] |
| choices = self.dataset[idx]["choices"] |
| answer = self.dataset[idx]["answer"] |
|
|
| start_chr = 'A' |
| choices_str = '' |
| index2ans = {} |
| all_choices = [] |
| for choice in choices: |
| all_choices.append(start_chr) |
| index2ans[start_chr] = choice |
| choices_str += f"{start_chr}. {choice}\n" |
| start_chr = chr(ord(start_chr) + 1) |
|
|
| question = question + '\n' + choices_str |
|
|
| return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"] |
|
|
|
|
| def get_task_dataloader(task_name, task_cfg, args): |
| if "subset" in task_cfg.keys(): |
| subset = task_cfg["subset"] |
| else: |
| subset = None |
|
|
| if task_name == "coco_caption": |
| dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset) |
| elif task_name == "flickr30k_caption": |
| dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset) |
| elif task_name == "vqav2": |
| dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| elif task_name == "textvqa": |
| dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| elif task_name == "gqa": |
| dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| elif task_name == "chartqa": |
| dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| elif task_name == "okvqa": |
| dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset) |
| elif task_name == "vizwiz": |
| dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset) |
| elif task_name == "docvqa": |
| dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset) |
| elif task_name == "docvqa_test": |
| dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset) |
| elif task_name == "realworldqa": |
| dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| elif task_name == "mmmu": |
| dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx) |
| elif task_name == "mmmu_pro": |
| dataset = MMMUProEvalDataset(args, task_cfg) |
| elif task_name == "mathvista": |
| dataset = MathVistaEvalDataset(args, task_cfg) |
| elif task_name == "mmbench": |
| dataset = MMBenchEvalDataset(args, task_cfg["gt_path"]) |
| elif task_name == 'ocrbench': |
| dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| elif task_name == 'ai2diagram': |
| dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| elif task_name == 'ai2diagram_nomask': |
| dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) |
| else: |
| raise NotImplementedError(f"Task {task_name} is not supported yet.") |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=1, |
| shuffle=False, |
| pin_memory=True, |
| ) |
|
|
| return dataloader |
|
|