diff --git a/clickclickclick/config/models.yaml b/clickclickclick/config/models.yaml index bf32b47..72bebd1 100644 --- a/clickclickclick/config/models.yaml +++ b/clickclickclick/config/models.yaml @@ -57,4 +57,5 @@ mlx: output_width: 100 # max range of outputted values output_height: 100 finder: - output_width: 100 \ No newline at end of file + output_width: 100 + model_path: mlx-community/Molmo-7B-D-0924-4bit \ No newline at end of file diff --git a/clickclickclick/finder/mlx.py b/clickclickclick/finder/mlx.py index ff9e171..3bc903c 100644 --- a/clickclickclick/finder/mlx.py +++ b/clickclickclick/finder/mlx.py @@ -10,50 +10,63 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" -def extract_coordinates(response_text): - # Define a regex pattern to capture key-value pairs in the format key="value" or key: value - pattern = r'(\w+)=?\"?(\d*\.\d+|\d+)\"?|(\w+):\s*(\d*\.\d+)' - # Use regular expression to find all key-value pairs +def extract_coordinates(response_text): + # Attempt to find key-value pairs first + pattern = r'(\w+)\s*[:=]\s*"?(\d*\.\d+|\d+)"?' matches = re.findall(pattern, response_text) - # Convert matches into a dictionary - coordinates_dict = {} - for match in matches: - if match[0]: # for key="value" pattern + # If matches are found, process as key-value pairs + if matches: + coordinates_dict = {} + for match in matches: key = match[0] value = match[1] - else: # for key: value pattern - key = match[2] - value = match[3] - # Convert each value to int after float conversion for precision if needed - coordinates_dict[key] = int(float(value)) + coordinates_dict[key] = float(value) + else: + # Assume input is a comma-separated list of values in the order ymin,ymax,xmin,xmax + try: + values = [float(value.strip()) for value in response_text.split(',')] + except ValueError as e: + logger.info(e) + return json.dumps({"ymin": 0, "ymax": 0, "xmin": 0, "xmax": 0}) + if len(values) == 4: + coordinates_dict = { + 'ymin': values[0], + 'ymax': values[1], + 'xmin': values[2], + 'xmax': values[3] + } + else: + # Handle error case where input doesn't match expected format + raise ValueError("Input does not contain valid key-value pairs or valid coordinate list.") # Define the normalization mapping conversion_map = { - "x1": "xmin", - "y1": "ymin", - "x2": "xmax", - "y2": "ymax" + 'x1': 'xmin', + 'y1': 'ymin', + 'x2': 'xmax', + 'y2': 'ymax' } # Transform the extracted coordinates into a standardized format - standardized_coordinates = {conversion_map[key]: coordinates_dict[key] for key in conversion_map if key in coordinates_dict} - standardized_coordinates.update(coordinates_dict) + standardized_coordinates = {conversion_map.get(key, key): value for key, value in coordinates_dict.items()} + # Convert the standardized dictionary to a JSON string response_json = json.dumps(standardized_coordinates) return response_json - class MLXFinder(BaseFinder): - def __init__(self, c: BaseConfig, executor, model_path="mlx-community/Molmo-7B-D-0924-4bit"): + def __init__(self, c: BaseConfig, executor): self.executor = executor + finder_config = c.models.get("finder_config") + model_path = finder_config.get("model_path") self.config = load_config(model_path) self.model, self.processor = load(model_path, {"trust_remote_code": True}) # self.image_finder_prompt = c.prompts["image-finder-prompt"] self.system_prompt = c.prompts["finder-system-prompt"] - finder_config = c.models.get("finder_config") + self.IMAGE_WIDTH = finder_config.get("image_width") self.IMAGE_HEIGHT = finder_config.get("image_height") self.OUTPUT_WIDTH = finder_config.get("output_width") @@ -75,7 +88,7 @@ def process_image(self, image_path, prompt): # Example usage def process_segment(self, segment, model_name, prompt): - prompt = f'UI bounds of "{prompt}" as ymin,ymax,xmin,xmax format strictly. ' + prompt = f'UI bounds of "{prompt}" as ymin=,ymax=,xmin=,xmax= format strictly. ' segment_image, coordinates = segment with NamedTemporaryFile(delete=False, suffix=".png") as temp_file: segment_image.save(temp_file, format="PNG")