Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to use objathor and black formatting. #37

Merged
merged 7 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Continuous integration

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10']

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install
run: |
python3 -m venv .env
source .env/bin/activate
make install
- name: Unit tests
run: |
source .env/bin/activate
make test
37 changes: 37 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Release

on:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions-ecosystem/action-regex-match@v2
id: regex-match
with:
text: ${{ github.event.head_commit.message }}
regex: '^Release ([^ ]+)'
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Release
if: ${{ steps.regex-match.outputs.match != '' }}
uses: softprops/action-gh-release@v1
with:
tag_name: ${{ steps.regex-match.outputs.group1 }}
- name: Build and publish
if: ${{ steps.regex-match.outputs.match != '' }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
17 changes: 17 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
install: ## [Local development] Upgrade pip, install requirements, install package.
python -m pip install -U pip
python -m pip install -e .

install-dev: ## [Local development] Install requirements
python -m pip install -r requirements.txt

black: ## [Local development] Auto-format python code using black
python -m black .

test: ## [Local development] Run unit tests
python -m pytest -x -s -v tests

.PHONY: help

help: # Run `make help` to get help on the make commands
@grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
</h5>

<h4 align="center">
<a href="https://arxiv.org/abs/2312.09067">Paper</i></a> | <a href="https://yueyang1996.github.io/holodeck/">Project Page</i></a>
<a href="https://arxiv.org/abs/2312.09067"><i>Paper</i></a> | <a href="https://yueyang1996.github.io/holodeck/"><i>Project Page</i></a>
</h4>

## Requirements
Expand All @@ -21,25 +21,28 @@ Holodeck is based on [AI2-THOR](https://ai2thor.allenai.org/ithor/documentation/
## Installation
After cloning the repo, you can install the required dependencies using the following commands:
```
conda create --name holodeck python=3.9.16
conda create --name holodeck python=3.10
conda activate holodeck
pip install -r requirements.txt
pip install --extra-index-url https://ai2thor-pypi.allenai.org ai2thor==0+6f165fdaf3cf2d03728f931f39261d14a67414d0
pip install --extra-index-url https://ai2thor-pypi.allenai.org ai2thor==0+8524eadda94df0ab2dbb2ef5a577e4d37c712897
```

## Data
Download the data from [google drive](https://drive.google.com/file/d/1MQbFbNfTz94x8Pxfkgbohz4l46O5e3G1/view?usp=sharing) and extract it to the `data/` folder, or use the following command to download from S3:
```
wget https://holodeck-ai2.s3.amazonaws.com/data.zip
unzip data.zip
Download the data by running the following commands:
```bash
python -m objathor.dataset.download_holodeck_metadata --version 2023_09_23
python -m objathor.dataset.download_assets --version 2023_09_23
python -m objathor.dataset.download_annotations --version 2023_09_23
python -m objathor.dataset.download_features --version 2023_09_23
```
by default these will save to `~/.objathor-assets/...`, you can change this director by specifying the `--path` argument. If you change the `--path`, you'll need to set the `OBJAVERSE_ASSETS_DIR` environment variable to the path where the assets are stored when you use Holodeck.

## Usage
You can use the following command to generate a new environment.
```
python main.py --query "a living room" --openai_api_key <OPENAI_API_KEY>
python holodeck/main.py --query "a living room" --openai_api_key <OPENAI_API_KEY>
```
To be noticed, our system uses `gpt-4-1106-preview`, so please ensure you have access to it.
Our system uses `gpt-4-1106-preview`, **so please ensure you have access to it.**

**Note:** To yield better layouts, use `DFS` as the solver. If you pull the repo before `12/28/2023`, you must set the [argument](https://github.com/allenai/Holodeck/blob/386b0a868def29175436dc3b1ed85b6309eb3cad/main.py#L78) `--use_milp` to `False` to use `DFS`.

Expand Down
Empty file added ai2holodeck/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions ai2holodeck/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
from pathlib import Path

ABS_PATH_OF_HOLODECK = os.path.abspath(os.path.dirname(Path(__file__)))

ASSETS_VERSION = os.environ.get("ASSETS_VERSION", "2023_09_23")
HD_BASE_VERSION = os.environ.get("HD_BASE_VERSION", "2023_09_23")

OBJATHOR_ASSETS_BASE_DIR = os.environ.get(
"OBJATHOR_ASSETS_BASE_DIR", os.path.expanduser(f"~/.objathor-assets")
)

OBJATHOR_VERSIONED_DIR = os.path.join(OBJATHOR_ASSETS_BASE_DIR, ASSETS_VERSION)
OBJATHOR_ASSETS_DIR = os.path.join(OBJATHOR_VERSIONED_DIR, "assets")
OBJATHOR_FEATURES_DIR = os.path.join(OBJATHOR_VERSIONED_DIR, "features")
OBJATHOR_ANNOTATIONS_PATH = os.path.join(OBJATHOR_VERSIONED_DIR, "annotations.json.gz")

HOLODECK_BASE_DATA_DIR = os.path.join(
OBJATHOR_ASSETS_BASE_DIR, "holodeck", HD_BASE_VERSION
)

HOLODECK_THOR_FEATURES_DIR = os.path.join(HOLODECK_BASE_DATA_DIR, "thor_object_data")
HOLODECK_THOR_ANNOTATIONS_PATH = os.path.join(
HOLODECK_BASE_DATA_DIR, "thor_object_data", "annotations.json.gz"
)

if ASSETS_VERSION > "2023_09_23":
THOR_COMMIT_ID = "8524eadda94df0ab2dbb2ef5a577e4d37c712897"
else:
THOR_COMMIT_ID = "3213d486cd09bcbafce33561997355983bdf8d1a"

# LLM_MODEL_NAME = "gpt-4-1106-preview"
LLM_MODEL_NAME = "gpt-4o-2024-05-13"

DEBUGGING = os.environ.get("DEBUGGING", "0").lower() in ["1", "true", "True", "t", "T"]
Empty file.
Original file line number Diff line number Diff line change
@@ -1,114 +1,143 @@
import re
import copy
import re

import torch
from colorama import Fore
import torch.nn.functional as F
import modules.prompts as prompts
from langchain import PromptTemplate
from colorama import Fore
from langchain import PromptTemplate, OpenAI
from shapely.geometry import Polygon


class CeilingObjectGenerator():
def __init__(self, llm, object_retriever):
self.json_template = {"assetId": None, "id": None, "kinematic": True,
"position": {}, "rotation": {}, "material": None, "roomId": None}
import ai2holodeck.generation.prompts as prompts
from ai2holodeck.generation.objaverse_retriever import ObjathorRetriever
from ai2holodeck.generation.utils import get_bbox_dims, get_annotations


class CeilingObjectGenerator:
def __init__(self, object_retriever: ObjathorRetriever, llm: OpenAI):
self.json_template = {
"assetId": None,
"id": None,
"kinematic": True,
"position": {},
"rotation": {},
"material": None,
"roomId": None,
}
self.llm = llm
self.object_retriever = object_retriever
self.database = object_retriever.database
self.ceiling_template = PromptTemplate(input_variables=["input", "rooms", "additional_requirements"],
template=prompts.ceiling_selection_prompt)

self.ceiling_template = PromptTemplate(
input_variables=["input", "rooms", "additional_requirements"],
template=prompts.ceiling_selection_prompt,
)

def generate_ceiling_objects(self, scene, additional_requirements_ceiling="N/A"):
room_types = [room["roomType"] for room in scene["rooms"]]
room_types_str = str(room_types).replace("'", "")[1:-1]
ceiling_prompt = self.ceiling_template.format(input=scene["query"],
rooms=room_types_str,
additional_requirements=additional_requirements_ceiling)
ceiling_prompt = self.ceiling_template.format(
input=scene["query"],
rooms=room_types_str,
additional_requirements=additional_requirements_ceiling,
)

if "raw_ceiling_plan" not in scene: raw_ceiling_plan = self.llm(ceiling_prompt)
else: raw_ceiling_plan = scene["raw_ceiling_plan"]
if "raw_ceiling_plan" not in scene:
raw_ceiling_plan = self.llm(ceiling_prompt)
else:
raw_ceiling_plan = scene["raw_ceiling_plan"]

print(f"\nUser: {ceiling_prompt}\n")
print(f"{Fore.GREEN}AI: Here is the ceiling plan:\n{raw_ceiling_plan}{Fore.RESET}")
print(
f"{Fore.GREEN}AI: Here is the ceiling plan:\n{raw_ceiling_plan}{Fore.RESET}"
)

ceiling_objects = []
parsed_ceiling_plan = self.parse_ceiling_plan(raw_ceiling_plan)
for room_type, ceiling_object_description in parsed_ceiling_plan.items():
room = self.get_room_by_type(scene["rooms"], room_type)

if room is None:
print("Room type {} not found in scene.".format(room_type))
print(f"Room type {room_type} not found in scene.")
continue

ceiling_object_id = self.select_ceiling_object(ceiling_object_description)
if ceiling_object_id is None: continue
if ceiling_object_id is None:
continue

# Temporary solution: place at the center of the room
dimension = self.database[ceiling_object_id]['assetMetadata']['boundingBox']
dimension = get_bbox_dims(self.database[ceiling_object_id])

floor_polygon = Polygon(room["vertices"])
x = floor_polygon.centroid.x
z = floor_polygon.centroid.y
y = scene["wall_height"] - dimension["y"] / 2

ceiling_object = copy.deepcopy(self.json_template)
ceiling_object["assetId"] = ceiling_object_id
ceiling_object["id"] = f"ceiling ({room_type})"
ceiling_object["position"] = {"x": x, "y": y, "z": z}
ceiling_object["rotation"] = {"x": 0, "y": 0, "z": 0}
ceiling_object["roomId"] = room["id"]
ceiling_object["object_name"] = self.database[ceiling_object_id]["annotations"]["category"]
ceiling_object["object_name"] = get_annotations(
self.database[ceiling_object_id]
)["category"]
ceiling_objects.append(ceiling_object)

return raw_ceiling_plan, ceiling_objects


def parse_ceiling_plan(self, raw_ceiling_plan):
plans = [plan.lower() for plan in raw_ceiling_plan.split("\n") if "|" in plan]
parsed_plans = {}
for plan in plans:
# remove index
pattern = re.compile(r'^\d+\.\s*')
plan = pattern.sub('', plan)
if plan[-1] == ".": plan = plan[:-1] # remove the last period
pattern = re.compile(r"^\d+\.\s*")
plan = pattern.sub("", plan)
if plan[-1] == ".":
plan = plan[:-1] # remove the last period

room_type, ceiling_object_description = plan.split("|")
room_type = room_type.strip()
ceiling_object_description = ceiling_object_description.strip()
if room_type not in parsed_plans: # only consider one type of ceiling object for each room
if (
room_type not in parsed_plans
): # only consider one type of ceiling object for each room
parsed_plans[room_type] = ceiling_object_description
return parsed_plans


def get_room_by_type(self, rooms, room_type):
for room in rooms:
if room["roomType"] == room_type:
return room
return None


def select_ceiling_object(self, description):
candidates = self.object_retriever.retrieve([f"a 3D model of {description}"], threshold=29)
ceiling_candiates = [candidate for candidate in candidates if self.database[candidate[0]]["annotations"]["onCeiling"] == True]
candidates = self.object_retriever.retrieve(
[f"a 3D model of {description}"], threshold=29
)
ceiling_candiates = [
candidate
for candidate in candidates
if get_annotations(self.database[candidate[0]])["onCeiling"] == True
]

valid_ceiling_candiates = []
for candidate in ceiling_candiates:
dimension = self.database[candidate[0]]['assetMetadata']['boundingBox']
if dimension["y"] <= 1.0: valid_ceiling_candiates.append(candidate)
dimension = get_bbox_dims(self.database[candidate[0]])
if dimension["y"] <= 1.0:
valid_ceiling_candiates.append(candidate)

if len(valid_ceiling_candiates) == 0:
print("No ceiling object found for description: {}".format(description))
return None

selected_ceiling_object_id = self.random_select(valid_ceiling_candiates)[0]
return selected_ceiling_object_id


def random_select(self, candidates):
scores = [candidate[1] for candidate in candidates]
scores_tensor = torch.Tensor(scores)
probas = F.softmax(scores_tensor, dim=0) # TODO: consider using normalized scores
probas = F.softmax(
scores_tensor, dim=0
) # TODO: consider using normalized scores
selected_index = torch.multinomial(probas, 1).item()
selected_candidate = candidates[selected_index]
return selected_candidate
return selected_candidate
Loading
Loading