From 8998b44d65a0c2410e05f729f1bec163049b1e6f Mon Sep 17 00:00:00 2001 From: Ahmed Khaleel Date: Thu, 16 Jan 2025 22:02:58 -0500 Subject: [PATCH] use black formatter --- backend/app/main.py | 10 ++- backend/app/routers/generate.py | 65 ++++++++--------- backend/app/routers/modify.py | 18 +++-- backend/app/services/claude_service.py | 37 ++++------ backend/app/services/github_service.py | 98 ++++++++++++++++++-------- src/app/[username]/[repo]/page.tsx | 1 - 6 files changed, 132 insertions(+), 97 deletions(-) diff --git a/backend/app/main.py b/backend/app/main.py index 769293d..61c02cf 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -13,10 +13,7 @@ app = FastAPI() -origins = [ - "http://localhost:3000", - "https://gitdiagram.com" -] +origins = ["http://localhost:3000", "https://gitdiagram.com"] app.add_middleware( CORSMiddleware, @@ -31,8 +28,9 @@ app.add_middleware(Analytics, api_key=API_ANALYTICS_KEY) app.state.limiter = limiter -app.add_exception_handler(RateLimitExceeded, cast( - ExceptionMiddleware, _rate_limit_exceeded_handler)) +app.add_exception_handler( + RateLimitExceeded, cast(ExceptionMiddleware, _rate_limit_exceeded_handler) +) app.include_router(generate.router) app.include_router(modify.router) diff --git a/backend/app/routers/generate.py b/backend/app/routers/generate.py index 8bac027..5a147bf 100644 --- a/backend/app/routers/generate.py +++ b/backend/app/routers/generate.py @@ -3,8 +3,12 @@ from app.services.github_service import GitHubService from app.services.claude_service import ClaudeService from app.core.limiter import limiter -import os -from app.prompts import SYSTEM_FIRST_PROMPT, SYSTEM_SECOND_PROMPT, SYSTEM_THIRD_PROMPT, ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT +from app.prompts import ( + SYSTEM_FIRST_PROMPT, + SYSTEM_SECOND_PROMPT, + SYSTEM_THIRD_PROMPT, + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT, +) from anthropic._exceptions import RateLimitError from pydantic import BaseModel from functools import lru_cache @@ -29,11 +33,7 @@ def get_cached_github_data(username: str, repo: str): file_tree = github_service.get_github_file_paths_as_list(username, repo) readme = github_service.get_github_readme(username, repo) - return { - "default_branch": default_branch, - "file_tree": file_tree, - "readme": readme - } + return {"default_branch": default_branch, "file_tree": file_tree, "readme": readme} class ApiRequest(BaseModel): @@ -51,7 +51,13 @@ async def generate(request: Request, body: ApiRequest): if len(body.instructions) > 1000: return {"error": "Instructions exceed maximum length of 1000 characters"} - if body.repo in ["fastapi", "streamlit", "flask", "api-analytics", "monkeytype"]: + if body.repo in [ + "fastapi", + "streamlit", + "flask", + "api-analytics", + "monkeytype", + ]: return {"error": "Example repos cannot be regenerated"} # Get cached github data @@ -71,7 +77,7 @@ async def generate(request: Request, body: ApiRequest): return { "error": f"File tree and README combined exceeds token limit (50,000). Current size: {token_count} tokens. This GitHub repository is too large for my wallet, but you can continue by providing your own Anthropic API key.", "token_count": token_count, - "requires_api_key": True + "requires_api_key": True, } elif token_count > 200000: return { @@ -82,10 +88,12 @@ async def generate(request: Request, body: ApiRequest): first_system_prompt = SYSTEM_FIRST_PROMPT third_system_prompt = SYSTEM_THIRD_PROMPT if body.instructions: - first_system_prompt = first_system_prompt + \ - "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT - third_system_prompt = third_system_prompt + \ - "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT + first_system_prompt = ( + first_system_prompt + "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT + ) + third_system_prompt = ( + third_system_prompt + "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT + ) # get the explanation for sysdesign from claude explanation = claude_service.call_claude_api( @@ -93,9 +101,9 @@ async def generate(request: Request, body: ApiRequest): data={ "file_tree": file_tree, "readme": readme, - "instructions": body.instructions + "instructions": body.instructions, }, - api_key=body.api_key + api_key=body.api_key, ) # Check for BAD_INSTRUCTIONS response @@ -104,18 +112,14 @@ async def generate(request: Request, body: ApiRequest): full_second_response = claude_service.call_claude_api( system_prompt=SYSTEM_SECOND_PROMPT, - data={ - "explanation": explanation, - "file_tree": file_tree - } + data={"explanation": explanation, "file_tree": file_tree}, ) # Extract component mapping from the response start_tag = "" end_tag = "" component_mapping_text = full_second_response[ - full_second_response.find(start_tag): - full_second_response.find(end_tag) + full_second_response.find(start_tag) : full_second_response.find(end_tag) ] # get mermaid.js code from claude @@ -124,8 +128,8 @@ async def generate(request: Request, body: ApiRequest): data={ "explanation": explanation, "component_mapping": component_mapping_text, - "instructions": body.instructions - } + "instructions": body.instructions, + }, ) # Check for BAD_INSTRUCTIONS response @@ -134,18 +138,14 @@ async def generate(request: Request, body: ApiRequest): # Process click events to include full GitHub URLs processed_diagram = process_click_events( - mermaid_code, - body.username, - body.repo, - default_branch + mermaid_code, body.username, body.repo, default_branch ) - return {"diagram": processed_diagram, - "explanation": explanation} + return {"diagram": processed_diagram, "explanation": explanation} except RateLimitError as e: raise HTTPException( status_code=429, - detail="Service is currently experiencing high demand. Please try again in a few minutes." + detail="Service is currently experiencing high demand. Please try again in a few minutes.", ) except Exception as e: return {"error": str(e)} @@ -184,12 +184,13 @@ def process_click_events(diagram: str, username: str, repo: str, branch: str) -> Process click events in Mermaid diagram to include full GitHub URLs. Detects if path is file or directory and uses appropriate URL format. """ + def replace_path(match): # Extract the path from the click event - path = match.group(2).strip('"\'') + path = match.group(2).strip("\"'") # Determine if path is likely a file (has extension) or directory - is_file = '.' in path.split('/')[-1] + is_file = "." in path.split("/")[-1] # Construct GitHub URL base_url = f"https://github.com/{username}/{repo}" diff --git a/backend/app/routers/modify.py b/backend/app/routers/modify.py index b99f9e0..2215f9f 100644 --- a/backend/app/routers/modify.py +++ b/backend/app/routers/modify.py @@ -31,10 +31,18 @@ async def modify(request: Request, body: ModifyRequest): # Check instructions length if not body.instructions or not body.current_diagram: return {"error": "Instructions and/or current diagram are required"} - elif len(body.instructions) > 1000 or len(body.current_diagram) > 100000: # just being safe + elif ( + len(body.instructions) > 1000 or len(body.current_diagram) > 100000 + ): # just being safe return {"error": "Instructions exceed maximum length of 1000 characters"} - if body.repo in ["fastapi", "streamlit", "flask", "api-analytics", "monkeytype"]: + if body.repo in [ + "fastapi", + "streamlit", + "flask", + "api-analytics", + "monkeytype", + ]: return {"error": "Example repos cannot be modified"} modified_mermaid_code = claude_service.call_claude_api( @@ -42,8 +50,8 @@ async def modify(request: Request, body: ModifyRequest): data={ "instructions": body.instructions, "explanation": body.explanation, - "diagram": body.current_diagram - } + "diagram": body.current_diagram, + }, ) # Check for BAD_INSTRUCTIONS response @@ -54,7 +62,7 @@ async def modify(request: Request, body: ModifyRequest): except RateLimitError as e: raise HTTPException( status_code=429, - detail="Service is currently experiencing high demand. Please try again in a few minutes." + detail="Service is currently experiencing high demand. Please try again in a few minutes.", ) except Exception as e: return {"error": str(e)} diff --git a/backend/app/services/claude_service.py b/backend/app/services/claude_service.py index 9cc39af..d7c72ae 100644 --- a/backend/app/services/claude_service.py +++ b/backend/app/services/claude_service.py @@ -8,7 +8,9 @@ class ClaudeService: def __init__(self): self.default_client = Anthropic() - def call_claude_api(self, system_prompt: str, data: dict, api_key: str | None = None) -> str: + def call_claude_api( + self, system_prompt: str, data: dict, api_key: str | None = None + ) -> str: """ Makes an API call to Claude and returns the response. @@ -32,40 +34,30 @@ def call_claude_api(self, system_prompt: str, data: dict, api_key: str | None = temperature=0, system=system_prompt, messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": user_message - } - ] - } - ] + {"role": "user", "content": [{"type": "text", "text": user_message}]} + ], ) return message.content[0].text # type: ignore - # autopep8: off def _format_user_message(self, data: dict[str, str]) -> str: """Helper method to format the data into a user message""" parts = [] for key, value in data.items(): - if key == 'file_tree': + if key == "file_tree": parts.append(f"\n{value}\n") - elif key == 'readme': + elif key == "readme": parts.append(f"\n{value}\n") - elif key == 'explanation': + elif key == "explanation": parts.append(f"\n{value}\n") - elif key == 'component_mapping': + elif key == "component_mapping": parts.append(f"\n{value}\n") - elif key == 'instructions' and value != "": + elif key == "instructions" and value != "": parts.append(f"\n{value}\n") - elif key == 'diagram': + elif key == "diagram": parts.append(f"\n{value}\n") - elif key == 'explanation': + elif key == "explanation": parts.append(f"\n{value}\n") return "\n\n".join(parts) - # autopep8: on def count_tokens(self, prompt: str) -> int: """ @@ -79,9 +71,6 @@ def count_tokens(self, prompt: str) -> int: """ response = self.default_client.messages.count_tokens( model="claude-3-5-sonnet-latest", - messages=[{ - "role": "user", - "content": prompt - }] + messages=[{"role": "user", "content": prompt}], ) return response.input_tokens diff --git a/backend/app/services/github_service.py b/backend/app/services/github_service.py index f46af6a..7b1b7ba 100644 --- a/backend/app/services/github_service.py +++ b/backend/app/services/github_service.py @@ -19,8 +19,13 @@ def __init__(self): self.github_token = os.getenv("GITHUB_PAT") # If no credentials are provided, warn about rate limits - if not all([self.client_id, self.private_key, self.installation_id]) and not self.github_token: - print("\033[93mWarning: No GitHub credentials provided. Using unauthenticated requests with rate limit of 60 requests/hour.\033[0m") + if ( + not all([self.client_id, self.private_key, self.installation_id]) + and not self.github_token + ): + print( + "\033[93mWarning: No GitHub credentials provided. Using unauthenticated requests with rate limit of 60 requests/hour.\033[0m" + ) self.access_token = None self.token_expires_at = None @@ -31,10 +36,11 @@ def _generate_jwt(self): payload = { "iat": now, "exp": now + (10 * 60), # 10 minutes - "iss": self.client_id + "iss": self.client_id, } # Convert PEM string format to proper newlines return jwt.encode(payload, self.private_key, algorithm="RS256") # type: ignore + # autopep8: on def _get_installation_token(self): @@ -47,8 +53,8 @@ def _get_installation_token(self): self.installation_id}/access_tokens", headers={ "Authorization": f"Bearer {jwt_token}", - "Accept": "application/vnd.github+json" - } + "Accept": "application/vnd.github+json", + }, ) data = response.json() self.access_token = data["token"] @@ -57,16 +63,17 @@ def _get_installation_token(self): def _get_headers(self): # If no credentials are available, return basic headers - if not all([self.client_id, self.private_key, self.installation_id]) and not self.github_token: - return { - "Accept": "application/vnd.github+json" - } + if ( + not all([self.client_id, self.private_key, self.installation_id]) + and not self.github_token + ): + return {"Accept": "application/vnd.github+json"} # Use PAT if available if self.github_token: return { "Authorization": f"token {self.github_token}", - "Accept": "application/vnd.github+json" + "Accept": "application/vnd.github+json", } # Otherwise use app authentication @@ -74,7 +81,7 @@ def _get_headers(self): return { "Authorization": f"Bearer {token}", "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28" + "X-GitHub-Api-Version": "2022-11-28", } def _check_repository_exists(self, username, repo): @@ -87,15 +94,17 @@ def _check_repository_exists(self, username, repo): if response.status_code == 404: raise ValueError("Repository not found.") elif response.status_code != 200: - raise Exception(f"Failed to check repository: {response.status_code}, {response.json()}") - + raise Exception( + f"Failed to check repository: {response.status_code}, {response.json()}" + ) + def get_default_branch(self, username, repo): """Get the default branch of the repository.""" api_url = f"https://api.github.com/repos/{username}/{repo}" response = requests.get(api_url, headers=self._get_headers()) if response.status_code == 200: - return response.json().get('default_branch') + return response.json().get("default_branch") return None def get_github_file_paths_as_list(self, username, repo): @@ -110,21 +119,43 @@ def get_github_file_paths_as_list(self, username, repo): Returns: str: A filtered and formatted string of file paths in the repository, one per line. """ + def should_include_file(path): # Patterns to exclude excluded_patterns = [ # Dependencies - 'node_modules/', 'vendor/', 'venv/', + "node_modules/", + "vendor/", + "venv/", # Compiled files - '.min.', '.pyc', '.pyo', '.pyd', '.so', '.dll', '.class', + ".min.", + ".pyc", + ".pyo", + ".pyd", + ".so", + ".dll", + ".class", # Asset files - '.jpg', '.jpeg', '.png', '.gif', '.ico', '.svg', '.ttf', '.woff', '.webp', + ".jpg", + ".jpeg", + ".png", + ".gif", + ".ico", + ".svg", + ".ttf", + ".woff", + ".webp", # Cache and temporary files - '__pycache__/', '.cache/', '.tmp/', + "__pycache__/", + ".cache/", + ".tmp/", # Lock files and logs - 'yarn.lock', 'poetry.lock', '*.log', + "yarn.lock", + "poetry.lock", + "*.log", # Configuration files - '.vscode/', '.idea/' + ".vscode/", + ".idea/", ] return not any(pattern in path.lower() for pattern in excluded_patterns) @@ -140,12 +171,15 @@ def should_include_file(path): data = response.json() if "tree" in data: # Filter the paths and join them with newlines - paths = [item['path'] for item in data['tree'] - if should_include_file(item['path'])] + paths = [ + item["path"] + for item in data["tree"] + if should_include_file(item["path"]) + ] return "\n".join(paths) # If default branch didn't work or wasn't found, try common branch names - for branch in ['main', 'master']: + for branch in ["main", "master"]: api_url = f"https://api.github.com/repos/{ username}/{repo}/git/trees/{branch}?recursive=1" response = requests.get(api_url, headers=self._get_headers()) @@ -154,12 +188,16 @@ def should_include_file(path): data = response.json() if "tree" in data: # Filter the paths and join them with newlines - paths = [item['path'] for item in data['tree'] - if should_include_file(item['path'])] + paths = [ + item["path"] + for item in data["tree"] + if should_include_file(item["path"]) + ] return "\n".join(paths) raise ValueError( - "Could not fetch repository file tree. Repository might not exist, be empty or private.") + "Could not fetch repository file tree. Repository might not exist, be empty or private." + ) def get_github_readme(self, username, repo): """ @@ -186,9 +224,11 @@ def get_github_readme(self, username, repo): if response.status_code == 404: raise ValueError("No README found for the specified repository.") elif response.status_code != 200: - raise Exception(f"Failed to fetch README: { - response.status_code}, {response.json()}") + raise Exception( + f"Failed to fetch README: { + response.status_code}, {response.json()}" + ) data = response.json() - readme_content = requests.get(data['download_url']).text + readme_content = requests.get(data["download_url"]).text return readme_content diff --git a/src/app/[username]/[repo]/page.tsx b/src/app/[username]/[repo]/page.tsx index f8091ad..31a2928 100644 --- a/src/app/[username]/[repo]/page.tsx +++ b/src/app/[username]/[repo]/page.tsx @@ -6,7 +6,6 @@ import Loading from "~/components/loading"; import MermaidChart from "~/components/mermaid-diagram"; import { useDiagram } from "~/hooks/useDiagram"; import { ApiKeyDialog } from "~/components/api-key-dialog"; -import { Button } from "~/components/ui/button"; import { ApiKeyButton } from "~/components/api-key-button"; export default function Repo() {