Skip to content

Commit

Permalink
fix regex in mlx to parse more scenarios; model_path in models.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
mkagenius committed Dec 25, 2024
1 parent e7ada94 commit f7b195b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
3 changes: 2 additions & 1 deletion clickclickclick/config/models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ mlx:
output_width: 100 # max range of outputted values
output_height: 100
finder:
output_width: 100
output_width: 100
model_path: mlx-community/Molmo-7B-D-0924-4bit
59 changes: 36 additions & 23 deletions clickclickclick/finder/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,50 +10,63 @@

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def extract_coordinates(response_text):
# Define a regex pattern to capture key-value pairs in the format key="value" or key: value
pattern = r'(\w+)=?\"?(\d*\.\d+|\d+)\"?|(\w+):\s*(\d*\.\d+)'

# Use regular expression to find all key-value pairs
def extract_coordinates(response_text):
# Attempt to find key-value pairs first
pattern = r'(\w+)\s*[:=]\s*"?(\d*\.\d+|\d+)"?'
matches = re.findall(pattern, response_text)

# Convert matches into a dictionary
coordinates_dict = {}
for match in matches:
if match[0]: # for key="value" pattern
# If matches are found, process as key-value pairs
if matches:
coordinates_dict = {}
for match in matches:
key = match[0]
value = match[1]
else: # for key: value pattern
key = match[2]
value = match[3]
# Convert each value to int after float conversion for precision if needed
coordinates_dict[key] = int(float(value))
coordinates_dict[key] = float(value)
else:
# Assume input is a comma-separated list of values in the order ymin,ymax,xmin,xmax
try:
values = [float(value.strip()) for value in response_text.split(',')]
except ValueError as e:
logger.info(e)
return json.dumps({"ymin": 0, "ymax": 0, "xmin": 0, "xmax": 0})
if len(values) == 4:
coordinates_dict = {
'ymin': values[0],
'ymax': values[1],
'xmin': values[2],
'xmax': values[3]
}
else:
# Handle error case where input doesn't match expected format
raise ValueError("Input does not contain valid key-value pairs or valid coordinate list.")

# Define the normalization mapping
conversion_map = {
"x1": "xmin",
"y1": "ymin",
"x2": "xmax",
"y2": "ymax"
'x1': 'xmin',
'y1': 'ymin',
'x2': 'xmax',
'y2': 'ymax'
}

# Transform the extracted coordinates into a standardized format
standardized_coordinates = {conversion_map[key]: coordinates_dict[key] for key in conversion_map if key in coordinates_dict}
standardized_coordinates.update(coordinates_dict)
standardized_coordinates = {conversion_map.get(key, key): value for key, value in coordinates_dict.items()}

# Convert the standardized dictionary to a JSON string
response_json = json.dumps(standardized_coordinates)

return response_json


class MLXFinder(BaseFinder):
def __init__(self, c: BaseConfig, executor, model_path="mlx-community/Molmo-7B-D-0924-4bit"):
def __init__(self, c: BaseConfig, executor):
self.executor = executor
finder_config = c.models.get("finder_config")
model_path = finder_config.get("model_path")
self.config = load_config(model_path)
self.model, self.processor = load(model_path, {"trust_remote_code": True})
# self.image_finder_prompt = c.prompts["image-finder-prompt"]
self.system_prompt = c.prompts["finder-system-prompt"]
finder_config = c.models.get("finder_config")

self.IMAGE_WIDTH = finder_config.get("image_width")
self.IMAGE_HEIGHT = finder_config.get("image_height")
self.OUTPUT_WIDTH = finder_config.get("output_width")
Expand All @@ -75,7 +88,7 @@ def process_image(self, image_path, prompt):

# Example usage
def process_segment(self, segment, model_name, prompt):
prompt = f'UI bounds of "{prompt}" as ymin,ymax,xmin,xmax format strictly. '
prompt = f'UI bounds of "{prompt}" as ymin=,ymax=,xmin=,xmax= format strictly. '
segment_image, coordinates = segment
with NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
segment_image.save(temp_file, format="PNG")
Expand Down

0 comments on commit f7b195b

Please sign in to comment.