-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcounting.py
262 lines (208 loc) · 9.4 KB
/
counting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import time
import argparse
from litellm import completion
import os
import cv2
from PIL import Image
import io
import csv
from tqdm import tqdm
import base64
def parse_widerface_annotations(annotation_file):
"""
Parse the WiderFace annotation file and return a dictionary mapping
image filenames to the number of faces.
Args:
annotation_file (str): Path to the WiderFace annotation file
Returns:
dict: Dictionary mapping image filenames to number of faces
"""
annotations = {}
with open(annotation_file, 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
# Get image filename
filename = lines[i].strip()
i += 1
# Get number of faces
num_faces = int(lines[i].strip())
i += 1
# Check if this is a "crowd" annotation (indicated by count of 1 and specific bbox format)
is_crowd = False
if num_faces == 1:
bbox_line = lines[i].strip()
# Check if this is the special "crowd" annotation format (all zeros or similar pattern)
if bbox_line == "0 0 0 0 0 0 0 0 0 0":
is_crowd = True
print(f"Skipping crowd image: {filename}")
# Skip bounding box lines
i += num_faces
# Store in dictionary only if not a crowd image
if not is_crowd:
annotations[filename] = num_faces
return annotations
def parse_face_count_response(response_text):
"""
Parse the model's response to extract the number of faces.
Args:
response_text (str): The model's raw response
Returns:
int: The parsed face count, or -1 if parsing failed
"""
try:
# Lower case and remove extra spaces
text = response_text.lower().strip()
# Try to find a number in the response
import re
numbers = re.findall(r'\b(\d+)\b', text)
if numbers:
# Return the first number found
return int(numbers[0])
return -1
except Exception as e:
print(f"Error parsing response: {str(e)}")
return -1
def evaluate_face_counter(model_name, dataset_dir, annotation_file):
"""
Evaluate a model on the face counting task and save results to CSV
Args:
model_name (str): The model identifier (e.g., "openai/gpt-4o")
dataset_dir (str): Directory containing the images
annotation_file (str): Path to the WiderFace annotation file
"""
# Extract model name without provider prefix for filename
model_short_name = model_name.split('/')[-1]
csv_filename = os.path.join("out_counting", f"face_counting_{model_short_name}.csv")
print(f"Evaluating {model_name} on face counting task...")
# Parse annotations
annotations = parse_widerface_annotations(annotation_file)
# Set up the CSV file and writer
fieldnames = ['filename', 'gt_num_faces', 'response_num_faces', 'raw_response']
# Check if the file exists to determine if we need to write a header
file_exists = os.path.isfile(csv_filename)
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
# Open the file in append mode so we can add rows incrementally
csv_file = open(csv_filename, 'a', newline='')
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
# Write header only if file is new
if not file_exists:
writer.writeheader()
# If file exists, determine which images we've already processed
processed_images = set()
if file_exists:
with open(csv_filename, 'r', newline='') as f:
reader = csv.DictReader(f)
for row in reader:
processed_images.add(row['filename'])
try:
# Process each image
for filename, gt_num_faces in tqdm(annotations.items()):
# Skip if this image was already processed
if filename in processed_images:
continue
# Load image
img_path = os.path.join(dataset_dir, filename)
img = cv2.imread(img_path)
# Skip if image loading failed
if img is None:
print(f"Warning: Failed to load image {img_path}, skipping")
continue
# Convert from BGR (OpenCV format) to RGB (for PIL)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Convert to PIL Image
pil_img = Image.fromarray(img_rgb)
max_size = 5 * 1024 * 1024 # 5MB in bytes
current_width, current_height = pil_img.size
scale_factor = 1.0
while True:
# Save to bytes buffer in PNG format
buffer = io.BytesIO()
pil_img.save(buffer, format="PNG")
buffer.seek(0)
# Check size
if len(buffer.getvalue()) <= max_size:
break
# Resize by 50%
scale_factor *= 0.5
new_width = int(current_width * scale_factor)
new_height = int(current_height * scale_factor)
pil_img = pil_img.resize((new_width, new_height), Image.LANCZOS)
print(f"Resizing image to {new_width}x{new_height} (scale: {scale_factor:.2f})")
# Safety check to prevent infinite loop with images that can't be compressed enough
if new_width < 200 or new_height < 200:
print(f"Warning: Cannot reduce image {img_path} below 5MB even at very small size")
break
buffer.seek(0)
# Encode to base64
encoded_file = base64.b64encode(buffer.getvalue()).decode("utf-8")
base64_url = f"data:image/png;base64,{encoded_file}"
# Prepare prompt
prompt = [
{"role": "system", "content": "You are an AI assistant that specializes in analyzing images. Please provide accurate and concise information."},
{"role": "user", "content": [
{"type": "text", "text": "How many visible human faces are in this image? Please respond with just a number. Any response other than an integer will be considered an error."},
{"type": "image_url", "image_url": {"url": base64_url}}
]}
]
try:
# Make API call
response = completion(model=model_name, messages=prompt, temperature=0.0, max_tokens=10)
response_text = response.choices[0].message.content
parsed_response = parse_face_count_response(response_text)
try:
clean_response = response_text.strip().replace('\n', ' ').replace('\r', ' ')
except:
clean_response = ""
# Create result dictionary
result = {
'filename': filename,
'gt_num_faces': gt_num_faces,
'response_num_faces': parsed_response,
'raw_response': clean_response
}
# Immediately write this result to the CSV file
writer.writerow(result)
# Flush to make sure it's written to disk
csv_file.flush()
# Add a small delay to avoid rate limiting
time.sleep(0.2)
except Exception as e:
print(f"Error processing {filename}: {str(e)}")
finally:
# Close the CSV file
csv_file.close()
def load_api_keys():
"""Load API keys from files"""
# Load OpenAI API key if file exists
if os.path.exists("keys/openai_key.txt"):
with open("keys/openai_key.txt", "r") as f:
os.environ["OPENAI_API_KEY"] = f.read().strip()
# Load Claude API key if file exists
if os.path.exists("keys/claude_key.txt"):
with open("keys/claude_key.txt", "r") as f:
os.environ["ANTHROPIC_API_KEY"] = f.read().strip()
# Load Gemini API key if file exists
if os.path.exists("keys/gemini_key.txt"):
with open("keys/gemini_key.txt", "r") as f:
os.environ["GEMINI_API_KEY"] = f.read().strip()
# Load XAI API key if file exists
if os.path.exists("keys/xai_key.txt"):
with open("keys/xai_key.txt", "r") as f:
os.environ["XAI_API_KEY"] = f.read().strip()
def main():
parser = argparse.ArgumentParser(description='Evaluate LLMs on face counting task')
parser.add_argument('--model', type=str, required=True,
help='Model identifier (e.g., "openai/gpt-4o")')
parser.add_argument('--dataset_dir', type=str, default='./WIDER_val/images',
help='Directory containing WiderFace images (default: ./WIDER_val/images)')
parser.add_argument('--annotation_file', type=str, default='./WIDER_val/wider_face_val_bbx_gt.txt',
help='Path to WiderFace annotation file (default: ./WIDER_val/wider_face_val_bbx_gt.txt)')
args = parser.parse_args()
# Load API keys
load_api_keys()
# Run evaluation
evaluate_face_counter(args.model, args.dataset_dir, args.annotation_file)
if __name__ == "__main__":
main()