Skip to content

Commit

Permalink
use black formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedkhaleel2004 committed Jan 17, 2025
1 parent 57b5adb commit 8998b44
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 97 deletions.
10 changes: 4 additions & 6 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
app = FastAPI()


origins = [
"http://localhost:3000",
"https://gitdiagram.com"
]
origins = ["http://localhost:3000", "https://gitdiagram.com"]

app.add_middleware(
CORSMiddleware,
Expand All @@ -31,8 +28,9 @@
app.add_middleware(Analytics, api_key=API_ANALYTICS_KEY)

app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, cast(
ExceptionMiddleware, _rate_limit_exceeded_handler))
app.add_exception_handler(
RateLimitExceeded, cast(ExceptionMiddleware, _rate_limit_exceeded_handler)
)

app.include_router(generate.router)
app.include_router(modify.router)
Expand Down
65 changes: 33 additions & 32 deletions backend/app/routers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from app.services.github_service import GitHubService
from app.services.claude_service import ClaudeService
from app.core.limiter import limiter
import os
from app.prompts import SYSTEM_FIRST_PROMPT, SYSTEM_SECOND_PROMPT, SYSTEM_THIRD_PROMPT, ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
from app.prompts import (
SYSTEM_FIRST_PROMPT,
SYSTEM_SECOND_PROMPT,
SYSTEM_THIRD_PROMPT,
ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT,
)
from anthropic._exceptions import RateLimitError
from pydantic import BaseModel
from functools import lru_cache
Expand All @@ -29,11 +33,7 @@ def get_cached_github_data(username: str, repo: str):
file_tree = github_service.get_github_file_paths_as_list(username, repo)
readme = github_service.get_github_readme(username, repo)

return {
"default_branch": default_branch,
"file_tree": file_tree,
"readme": readme
}
return {"default_branch": default_branch, "file_tree": file_tree, "readme": readme}


class ApiRequest(BaseModel):
Expand All @@ -51,7 +51,13 @@ async def generate(request: Request, body: ApiRequest):
if len(body.instructions) > 1000:
return {"error": "Instructions exceed maximum length of 1000 characters"}

if body.repo in ["fastapi", "streamlit", "flask", "api-analytics", "monkeytype"]:
if body.repo in [
"fastapi",
"streamlit",
"flask",
"api-analytics",
"monkeytype",
]:
return {"error": "Example repos cannot be regenerated"}

# Get cached github data
Expand All @@ -71,7 +77,7 @@ async def generate(request: Request, body: ApiRequest):
return {
"error": f"File tree and README combined exceeds token limit (50,000). Current size: {token_count} tokens. This GitHub repository is too large for my wallet, but you can continue by providing your own Anthropic API key.",
"token_count": token_count,
"requires_api_key": True
"requires_api_key": True,
}
elif token_count > 200000:
return {
Expand All @@ -82,20 +88,22 @@ async def generate(request: Request, body: ApiRequest):
first_system_prompt = SYSTEM_FIRST_PROMPT
third_system_prompt = SYSTEM_THIRD_PROMPT
if body.instructions:
first_system_prompt = first_system_prompt + \
"\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
third_system_prompt = third_system_prompt + \
"\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
first_system_prompt = (
first_system_prompt + "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
)
third_system_prompt = (
third_system_prompt + "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
)

# get the explanation for sysdesign from claude
explanation = claude_service.call_claude_api(
system_prompt=first_system_prompt,
data={
"file_tree": file_tree,
"readme": readme,
"instructions": body.instructions
"instructions": body.instructions,
},
api_key=body.api_key
api_key=body.api_key,
)

# Check for BAD_INSTRUCTIONS response
Expand All @@ -104,18 +112,14 @@ async def generate(request: Request, body: ApiRequest):

full_second_response = claude_service.call_claude_api(
system_prompt=SYSTEM_SECOND_PROMPT,
data={
"explanation": explanation,
"file_tree": file_tree
}
data={"explanation": explanation, "file_tree": file_tree},
)

# Extract component mapping from the response
start_tag = "<component_mapping>"
end_tag = "</component_mapping>"
component_mapping_text = full_second_response[
full_second_response.find(start_tag):
full_second_response.find(end_tag)
full_second_response.find(start_tag) : full_second_response.find(end_tag)
]

# get mermaid.js code from claude
Expand All @@ -124,8 +128,8 @@ async def generate(request: Request, body: ApiRequest):
data={
"explanation": explanation,
"component_mapping": component_mapping_text,
"instructions": body.instructions
}
"instructions": body.instructions,
},
)

# Check for BAD_INSTRUCTIONS response
Expand All @@ -134,18 +138,14 @@ async def generate(request: Request, body: ApiRequest):

# Process click events to include full GitHub URLs
processed_diagram = process_click_events(
mermaid_code,
body.username,
body.repo,
default_branch
mermaid_code, body.username, body.repo, default_branch
)

