From db59bd059718ccd8219297d6986cb271e9ed0211 Mon Sep 17 00:00:00 2001 From: z3z1ma Date: Thu, 2 Jan 2025 20:20:23 -0700 Subject: [PATCH] feat: ensure we always topologically sort operations in _iter, some small synthesis feature polish --- pyproject.toml | 2 +- src/dbt_osmosis/core/llm.py | 72 +++++++++++++++++++ src/dbt_osmosis/core/osmosis.py | 122 +++++++++++++++++++++++++++----- uv.lock | 2 +- 4 files changed, 180 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 32c483a..7ca0de5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "dbt-osmosis" -version = "1.1.2" +version = "1.1.3" description = "A dbt utility for managing YAML to make developing with dbt more delightful." readme = "README.md" license = { text = "Apache-2.0" } diff --git a/src/dbt_osmosis/core/llm.py b/src/dbt_osmosis/core/llm.py index 9fada9b..52c587d 100644 --- a/src/dbt_osmosis/core/llm.py +++ b/src/dbt_osmosis/core/llm.py @@ -12,6 +12,7 @@ __all__ = [ "generate_model_spec_as_json", "generate_column_doc", + "generate_table_doc", ] @@ -123,6 +124,45 @@ def _create_llm_prompt_for_column( ] +def _create_llm_prompt_for_table( + sql_content: str, table_name: str, upstream_docs: list[str] | None = None +) -> list[dict[str, t.Any]]: + """Builds a system + user prompt instructing the model to produce a string description describing a single model.""" + if upstream_docs is None: + upstream_docs = [] + + system_prompt = dedent(f""" + You are a helpful SQL Developer and an Expert in dbt. + Your job is to produce a concise documentation string + for a table named {table_name}. + + IMPORTANT RULES: + 1. DO NOT output extra commentary or Markdown fences. + 2. Provide only the column description text, nothing else. + 3. If upstream docs exist, you may incorporate them. If none exist, + a short placeholder is acceptable. + 4. Avoid speculation. Keep it short and relevant. + """) + + user_message = dedent(f""" + The SQL for the model is: + + >>> SQL CODE START + {sql_content} + >>> SQL CODE END + + The upstream documentation is: + {os.linesep.join(upstream_docs)} + + Please return only the text suitable for the "description" field. + """) + + return [ + {"role": "system", "content": system_prompt.strip()}, + {"role": "user", "content": user_message.strip()}, + ] + + def generate_model_spec_as_json( sql_content: str, upstream_docs: list[str] | None = None, @@ -207,6 +247,38 @@ def generate_column_doc( return content.strip() +def generate_table_doc( + sql_content: str, + table_name: str, + upstream_docs: list[str] | None = None, + model_engine: str = "gpt-4o", + temperature: float = 0.7, +) -> str: + """Calls OpenAI to generate documentation for a single column in a table. + + Args: + sql_content (str): The SQL code for the table + table_name (str | None): Name of the table/model (optional) + upstream_docs (list[str] | None): Optional docs or references you might have + model_engine (str): The OpenAI model to use (e.g., 'gpt-3.5-turbo') + temperature (float): OpenAI completion temperature + + Returns: + str: A short docstring suitable for a "description" field + """ + messages = _create_llm_prompt_for_table(sql_content, table_name, upstream_docs) + response = openai.chat.completions.create( + model=model_engine, + messages=messages, # pyright: ignore[reportArgumentType] + temperature=temperature, + ) + + content = response.choices[0].message.content + if not content: + raise ValueError("OpenAI returned an empty response") + return content.strip() + + if __name__ == "__main__": # Kitchen sink sample_sql = """ diff --git a/src/dbt_osmosis/core/osmosis.py b/src/dbt_osmosis/core/osmosis.py index f1e57fd..396d641 100644 --- a/src/dbt_osmosis/core/osmosis.py +++ b/src/dbt_osmosis/core/osmosis.py @@ -11,7 +11,7 @@ import time import typing as t import uuid -from collections import OrderedDict +from collections import ChainMap, OrderedDict, defaultdict, deque from collections.abc import Iterable, Iterator from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait from dataclasses import dataclass, field @@ -663,6 +663,59 @@ def _get_node_path(node: ResultNode) -> Path | None: return None +def _topological_sort( + candidate_nodes: list[tuple[str, ResultNode]], +) -> list[tuple[str, ResultNode]]: + """ + Perform a topological sort on the given candidate_nodes (uid, node) pairs + based on their dependencies. If a cycle is detected, raise a ValueError. + + Kahn’s Algorithm: + 1) Build adjacency list: parent -> {child, child, ...} + (Because if node 'child' depends on 'parent', we have an edge parent->child). + 2) Compute in-degrees for all nodes. + 3) Collect all nodes with in-degree == 0 into a queue. + 4) Repeatedly pop from queue and 'visit' that node, + then decrement the in-degree of its children. + If any child's in-degree becomes 0, push it into the queue. + 5) If we visited all nodes, we have a valid topological order. + Otherwise, a cycle exists. + """ + adjacency: defaultdict[str, set[str]] = defaultdict(set) + in_degree: defaultdict[str, int] = defaultdict(int) + + all_uids = set(uid for uid, _ in candidate_nodes) + + for uid, _ in candidate_nodes: + in_degree[uid] = 0 + + for uid, node in candidate_nodes: + for dep_uid in node.depends_on_nodes: + if dep_uid in all_uids: + adjacency[dep_uid].add(uid) + in_degree[uid] += 1 + + queue: deque[str] = deque([uid for uid, deg in in_degree.items() if deg == 0]) + sorted_uids: list[str] = [] + + while queue: + parent_uid = queue.popleft() + sorted_uids.append(parent_uid) + + for child_uid in adjacency[parent_uid]: + in_degree[child_uid] -= 1 + if in_degree[child_uid] == 0: + queue.append(child_uid) + + if len(sorted_uids) < len(candidate_nodes): + raise ValueError( + "Cycle detected in node dependencies. Cannot produce a valid topological order." + ) + + uid_to_node = dict(candidate_nodes) + return [(uid, uid_to_node[uid]) for uid in sorted_uids] + + def _iter_candidate_nodes( context: YamlRefactorContext, ) -> Iterator[tuple[str, ResultNode]]: @@ -689,10 +742,14 @@ def f(node: ResultNode) -> bool: logger.debug(":white_check_mark: Node => %s passed filtering logic.", node.unique_id) return True + candidate_nodes: list[t.Any] = [] items = chain(context.project.manifest.nodes.items(), context.project.manifest.sources.items()) for uid, dbt_node in items: if f(dbt_node): - yield uid, dbt_node + candidate_nodes.append((uid, dbt_node)) + + for uid, node in _topological_sort(candidate_nodes): + yield uid, node # Introspection @@ -1875,7 +1932,11 @@ def synthesize_missing_documentation_with_openai( ) -> None: """Synthesize missing documentation for a dbt node using OpenAI's GPT-4o API.""" try: - from dbt_osmosis.core.llm import generate_column_doc, generate_model_spec_as_json + from dbt_osmosis.core.llm import ( + generate_column_doc, + generate_model_spec_as_json, + generate_table_doc, + ) except ImportError: raise ImportError("Please install the 'dbt-osmosis[openai]' extra to use this feature.") if node is None: @@ -1892,8 +1953,23 @@ def synthesize_missing_documentation_with_openai( ":no_entry_sign: No columns to synthesize documentation for => %s", node.unique_id ) return - documented = len([n for n in node.columns.values() if n.description]) - if total - documented > 10: + documented = len([ + column + for column in node.columns.values() + if column.description and column.description not in context.placeholders + ]) + node_map = ChainMap( + t.cast(dict[str, ResultNode], context.project.manifest.nodes), + t.cast(dict[str, ResultNode], context.project.manifest.sources), + ) + upstream_docs: list[str] = [] + for uid in node.depends_on_nodes: + dep = node_map.get(t.cast(str, uid)) + if dep is not None: + upstream_docs.append(f"{uid}: {dep.description}") + if ( # NOTE a semi-arbitrary limit by which its probably better to one shot the table versus many smaller requests + total - documented > 10 + ): logger.info( ":robot: Synthesizing bulk documentation for => %s columns in node => %s", total - documented, @@ -1903,29 +1979,43 @@ def synthesize_missing_documentation_with_openai( getattr( node, "raw_code", f"SELECT {', '.join(node.columns)} FROM {node.schema}.{node.name}" ), - [context.project.manifest.nodes[n].description for n in node.depends_on_nodes], - f"{node.name} ({node.resource_type}) -- {node.description}", + upstream_docs=upstream_docs, + existing_context=f"{node.unique_id} -- {node.description}", temperature=0.4, ) if not node.description or node.description in context.placeholders: node.description = spec.get("description", node.description) for synth_col in spec.get("columns", []): - cur_col = node.columns.get(synth_col["name"]) - if cur_col and (not cur_col.description or cur_col.description in context.placeholders): - cur_col.description = synth_col.get("description", cur_col.description) + usr_col = node.columns.get(synth_col["name"]) + if usr_col and (not usr_col.description or usr_col.description in context.placeholders): + usr_col.description = synth_col.get("description", usr_col.description) else: - for col_name, col in node.columns.items(): + if not node.description or node.description in context.placeholders: + logger.info( + ":robot: Synthesizing documentation for node => %s", + node.unique_id, + ) + node.description = generate_table_doc( + getattr( + node, + "raw_code", + f"SELECT {', '.join(node.columns)} FROM {node.schema}.{node.name}", + ), + table_name=node.relation_name or node.name, + upstream_docs=upstream_docs, + ) + for column_name, col in node.columns.items(): if not col.description or col.description in context.placeholders: logger.info( ":robot: Synthesizing documentation for column => %s in node => %s", - col_name, + column_name, node.unique_id, ) col.description = generate_column_doc( - col_name, - f"{node.name} ({node.resource_type}) -- {node.description}", - node.relation_name or node.name, - [context.project.manifest.nodes[n].description for n in node.depends_on_nodes], + column_name, + existing_context=f"{node.unique_id} -- {node.description}", + table_name=node.relation_name or node.name, + upstream_docs=upstream_docs, temperature=0.7, ) diff --git a/uv.lock b/uv.lock index d1e0dd0..e8b5cec 100644 --- a/uv.lock +++ b/uv.lock @@ -391,7 +391,7 @@ wheels = [ [[package]] name = "dbt-osmosis" -version = "1.1.1" +version = "1.1.3" source = { editable = "." } dependencies = [ { name = "click" },