diff --git a/backend/app/main.py b/backend/app/main.py
index 769293d..61c02cf 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -13,10 +13,7 @@
app = FastAPI()
-origins = [
- "http://localhost:3000",
- "https://gitdiagram.com"
-]
+origins = ["http://localhost:3000", "https://gitdiagram.com"]
app.add_middleware(
CORSMiddleware,
@@ -31,8 +28,9 @@
app.add_middleware(Analytics, api_key=API_ANALYTICS_KEY)
app.state.limiter = limiter
-app.add_exception_handler(RateLimitExceeded, cast(
- ExceptionMiddleware, _rate_limit_exceeded_handler))
+app.add_exception_handler(
+ RateLimitExceeded, cast(ExceptionMiddleware, _rate_limit_exceeded_handler)
+)
app.include_router(generate.router)
app.include_router(modify.router)
diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py
index 8bac027..5a147bf 100644
--- a/backend/app/routers/generate.py
+++ b/backend/app/routers/generate.py
@@ -3,8 +3,12 @@
from app.services.github_service import GitHubService
from app.services.claude_service import ClaudeService
from app.core.limiter import limiter
-import os
-from app.prompts import SYSTEM_FIRST_PROMPT, SYSTEM_SECOND_PROMPT, SYSTEM_THIRD_PROMPT, ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
+from app.prompts import (
+ SYSTEM_FIRST_PROMPT,
+ SYSTEM_SECOND_PROMPT,
+ SYSTEM_THIRD_PROMPT,
+ ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT,
+)
from anthropic._exceptions import RateLimitError
from pydantic import BaseModel
from functools import lru_cache
@@ -29,11 +33,7 @@ def get_cached_github_data(username: str, repo: str):
file_tree = github_service.get_github_file_paths_as_list(username, repo)
readme = github_service.get_github_readme(username, repo)
- return {
- "default_branch": default_branch,
- "file_tree": file_tree,
- "readme": readme
- }
+ return {"default_branch": default_branch, "file_tree": file_tree, "readme": readme}
class ApiRequest(BaseModel):
@@ -51,7 +51,13 @@ async def generate(request: Request, body: ApiRequest):
if len(body.instructions) > 1000:
return {"error": "Instructions exceed maximum length of 1000 characters"}
- if body.repo in ["fastapi", "streamlit", "flask", "api-analytics", "monkeytype"]:
+ if body.repo in [
+ "fastapi",
+ "streamlit",
+ "flask",
+ "api-analytics",
+ "monkeytype",
+ ]:
return {"error": "Example repos cannot be regenerated"}
# Get cached github data
@@ -71,7 +77,7 @@ async def generate(request: Request, body: ApiRequest):
return {
"error": f"File tree and README combined exceeds token limit (50,000). Current size: {token_count} tokens. This GitHub repository is too large for my wallet, but you can continue by providing your own Anthropic API key.",
"token_count": token_count,
- "requires_api_key": True
+ "requires_api_key": True,
}
elif token_count > 200000:
return {
@@ -82,10 +88,12 @@ async def generate(request: Request, body: ApiRequest):
first_system_prompt = SYSTEM_FIRST_PROMPT
third_system_prompt = SYSTEM_THIRD_PROMPT
if body.instructions:
- first_system_prompt = first_system_prompt + \
- "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
- third_system_prompt = third_system_prompt + \
- "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
+ first_system_prompt = (
+ first_system_prompt + "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
+ )
+ third_system_prompt = (
+ third_system_prompt + "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
+ )
# get the explanation for sysdesign from claude
explanation = claude_service.call_claude_api(
@@ -93,9 +101,9 @@ async def generate(request: Request, body: ApiRequest):
data={
"file_tree": file_tree,
"readme": readme,
- "instructions": body.instructions
+ "instructions": body.instructions,
},
- api_key=body.api_key
+ api_key=body.api_key,
)
# Check for BAD_INSTRUCTIONS response
@@ -104,18 +112,14 @@ async def generate(request: Request, body: ApiRequest):
full_second_response = claude_service.call_claude_api(
system_prompt=SYSTEM_SECOND_PROMPT,
- data={
- "explanation": explanation,
- "file_tree": file_tree
- }
+ data={"explanation": explanation, "file_tree": file_tree},
)
# Extract component mapping from the response
start_tag = ""
end_tag = ""
component_mapping_text = full_second_response[
- full_second_response.find(start_tag):
- full_second_response.find(end_tag)
+ full_second_response.find(start_tag) : full_second_response.find(end_tag)
]
# get mermaid.js code from claude
@@ -124,8 +128,8 @@ async def generate(request: Request, body: ApiRequest):
data={
"explanation": explanation,
"component_mapping": component_mapping_text,
- "instructions": body.instructions
- }
+ "instructions": body.instructions,
+ },
)
# Check for BAD_INSTRUCTIONS response
@@ -134,18 +138,14 @@ async def generate(request: Request, body: ApiRequest):
# Process click events to include full GitHub URLs
processed_diagram = process_click_events(
- mermaid_code,
- body.username,
- body.repo,
- default_branch
+ mermaid_code, body.username, body.repo, default_branch
)
- return {"diagram": processed_diagram,
- "explanation": explanation}
+ return {"diagram": processed_diagram, "explanation": explanation}
except RateLimitError as e:
raise HTTPException(
status_code=429,
- detail="Service is currently experiencing high demand. Please try again in a few minutes."
+ detail="Service is currently experiencing high demand. Please try again in a few minutes.",
)
except Exception as e:
return {"error": str(e)}
@@ -184,12 +184,13 @@ def process_click_events(diagram: str, username: str, repo: str, branch: str) ->
Process click events in Mermaid diagram to include full GitHub URLs.
Detects if path is file or directory and uses appropriate URL format.
"""
+
def replace_path(match):
# Extract the path from the click event
- path = match.group(2).strip('"\'')
+ path = match.group(2).strip("\"'")
# Determine if path is likely a file (has extension) or directory
- is_file = '.' in path.split('/')[-1]
+ is_file = "." in path.split("/")[-1]
# Construct GitHub URL
base_url = f"https://github.com/{username}/{repo}"
diff --git a/backend/app/routers/modify.py b/backend/app/routers/modify.py
index b99f9e0..2215f9f 100644
--- a/backend/app/routers/modify.py
+++ b/backend/app/routers/modify.py
@@ -31,10 +31,18 @@ async def modify(request: Request, body: ModifyRequest):
# Check instructions length
if not body.instructions or not body.current_diagram:
return {"error": "Instructions and/or current diagram are required"}
- elif len(body.instructions) > 1000 or len(body.current_diagram) > 100000: # just being safe
+ elif (
+ len(body.instructions) > 1000 or len(body.current_diagram) > 100000
+ ): # just being safe
return {"error": "Instructions exceed maximum length of 1000 characters"}
- if body.repo in ["fastapi", "streamlit", "flask", "api-analytics", "monkeytype"]:
+ if body.repo in [
+ "fastapi",
+ "streamlit",
+ "flask",
+ "api-analytics",
+ "monkeytype",
+ ]:
return {"error": "Example repos cannot be modified"}
modified_mermaid_code = claude_service.call_claude_api(
@@ -42,8 +50,8 @@ async def modify(request: Request, body: ModifyRequest):
data={
"instructions": body.instructions,
"explanation": body.explanation,
- "diagram": body.current_diagram
- }
+ "diagram": body.current_diagram,
+ },
)
# Check for BAD_INSTRUCTIONS response
@@ -54,7 +62,7 @@ async def modify(request: Request, body: ModifyRequest):
except RateLimitError as e:
raise HTTPException(
status_code=429,
- detail="Service is currently experiencing high demand. Please try again in a few minutes."
+ detail="Service is currently experiencing high demand. Please try again in a few minutes.",
)
except Exception as e:
return {"error": str(e)}
diff --git a/backend/app/services/claude_service.py b/backend/app/services/claude_service.py
index 9cc39af..d7c72ae 100644
--- a/backend/app/services/claude_service.py
+++ b/backend/app/services/claude_service.py
@@ -8,7 +8,9 @@ class ClaudeService:
def __init__(self):
self.default_client = Anthropic()
- def call_claude_api(self, system_prompt: str, data: dict, api_key: str | None = None) -> str:
+ def call_claude_api(
+ self, system_prompt: str, data: dict, api_key: str | None = None
+ ) -> str:
"""
Makes an API call to Claude and returns the response.
@@ -32,40 +34,30 @@ def call_claude_api(self, system_prompt: str, data: dict, api_key: str | None =
temperature=0,
system=system_prompt,
messages=[
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": user_message
- }
- ]
- }
- ]
+ {"role": "user", "content": [{"type": "text", "text": user_message}]}
+ ],
)
return message.content[0].text # type: ignore
- # autopep8: off
def _format_user_message(self, data: dict[str, str]) -> str:
"""Helper method to format the data into a user message"""
parts = []
for key, value in data.items():
- if key == 'file_tree':
+ if key == "file_tree":
parts.append(f"\n{value}\n")
- elif key == 'readme':
+ elif key == "readme":
parts.append(f"\n{value}\n")
- elif key == 'explanation':
+ elif key == "explanation":
parts.append(f"\n{value}\n")
- elif key == 'component_mapping':
+ elif key == "component_mapping":
parts.append(f"\n{value}\n")
- elif key == 'instructions' and value != "":
+ elif key == "instructions" and value != "":
parts.append(f"\n{value}\n")
- elif key == 'diagram':
+ elif key == "diagram":
parts.append(f"\n{value}\n")
- elif key == 'explanation':
+ elif key == "explanation":
parts.append(f"\n{value}\n")
return "\n\n".join(parts)
- # autopep8: on
def count_tokens(self, prompt: str) -> int:
"""
@@ -79,9 +71,6 @@ def count_tokens(self, prompt: str) -> int:
"""
response = self.default_client.messages.count_tokens(
model="claude-3-5-sonnet-latest",
- messages=[{
- "role": "user",
- "content": prompt
- }]
+ messages=[{"role": "user", "content": prompt}],
)
return response.input_tokens
diff --git a/backend/app/services/github_service.py b/backend/app/services/github_service.py
index f46af6a..7b1b7ba 100644
--- a/backend/app/services/github_service.py
+++ b/backend/app/services/github_service.py
@@ -19,8 +19,13 @@ def __init__(self):
self.github_token = os.getenv("GITHUB_PAT")
# If no credentials are provided, warn about rate limits
- if not all([self.client_id, self.private_key, self.installation_id]) and not self.github_token:
- print("\033[93mWarning: No GitHub credentials provided. Using unauthenticated requests with rate limit of 60 requests/hour.\033[0m")
+ if (
+ not all([self.client_id, self.private_key, self.installation_id])
+ and not self.github_token
+ ):
+ print(
+ "\033[93mWarning: No GitHub credentials provided. Using unauthenticated requests with rate limit of 60 requests/hour.\033[0m"
+ )
self.access_token = None
self.token_expires_at = None
@@ -31,10 +36,11 @@ def _generate_jwt(self):
payload = {
"iat": now,
"exp": now + (10 * 60), # 10 minutes
- "iss": self.client_id
+ "iss": self.client_id,
}
# Convert PEM string format to proper newlines
return jwt.encode(payload, self.private_key, algorithm="RS256") # type: ignore
+
# autopep8: on
def _get_installation_token(self):
@@ -47,8 +53,8 @@ def _get_installation_token(self):
self.installation_id}/access_tokens",
headers={
"Authorization": f"Bearer {jwt_token}",
- "Accept": "application/vnd.github+json"
- }
+ "Accept": "application/vnd.github+json",
+ },
)
data = response.json()
self.access_token = data["token"]
@@ -57,16 +63,17 @@ def _get_installation_token(self):
def _get_headers(self):
# If no credentials are available, return basic headers
- if not all([self.client_id, self.private_key, self.installation_id]) and not self.github_token:
- return {
- "Accept": "application/vnd.github+json"
- }
+ if (
+ not all([self.client_id, self.private_key, self.installation_id])
+ and not self.github_token
+ ):
+ return {"Accept": "application/vnd.github+json"}
# Use PAT if available
if self.github_token:
return {
"Authorization": f"token {self.github_token}",
- "Accept": "application/vnd.github+json"
+ "Accept": "application/vnd.github+json",
}
# Otherwise use app authentication
@@ -74,7 +81,7 @@ def _get_headers(self):
return {
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
- "X-GitHub-Api-Version": "2022-11-28"
+ "X-GitHub-Api-Version": "2022-11-28",
}
def _check_repository_exists(self, username, repo):
@@ -87,15 +94,17 @@ def _check_repository_exists(self, username, repo):
if response.status_code == 404:
raise ValueError("Repository not found.")
elif response.status_code != 200:
- raise Exception(f"Failed to check repository: {response.status_code}, {response.json()}")
-
+ raise Exception(
+ f"Failed to check repository: {response.status_code}, {response.json()}"
+ )
+
def get_default_branch(self, username, repo):
"""Get the default branch of the repository."""
api_url = f"https://api.github.com/repos/{username}/{repo}"
response = requests.get(api_url, headers=self._get_headers())
if response.status_code == 200:
- return response.json().get('default_branch')
+ return response.json().get("default_branch")
return None
def get_github_file_paths_as_list(self, username, repo):
@@ -110,21 +119,43 @@ def get_github_file_paths_as_list(self, username, repo):
Returns:
str: A filtered and formatted string of file paths in the repository, one per line.
"""
+
def should_include_file(path):
# Patterns to exclude
excluded_patterns = [
# Dependencies
- 'node_modules/', 'vendor/', 'venv/',
+ "node_modules/",
+ "vendor/",
+ "venv/",
# Compiled files
- '.min.', '.pyc', '.pyo', '.pyd', '.so', '.dll', '.class',
+ ".min.",
+ ".pyc",
+ ".pyo",
+ ".pyd",
+ ".so",
+ ".dll",
+ ".class",
# Asset files
- '.jpg', '.jpeg', '.png', '.gif', '.ico', '.svg', '.ttf', '.woff', '.webp',
+ ".jpg",
+ ".jpeg",
+ ".png",
+ ".gif",
+ ".ico",
+ ".svg",
+ ".ttf",
+ ".woff",
+ ".webp",
# Cache and temporary files
- '__pycache__/', '.cache/', '.tmp/',
+ "__pycache__/",
+ ".cache/",
+ ".tmp/",
# Lock files and logs
- 'yarn.lock', 'poetry.lock', '*.log',
+ "yarn.lock",
+ "poetry.lock",
+ "*.log",
# Configuration files
- '.vscode/', '.idea/'
+ ".vscode/",
+ ".idea/",
]
return not any(pattern in path.lower() for pattern in excluded_patterns)
@@ -140,12 +171,15 @@ def should_include_file(path):
data = response.json()
if "tree" in data:
# Filter the paths and join them with newlines
- paths = [item['path'] for item in data['tree']
- if should_include_file(item['path'])]
+ paths = [
+ item["path"]
+ for item in data["tree"]
+ if should_include_file(item["path"])
+ ]
return "\n".join(paths)
# If default branch didn't work or wasn't found, try common branch names
- for branch in ['main', 'master']:
+ for branch in ["main", "master"]:
api_url = f"https://api.github.com/repos/{
username}/{repo}/git/trees/{branch}?recursive=1"
response = requests.get(api_url, headers=self._get_headers())
@@ -154,12 +188,16 @@ def should_include_file(path):
data = response.json()
if "tree" in data:
# Filter the paths and join them with newlines
- paths = [item['path'] for item in data['tree']
- if should_include_file(item['path'])]
+ paths = [
+ item["path"]
+ for item in data["tree"]
+ if should_include_file(item["path"])
+ ]
return "\n".join(paths)
raise ValueError(
- "Could not fetch repository file tree. Repository might not exist, be empty or private.")
+ "Could not fetch repository file tree. Repository might not exist, be empty or private."
+ )
def get_github_readme(self, username, repo):
"""
@@ -186,9 +224,11 @@ def get_github_readme(self, username, repo):
if response.status_code == 404:
raise ValueError("No README found for the specified repository.")
elif response.status_code != 200:
- raise Exception(f"Failed to fetch README: {
- response.status_code}, {response.json()}")
+ raise Exception(
+ f"Failed to fetch README: {
+ response.status_code}, {response.json()}"
+ )
data = response.json()
- readme_content = requests.get(data['download_url']).text
+ readme_content = requests.get(data["download_url"]).text
return readme_content
diff --git a/src/app/[username]/[repo]/page.tsx b/src/app/[username]/[repo]/page.tsx
index f8091ad..31a2928 100644
--- a/src/app/[username]/[repo]/page.tsx
+++ b/src/app/[username]/[repo]/page.tsx
@@ -6,7 +6,6 @@ import Loading from "~/components/loading";
import MermaidChart from "~/components/mermaid-diagram";
import { useDiagram } from "~/hooks/useDiagram";
import { ApiKeyDialog } from "~/components/api-key-dialog";
-import { Button } from "~/components/ui/button";
import { ApiKeyButton } from "~/components/api-key-button";
export default function Repo() {