return {"diagram": processed_diagram,
"explanation": explanation}
return {"diagram": processed_diagram, "explanation": explanation}
except RateLimitError as e:
raise HTTPException(
status_code=429,
detail="Service is currently experiencing high demand. Please try again in a few minutes."
detail="Service is currently experiencing high demand. Please try again in a few minutes.",
)
except Exception as e:
return {"error": str(e)}
Expand Down Expand Up @@ -184,12 +184,13 @@ def process_click_events(diagram: str, username: str, repo: str, branch: str) ->
Process click events in Mermaid diagram to include full GitHub URLs.
Detects if path is file or directory and uses appropriate URL format.
"""

def replace_path(match):
# Extract the path from the click event
path = match.group(2).strip('"\'')
path = match.group(2).strip("\"'")

# Determine if path is likely a file (has extension) or directory
is_file = '.' in path.split('/')[-1]
is_file = "." in path.split("/")[-1]

# Construct GitHub URL
base_url = f"https://github.com/{username}/{repo}"
Expand Down
18 changes: 13 additions & 5 deletions backend/app/routers/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,27 @@ async def modify(request: Request, body: ModifyRequest):
# Check instructions length
if not body.instructions or not body.current_diagram:
return {"error": "Instructions and/or current diagram are required"}
elif len(body.instructions) > 1000 or len(body.current_diagram) > 100000: # just being safe
elif (
len(body.instructions) > 1000 or len(body.current_diagram) > 100000
): # just being safe
return {"error": "Instructions exceed maximum length of 1000 characters"}

if body.repo in ["fastapi", "streamlit", "flask", "api-analytics", "monkeytype"]:
if body.repo in [
"fastapi",
"streamlit",
"flask",
"api-analytics",
"monkeytype",
]:
return {"error": "Example repos cannot be modified"}

modified_mermaid_code = claude_service.call_claude_api(
system_prompt=SYSTEM_MODIFY_PROMPT,
data={
"instructions": body.instructions,
"explanation": body.explanation,
"diagram": body.current_diagram
}
"diagram": body.current_diagram,
},
)

# Check for BAD_INSTRUCTIONS response
Expand All @@ -54,7 +62,7 @@ async def modify(request: Request, body: ModifyRequest):
except RateLimitError as e:
raise HTTPException(
status_code=429,
detail="Service is currently experiencing high demand. Please try again in a few minutes."
detail="Service is currently experiencing high demand. Please try again in a few minutes.",
)
except Exception as e:
return {"error": str(e)}
37 changes: 13 additions & 24 deletions backend/app/services/claude_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ class ClaudeService:
def __init__(self):
self.default_client = Anthropic()

def call_claude_api(self, system_prompt: str, data: dict, api_key: str | None = None) -> str:
def call_claude_api(
self, system_prompt: str, data: dict, api_key: str | None = None
) -> str:
"""
Makes an API call to Claude and returns the response.
Expand All @@ -32,40 +34,30 @@ def call_claude_api(self, system_prompt: str, data: dict, api_key: str | None =
temperature=0,
system=system_prompt,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": user_message
}
]
}
]
{"role": "user", "content": [{"type": "text", "text": user_message}]}
],
)
return message.content[0].text # type: ignore

# autopep8: off
def _format_user_message(self, data: dict[str, str]) -> str:
"""Helper method to format the data into a user message"""
parts = []
for key, value in data.items():
if key == 'file_tree':
if key == "file_tree":
parts.append(f"<file_tree>\n{value}\n</file_tree>")
elif key == 'readme':
elif key == "readme":
parts.append(f"<readme>\n{value}\n</readme>")
elif key == 'explanation':
elif key == "explanation":
parts.append(f"<explanation>\n{value}\n</explanation>")
elif key == 'component_mapping':
elif key == "component_mapping":
parts.append(f"<component_mapping>\n{value}\n</component_mapping>")
elif key == 'instructions' and value != "":
elif key == "instructions" and value != "":
parts.append(f"<instructions>\n{value}\n</instructions>")
elif key == 'diagram':
elif key == "diagram":
parts.append(f"<diagram>\n{value}\n</diagram>")
elif key == 'explanation':
elif key == "explanation":
parts.append(f"<explanation>\n{value}\n</explanation>")
return "\n\n".join(parts)
# autopep8: on

def count_tokens(self, prompt: str) -> int:
"""
Expand All @@ -79,9 +71,6 @@ def count_tokens(self, prompt: str) -> int:
"""
response = self.default_client.messages.count_tokens(
model="claude-3-5-sonnet-latest",
messages=[{
"role": "user",
"content": prompt
}]
messages=[{"role": "user", "content": prompt}],
)
return response.input_tokens
Loading

0 comments on commit 8998b44

Please sign in to comment.