diff --git a/.gitignore b/.gitignore index 60e511a4a..ae0ab4af0 100644 --- a/.gitignore +++ b/.gitignore @@ -312,5 +312,7 @@ dist static/ *.db-shm *.db-wal + +logdir/ .aider* examples/logdir diff --git a/README.md b/README.md index 620736743..8abfd9084 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ To install `ell` and `ell studio`, you can use pip. Follow these steps: 2. Run the following command to install the `ell-ai` package from PyPI: ```bash - pip install ell-ai + pip install ell-ai[all] ``` 3. Verify the installation by checking the version of `ell`: diff --git a/docs/ramblings/0.1.0/chat.md b/docs/ramblings/0.1.0/chat.md index 4f4873125..55bdff178 100644 --- a/docs/ramblings/0.1.0/chat.md +++ b/docs/ramblings/0.1.0/chat.md @@ -154,4 +154,155 @@ class MyAgent(): pass + + +from functools import wraps +from typing import Generator, Any +# Precanned AI responses for demo purposes +ai_responses = [ + "The capital of France is Paris.", + "Population: 2200000", +] + + +def lmp(func): + """ + Decorator that simulates multi-step calls to an LLM API. + Prints each step and collects all yields from a generator. + Returns the collected values. + """ + @wraps(func) + def wrapper(*args, **kwargs): + system_prompt = func.__doc__ + print(f"\033[94mSystem: {system_prompt}\033[0m") + generator = func(*args, **kwargs) + message_history = [] + step = 1 + + try: + user_prompt = next(generator) + while True: + print(f"\033[92mUser: {user_prompt}\033[0m") + message_history.append({"role": "user", "content": user_prompt}) + + # Use precanned AI response + ai_response = ai_responses[step - 1] if step <= len(ai_responses) else f"AI response for step {step}" + print(f"\033[93mAssistant: {ai_response}\033[0m") + + message_history.append({"role": "assistant", "content": ai_response}) + step += 1 + + # Send AI response back to the generator + user_prompt = generator.send(ai_response) + + except StopIteration as e: + return e.value + return wrapper + +@lmp +def multistep_prompt(): + """You are a helpful assistant.""" + assistant_response = yield "What is the capital of France?" + print("City!", assistant_response) + assistant_response_2 = yield "What is the population of that city?" + + # This is allowed in a generator + return int(assistant_response_2.split("Population: ")[-1]) + +# Execute the multi-step prompt +result = multistep_prompt() +print(f"{result}") + + +import asyncio +from functools import wraps + +async def async_lmp(func): + """ + Async decorator that simulates multi-step calls to an LLM API. + Prints each step and collects all yields from an async generator. + Returns the collected values. + """ + @wraps(func) + async def wrapper(*args, **kwargs): + system_prompt = func.__doc__ + print(f"\033[94mSystem: {system_prompt}\033[0m") + generator = func(*args, **kwargs) + message_history = [] + step = 1 + + try: + user_prompt = await anext(generator) + while True: + print(f"\033[92mUser: {user_prompt}\033[0m") + message_history.append({"role": "user", "content": user_prompt}) + + # Use precanned AI response + ai_response = ai_responses[step - 1] if step <= len(ai_responses) else f"AI response for step {step}" + print(f"\033[93mAssistant: {ai_response}\033[0m") + + message_history.append({"role": "assistant", "content": ai_response}) + step += 1 + + # Send AI response back to the generator + user_prompt = await generator.asend(ai_response) + + except StopAsyncIteration as e: + return e.value + return wrapper + +@async_generator +async def async_multistep_prompt(): + """You are a helpful assistant.""" + resp = await yield_("What is the capital of France?") + resp = await yield_("What is the population of that city?") + return int(resp.split("Population: ")[-1]) + +async def main(): + result = await async_multistep_prompt() + print(f"{result}") + +asyncio.run(main()) + + +# so in some sense this is the most natural interface for ell which is just the fucking api iterface with an lmp context for multistep, the yield statment feels just right for multistep though. it's so unclear to me why async generators do not have a return value though. +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield ell.user("What is the capital of France?") + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield [ell.user("What is the capital of France?")] + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield ell.Call(messages=[ell.user("What is the capital of France?")], api_params={"max_tokens": 10}) + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield [ell.user("What is the capital of France?")], {"max_tokens": 10} + + +# This is unacceptable. + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + claude_says = yield "What is the capital of France?", {'model': 'claude'} + gpt_says = yield "What is the capital of France?" + +--> + +def normal_prompt(): + anthropic_client = anthropic.Anthropic() + openai_client = openai.OpenAI() + + claude_says = anthropic_client.messages.create(model="claude-3-opus", messages=[{"role": "user", "content": "What is the capital of France?"}]) + gpt_says = openai_client.chat.completions.create(model="gpt-4o", messages=[ + {"role": "user", "content": "What is the capital of France?"}, + { "role": "assistant", "content": claude_says.content}, + {"role": "user", "content": "What is the capital of France?"} + ]) + + return None + ``` \ No newline at end of file diff --git a/docs/ramblings/eval_db_schema.md b/docs/ramblings/eval_db_schema.md new file mode 100644 index 000000000..2201ff56d --- /dev/null +++ b/docs/ramblings/eval_db_schema.md @@ -0,0 +1,563 @@ +The tree would look like this: +``` +experiment[emotional-empathy] -> + run[leap-for-the-sky-o6] -> # process 1 + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + run[leap-for-the-sky-o7] -> # process 2 + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + run[leap-for-the-sky-o8] -> # process 3 + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] +``` + + +With an eval we get + +``` +experiment[emotional-empathy] -> + evaluation[cold-email-writer] -> + evaluation run | + run[leap-for-the-sky-o6] -> # no logner correspodns to a single process execution. + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + evaluation run | + run[leap-for-the-sky-o7] -> + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] +``` + +In general automatically grouping by "run" Is a bad thing when we think about The production execution with multiple process brokers and processes. + +We can just have experiment labeling as a convenience function where. otherwise We don't label by experiment. And also, this doesn't make sense, necessarily, for production runs as well. + +In that case. We might end up with something that looks like this + +``` +experiment[emotional-empathy] -> + evaluation[cold-email-writer] -> + evaluation run | + run[leap-for-the-sky-o6] -> # no logner correspodns to a single process execution. + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + invocation[cold-email-4] + invocation[cold-email-5] +invocation[cold-email-6] # happened in dev +invocation[cold-email-7] # happened in production +invocation[cold-email-8] +``` + +```python +class EvaluationRun(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + scores: List[Dict[str, float]] = Field(default_factory=list) + dataset : Dataset = Field(default_factory=list) + lmp: Optional[LMP] = Field(default=None) + outputs: List[Any] = Field(default_factory=list) + api_params: Dict[str, Any] = Field(default_factory=dict) + start_time: datetime = Field(default_factory=datetime.now) + end_time: Optional[datetime] = None + + @property + def inputs(self) -> List[Any]: + return [d['input'] for d in self.dataset] + + + def write(self, serialized_evaluation_run) -> None: + # To link! + pass + +class Evaluation(BaseModel): + """Simple evaluation for prompt engineering rigorously""" + model_config = ConfigDict(arbitrary_types_allowed=True) + name: str + dataset: Dataset + criteria: Optional[Criteria] = Field(default_factory=dict) + default_api_params: Optional[Dict[str, Any]] = Field(default_factory=dict) +``` + +So the chief question here is, what do we do with criteria? In the traditional deviate view criteria are tracked functions And this track functions will appear somehow in the computation graph. And because they adorn The outputs of invocations. We can trace which invocations led to a certain criteria evaluation and then. in the tables for invocations, we can show the criteria as they relate either to an eval or Um an individual invocation that relates them The hard question here is, when we evaluate a criteria, it's based on the output of one particular invocation. So You know, if you show that, for example, in a chain of writing a story where we generate drafts in then generate a final output story, the criteria that are linked downstream of, let's say, generate story ideas as it is realized through writing a story based on those ideas, wouldn't necessarily pipe through on the other end In that we wouldn't probably want to show. That particular metric on the generates story idea evaluation, even though there is sort of a direct line through tracing + +Also, I don't think we really want to show the metrics in the computation graph, but we do want to keep track of them. And it would be simple enough to just have them as sort of functional invocations that are linked. So for example, you could query Was there a invocation that was from an LMP of the type metric that was downstream of this? And then that would be something you show in L studio in the computation graph without creating any additional links, just a simple sql query + +So in this view, we would do the following We metra size all Criteria, or we call them criterion, whatever. L dot criteria, or out. Maybe it would be L dot metric, right? I'm not sure if we want to call them criteria metric yet, but We will adorn With these decorators that basically say, every time it's invoked, we're going to track it. And the it can invoke it any number of times. So like, for example, even MP dot mean would be something that we Adorn like it could be any function from any module. So that's also a little bit weird, right Like, I don't necessarily want to, well, okay, actually, in that case, the source code isn't tracked. The function itself is pickled. So maybe this is okay. Because it's an external module and we need to think about how we serialize external module functions like that, right? Like, if I use a standard eval from some other module, that's just an import. I don't really need to actually serialize like source code for versioning. It's just like fix because we assume the versioning and production versus the versioning here is The The same Period. + +OK, so the total picture so far It is the following. We're going to wrap them in a decorator All their invocations are going to show up in the computation graph. Which is kind of **** actually, because if I literally use something from another module, then I get a ton of invocation. So maybe this is not something I want to do I'm trying to imagine a scenario where that would be the case. Like Yeah, like, imagine we just do like A comparison, right? So like, let's say we have a pre can thing that just like is like L dot. Is equal And then you choose the field that you're trying to be equal to. And you say a criteria equals L is equal of expected output Then every time that he gets called, like basically all the lmps are going to link to that in the graph. I guess we don't particularly have to show that in the graph, but they all do connect on that basis And so now we track, because we're decorating the actual source function, we're going to slow down the invocation of that simple is equal function. So, so much. Whereas I'm not convinced that you want to do the traditional weep thing of decorating the function like when you actually decorate the function on your own like that's going to have a certain cost associated with it So maybe what we do is we allow you to decorate with a metric thing, knowing that if you do decorate with a metric, it Technically be more expensive But the actual realization of this inside of L studio is going to be a little bit different. So if you are passing in a metric. Then. Yeah. So if you're passing in a metric, what's going to happen is that we will use that actual metric function for invocations, and if you're passing in any random function Then we will wrap it in Metric Lambda, which for all intensive purposes will. Do a pass through to the underlying source, but only when we invoke the metric within the criteria. It doesn't take your function and then decorate it, right? Because that's unexpected behavior. In fact, this is just like a wrapper for the purposes of the eval and we literally give the name in the code or within L studio will be like eval dot whatever. And that will be the metric that's being logged, I guess. Okay. So let's say we're now adding that lmp to the computation graph. In L studio, we would differentiate that because Well, I don't know....... + + +```python +@ell.metric +def my_metric(datapoint, output): + return output == datapoint['expected_output'] + +eval = ell.evaluation(name="my_eval", criteria=[my_metric]) +``` + +Does this become: + +``` +LMPs: + my_metric + my_eval.my_metric + +Invocations: + my_metric[0] + my_eval.my_metric[0] + my_eval.my_metric[1] + ... + +EvaluationRuns: + my_eval + +``` + +Yeah, this feels sort of messy, but I also don't like the idea of exceptionalism for metrics when. They are just alfunctions. And that, like, appeared fairly clean on the other side. + + +I think one of the bigger problems here is we haven't really solved structured outputs as well, or just arbitrary output parsing I liked the Alex Dixon approach of this, but it wasn't exactly clear that. It would be intuitive to someone you know who's using python. Like, I don't think you can get the return statement of a yield + +Ok. So if you take a look at the yield L pi or md, whatever the document is in rambling, this actually Provides a solution to structure outputs that makes me quite happy. Um, um. And now thinking about it, I think there should actually be A Decorator or something like this. I'm not sure, but um It is now possible to structure outputs and very, very simply. And this is actually a really cool unification of functions and LMP + +OK, so back on this, do we actually wrap? Evaluations In like or the criterion of evaluations in vacations. And we just leave the db schema like very, very clean. I guess the only problem here is this, if I am already like, if I pass in now, let's let's say everything is using yield statements. If I pass in This L dot simple lm thing, or L dot complex lm. And It's already decorated it. then it doesn't make sense to redecorate it. But then now the metric. Doesn't have like a clean, like my eval dot, whatever It's just using some metric I've defined somewhere else in my program. So there is a bit of a problem there. + + + +Okay and we can just simply solve this by literally creating evaluation criterion or literally just evaluation with like just as a whole object. We don't care about invocations. We don't actually serialize anything and all we're going to do is if you have an LMP score like that just it doesn't matter. It doesn't get written. We don't see it in indication views. We just know that that invocation appeared in eval and then we can go look at the invocation eval run scores and pull. out the score for the L. M. P. From the invocation eval and the score for an L. M. P. And the invocation Eval will look like a table with The invocation Uh the score name or the criterion name the criterion. Id And the Actual float value of the criterion. Along with the invocation run ID. + + + + +Yeah, this feels a bit cleaner, I suppose. But the problem is now Will never have first class support for like these weave ops that we want to log. But of course the. Ultimate score from the lm is now logged. So that's nice. + + +``` +Evaluation Run: + id + dataset + lmp + outputs + api_params + start_time + end_time + +Evaluation Score: + id + invocation_id + evaluation_run_id + criterion_id + value + +Evaluation Criterion (LMP, but not really; kind of annoying.) + id + name + description + source + dependencies + evaluation_id + +Evaluation: + id + name + dataset +``` + + +Now if we use LMP + +``` +Evaluation Run: + id + dataset + lmp_id + evaluation_id + outputs + api_params + start_time + end_time + +Evaluation Score: + id + criterion_lmp_id + evaluation_run_id + evaluation_id # redundant + value + +Criterion: + id + optional_lmp_id + name #as defined by the criteria. + evaluation + +Evaluation: + id + name + dataset + (criteria) + (runs) + +``` + +in this case + + +```python + +class Evaluation(BaseModel): + criteria: Dict[str, Callable] + dataset: Dataset + name: str + + def __init__(self, name: str, dataset: Dataset, criteria: Dict[str, Callable]): + wrapped_criteria = { + name: ell.metric(criterion) for name, criterion in criteria.items() + } + self.criteria = wrapped_criteria + +@ell.simple(model="gpt-4o") +def write_a_poem(topic : str): + response = yield f"Write a poem about {topic}" + return response + +# EVAL CRITERION!!! + +@ell.simple(model="gpt-4o") +def is_good_poem(datapoint, output): + response = yield f"Is this a good poem? {output}" + return "yes" in response.lower() + +@ell.simple(model="gpt-4o") +def on_topic(datapoint, output): + response = yield [ + ell.system("You are a helpful assistant. Always answer with 'yes' or 'no'."), + ell.user(f"Is this poem about {datapoint['topic']}? {output}") + ] + return "yes" in response.lower() + + +eval = Evaluation(name="my_eval", dataset=poem_prompt_dataset, criteria={ + "matches_expert_poem": lambda datapoint, output: output == datapoint['expert_poem'], + "is_good": is_good_poem, + "on_topic": on_topic, +}) + +eval.run(write_a_poem) +``` +s + +So are we really going to actually Turn the past criteria directly into. Language model programs. I feel like that would be a problem if I use some pre canned or massive criteria. But we have to serialize them anyway. So let's just say that we put it into the computation graph like this and then we have a nice clean separation. I'll keep track of all invocaitons of metrics (I would have anyway) BUT do we seperate them + +say for example +```python +eval= Evaluation(name="my_eval", dataset=poem_prompt_dataset, criteria={ + "metric", np.equals +}) + +if not isinstance((metric := np.equals), ell.metric): + # This is so hacky, but it's waht wandb does. + + metric = ell.metric(lambda x: metric(x)) + + +``` + + + + +``` +Evaluation Run: + id + dataset + lmp_id + evaluation_id + outputs + api_params + start_time + end_time + +Evaluation Score: + id + criterion_lmp_id + evaluation_run_id + evaluation_id # redundant + value + +Criterion(LMP): + id + optional_lmp_id + name #as defined by the criteria. + evaluation + + src + dependencies + +Evaluation: + id + name + dataset + (criteria) + (runs) + +``` + +We need criterion to be LMPs for verisoning sake because of freevars and so on.. +Okay what does it look like if we have vals in the ocmputation graph without being an LMP? + +If we really accept that it needs to be an LMP, lets call it a criterion LMP, what does the src of + +lambda x: np.mean(x) look like? + + it will actually get the line of defiition for the lambda.. + this raises this issue. (https://github.com/MadcowD/ell/issues/288) but this is neither here nor there. + + So what are we going to do here? + + +If we actually do this Real situation is that everything gets wrapped in a metric Including metrics that get passed in unless they're actually explicitly metrics, and then we hide the metric type underneath. This lets you reuse criterion. + +```python + + + + +class Evaluation(BaseModel): + + def __init__(self, name: str, dataset: Dataset, criteria: Dict[str, Callable]): + wrapped_criteria = { + name: ell.metric(criterion) if not isinstance(criterion, ell.metric) else criterion for name, criterion in criteria.items() + } + self.criteria = wrapped_criteria + +@ell.simple(model="gpt-4o") +def write_a_poem(topic : str): + response = yield f"Write a poem about {topic}" + return response + + +# So let's say we want to track this... I see so we don't actually want to +@ell.lmp +def is_good_poem(datapoint, output): + response = yield f"Is this a good poem? {output}" + return "yes" in response.lower() + +# I think it shouldn't appear in the computation graph unless a user specifies it! hence ell.metric shouldn't do anyhting and it hsould be @ell.function(_hidden=True) or something like that. +``` + +But then a criterion isn't exactly an LMP: + + +Evaluation Run: + id + dataset + lmp_id + evaluation_id + outputs + api_params + start_time + end_time + + This is a problem because these are just invocation ids......... +**Evaluation Score:** + id + criterion_id + evaluation_run_id + evaluation_id # redundant + value + +**Criterion**: + id + lmp_id + name #as defined by the criteria. + evaluation + +Evaluation: + id + name + dataset + (criteria) + (runs) + + + +To visualize an invocation with all of its corresponding criterion scores. What would that look like : +```python +import sqlmodel +from models import EvaluationRun, Evaluation, EvaluationCriterion, EvaluationScore, LMP, Invocation + +# Get all of the LMP's with their corresponding critterionscore +def get_invocation_with_criterion_scores(invocation_id: str, session: sqlmodel.Session): + # Query the invocation + invocation = session.query(Invocation).filter(Invocation.id == invocation_id).first() + + if not invocation: + raise ValueError(f"No invocation found with id {invocation_id}") + + # Query the evaluation run associated with this invocation + evaluation_run = session.query(EvaluationRun).filter(EvaluationRun.id == invocation.evaluation_run_id).first() + + if not evaluation_run: + raise ValueError(f"No evaluation run found for invocation {invocation_id}") + + # Query the evaluation associated with this run + evaluation = session.query(Evaluation).filter(Evaluation.id == evaluation_run.evaluation_id).first() + + if not evaluation: + raise ValueError(f"No evaluation found for evaluation run {evaluation_run.id}") + + # Query all criteria for this evaluation + criteria = session.query(EvaluationCriterion).filter(EvaluationCriterion.evaluation_id == evaluation.id).all() + + # Query all scores for this invocation + scores = session.query(EvaluationScore).filter( + EvaluationScore.evaluation_run_id == evaluation_run.id, + EvaluationScore.invocation_id == invocation.id + ).all() + + # Organize scores by criterion + score_by_criterion = {score.criterion_id: score.value for score in scores} + + # Construct the result + result = { + "invocation_id": invocation.id, + "lmp_id": invocation.lmp_id, + "output": invocation.output, + "evaluation_run_id": evaluation_run.id, + "evaluation_id": evaluation.id, + "evaluation_name": evaluation.name, + "criteria_scores": [ + { + "criterion_id": criterion.id, + "criterion_name": criterion.name, + "score": score_by_criterion.get(criterion.id, None) + } + for criterion in criteria + ] + } + + return result + +``` +See that feels bad.. We could do something like + +```python +# Query all invocations of the criteria that are linked to this evaluation run +criterion_invocations = session.query(Invocation).join( + InvocationTrace, Invocation.id == InvocationTrace.invocation_consuming_id +).filter( + InvocationTrace.invocation_consumer_id == invocation.id, + Invocation.lmp_id.in_([criterion.lmp_id for criterion in criteria]) +).all() + +# Organize criterion invocations by criterion LMP ID +criterion_invocations_by_lmp = {inv.lmp_id: inv for inv in criterion_invocations} + +# Update the result to include criterion invocations +result["criteria_scores"] = [ + { + "criterion_id": criterion.id, + "criterion_name": criterion.name, + "score": score_by_criterion.get(criterion.id, None), + "criterion_invocation": { + "id": criterion_invocations_by_lmp[criterion.lmp_id].id, + "output": criterion_invocations_by_lmp[criterion.lmp_id].output + } if criterion.lmp_id in criterion_invocations_by_lmp else None + } + for criterion in criteria +] +``` + +Alternative fast API way of doing this is: + +```python +from sqlmodel import SQLModel, Field, Relationship +from typing import Optional, List +from datetime import datetime +class Invocation(InvocationBase, table=True): + lmp: SerializedLMP = Relationship(back_populates="invocations") + consumed_by: List["Invocation"] = Relationship( + back_populates="consumes", + link_model=InvocationTrace, + sa_relationship_kwargs=dict( + primaryjoin="Invocation.id==InvocationTrace.invocation_consumer_id", + secondaryjoin="Invocation.id==InvocationTrace.invocation_consuming_id", + ), + ) + consumes: List["Invocation"] = Relationship( + back_populates="consumed_by", + link_model=InvocationTrace, + sa_relationship_kwargs=dict( + primaryjoin="Invocation.id==InvocationTrace.invocation_consuming_id", + secondaryjoin="Invocation.id==InvocationTrace.invocation_consumer_id", + ), + ) + used_by: Optional["Invocation"] = Relationship(back_populates="uses", sa_relationship_kwargs={"remote_side": "Invocation.id"}) + uses: List["Invocation"] = Relationship(back_populates="used_by") + contents: InvocationContents = Relationship(back_populates="invocation") + __table_args__ = ( + Index('ix_invocation_lmp_id_created_at', 'lmp_id', 'created_at'), + Index('ix_invocation_created_at_latency_ms', 'created_at', 'latency_ms'), + Index('ix_invocation_created_at_tokens', 'created_at', 'prompt_tokens', 'completion_tokens'), + ) + + + +class EvaluationRun(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + dataset: str + lmp_id: str = Field(foreign_key="serializedlmp.lmp_id") + evaluation_id: int = Field(foreign_key="evaluation.id") + + api_params: str + start_time: datetime + end_time: datetime + + evaluation: "Evaluation" = Relationship(back_populates="runs") + results: List[EvaluationInvocation] = Relationship(back_populates="evaluation_run") + + + +# this linkage is 'ok' + +class EvaluationInvocation(SQLModel, table=True): + evaluation_run_id : int = Field(foreign_key="evaluationrun.id", primary_key=True) + invocation_id : int = Field(foreign_key="invocation.id", primary_key=True) + + evaluation_run : EvaluationRun = Relationship(back_populates="invocations") + invocation : Invocation = Relationship(back_populates="evaluation_invocations") + + # Something like this with no back population + scores : List[Invocation] = Relationship( + link_model=EvaluationCriterionLink, + sa_relationship_kwargs=dict( + primaryjoin="EvaluationInvocation.invocation_id==Invocation.id") + ) + +class EvaluationCriterionLink(SQLModel, table=True): + evaluation_invocation_id : int = Field(foreign_key="evaluationinvocation.id") + criterion_invocation_id : int = Field(foreign_key="invocation.id") + + +class Criterion(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + lmp_id: str = Field(foreign_key="serializedlmp.lmp_id") + name: str + evaluation_id: int = Field(foreign_key="evaluation.id") + + evaluation: "Evaluation" = Relationship(back_populates="criteria") + scores: List[EvaluationScore] = Relationship(back_populates="criterion") + +class Evaluation(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + dataset_id: str + dataset_pickle : bytes + + criteria: List[Criterion] = Relationship(back_populates="evaluation") + runs: List[EvaluationRun] = Relationship(back_populates="evaluation") + diff --git a/docs/ramblings/evalspec.md b/docs/ramblings/evalspec.md new file mode 100644 index 000000000..2cc4bd321 --- /dev/null +++ b/docs/ramblings/evalspec.md @@ -0,0 +1,278 @@ + +```python + +# Example implementation based on the ideas discussed + +from typing import List, Dict, Any, Union, Callable +import inspect +import numpy as np + +# Define a flexible Dataset type +Dataset = List[Dict[str, Any]] + +# Example dataset +dataset: Dataset = [ + {"input": "What is the capital of France?", "expected_output": "Paris", "difficulty": "easy"}, + {"input": "What is the square root of 144?", "expected_output": "12", "difficulty": "medium"}, + # ... more data points +] + +# Example LMP (Language Model Program) +def my_lmp(input: str) -> str: + # This is a mock LMP that just returns the input + return input + +# Example score functions +def accuracy_score(expected_output: str, output: str) -> float: + return float(expected_output.lower() == output.lower()) + +def difficulty_weighted_score(difficulty: str, expected_output: str, output: str) -> float: + base_score = float(expected_output.lower() == output.lower()) + difficulty_weight = {"easy": 1.0, "medium": 1.5, "hard": 2.0} + return base_score * difficulty_weight.get(difficulty, 1.0) + +class Evaluation: + def __init__(self, name: str, dataset: Dataset, lmp: Callable, scores: List[Callable]): + self.name = name + self.dataset = dataset + self.lmp = lmp + self.scores = scores + + def run(self) -> Dict[str, List[float]]: + results = {score.__name__: [] for score in self.scores} + + for datapoint in self.dataset: + # Run the LMP + lmp_input = datapoint.get("input") + if isinstance(lmp_input, str): + output = self.lmp(lmp_input) + elif isinstance(lmp_input, dict): + output = self.lmp(**lmp_input) + elif isinstance(lmp_input, list): + output = self.lmp(*lmp_input) + else: + raise ValueError(f"Unsupported input type: {type(lmp_input)}") +``` +Alright, so this part is a bit too magical. Essentially, what it's doing is taking the input object and, if it's a single object, passing it directly into the LMP. Otherwise, it destructures the arguments. I do appreciate the use of **kwargs versus list destructuring; it's quite elegant. We can think of it as handling both args and kwargs, which is fine. However, it's also quite clean to write your dataset as single input elements. + +```python + # Calculate scores + for score in self.scores: + args = inspect.signature(score).parameters + datapoint_subset = {k: datapoint.get(k) for k in args if k != 'output'} + score_output = score(**datapoint_subset, output=output) + results[score.__name__].append(score_output) + + return results + +# Usage example +eval = Evaluation( + name="my_evaluation", + dataset=dataset, + lmp=my_lmp, + scores=[accuracy_score, difficulty_weighted_score] +) + +results = eval.run() +print(results) + +# You could then add methods to analyze and visualize the results +# For example: +def analyze_results(results: Dict[str, List[float]]): + for score_name, scores in results.items(): + print(f"{score_name}:") + print(f" Mean: {np.mean(scores):.4f}") + print(f" Median: {np.median(scores):.4f}") + print(f" Min: {np.min(scores):.4f}") + print(f" Max: {np.max(scores):.4f}") + +analyze_results(results) + +``` + +So now let's consider The usability of these input shapes. If we're really going to accept that, there's like some special input data point arg. + + +```python +class DatapointPD(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + input : Dict[str, Any] | List[Any] + labels: Dict[str, Any] + + +# or + +class DatapointTD(TypedDict, total=False): + input : Dict[str, Any] | List[Any] + +# finally +Dataset = List[Datapoint] + + +# This is actually quite in the style of ell where we have input and output in ell studio as either a list of arguments or a dictionary of kwargs. +dataset = [ + Datapoint(input=["What is the capital of France?"], labels={"expected_output": "Paris"}), +] + +# or +dataset = [ + {"input" : {'question' : "What is the capital of France?"}, "answer" : "Paris"}, +] +#/equivalently +dataset = [ + DatapointTD(input=["What is the capital of France?"], labels={"expected_output": "Paris"}), +] + +``` + +This approach is quite elegant. We need to use Pydantic models with `total=False` so we can validate that each entry has an input. + +Imagine defining a dataset in this structured way, where every entry must at least have the shape of an input. You can then add arbitrary fields to the dataset columns. This avoids the issue where the shape of the LMP function needs to be transformed. + +So let's actually write out what the final form of this might actually look like and see if it's palatable. If it's not that's okay. +```python + + +@ell.simple(model="gpt-4o-mini") +def write_a_poem(about :str): + """You are PoetGPT. You always write in iambic pentameter. Only answer with a poem.""" + return f"Write a poem about {about}" + + +@ell.simple(model="gpt-4o-mini") +def iambic_pentameter(poem :str): + return f"Is the following poem in iambic pentameter? {output} answer with yes or no." + + +# This is like OpenAI + weave evals. + +eval = Evaluation( + name="poem-eval", + dataset=[ + Datapoint(input=["a rose"], must_contain="rose", minimum_length=100), + Datapoint(input=["a sunset"], must_contain="sunset", minimum_length=100), + Datapoint(input=["a rainbow"], must_contain="", refuse=True, minimum_length=100), + ], + criterion=[ + lambda datapoint, output: datapoint.must_contain in output, + lambda datapoint, output: len(output) >= datapoint.minimum_length, + lambda datapoint, output: "I refuse to write a poem about that" in output or not datapoint.refuse, + lambda datapoint, output: "yes" in iambic_pentameter(output).lower(), + ] +) + + +eval.run(write_a_poem) +# a set of scores. +# Then we modify write a poem + + + +@ell.simple(model="gpt-4o-mini") +def write_a_poem(about :str): + """You are PoetGPT. You always write in iambic pentameter. Only answer with a poem. Say I refuse to write a poem about that if you are asked to write about rianbows """ + return f"Write a poem about {about}" + + +# Now the refusal criterion will work. +eval.run(write_a_poem) + +# Now we improve iambic pentameter score by trying to rewrite the poem. + +@ell.simple(model="gpt-4o-mini") +def better_poem_writer(about :str): + """You are a poet. You are a poet who is extremely good at writing iambic pentameter. If the poem says I refuse just copy the refusal""" + initial_poem = write_a_poem(about) + + return f"Rewrite the following poem in iambic pentameter: {initial_poem}" + + +eval.run(better_poem_writer) +# highest score. + +``` + +I think I like this Eval the most from any of the specs I have come up with. You can just throw accuracy criteria in there. It's very easy by specifying how the dataset looks. The Weave guys definitely built a really good abstraction here. Some small changes around where things feel magical make this pretty close to an abstraction that we can use. In the above example, it's extremely readable as to what's going on, and I can imagine a very simple flow where I iteratively improve things. I don't have to worry about what's going on with the individual args or kwargs, as they're specified in the input dict. If there's a mismatch, then I just use arguments instead of kwargs. As for the criterion, you just take in the data point and the output. It's just two positional arguments. The data point is literally just what came from the dataset. So if you ever need to look at the schema, it's all there. Inputs are separated out. Inputs are a requirement for data points. We can validate that when we build the eval. This is a very particular type of dataset, and this lets you very quickly and rapidly develop fast evaluations. + +The only problem here is I think what is very nice about the OpenAI evaluation product is that it comes with tons of evaluations by default. For example, text similarity, text quality, BLEU score, things like this. And because the dataset is so free, we don't have an expected output. We can't run metrics automatically. + +We could, by default, actually include something inside the metric functionality, like a special keyword in the dataset. If we actually use the reserved expected output keyword, then you can just use pre-canned metrics without having to specify them, because then we're sort of moving the transmutation of metrics to the criterion specification, right? But I could automatically run things like BLEU score or text similarity if you use the expected output keyword. Otherwise, I guess we could just make them instantiable, so I might actually prefer this. So let's just do this, for example. + + +```python + + +from ell.evals import cosine_similarity + +@ell.simple(model="gpt-4o-mini") +def write_a_poem(about :str): + """You are PoetGPT. Write with cheesy well-known poems if available.""" + return f"Write a poem about {about}" + + +eval = Evaluation( + name="poem-eval", + dataset=[ + # jsonl injection into dataset formula + Datapoint(input=["a rose"], expert_poem="Roses are red, violets are blue, sugar is sweet, and so are you.") + ], + criterion=[ + cosine_similarity("text-embedding-3-small", expected_output="expert_poem", inner_product="normal") + ] +) + +# can automatically do cosine similarity & other nice things +eval.run(write_a_poem) + +``` + + + +# Next up +1. Implementing Evaluation +2. Implement Studio schemas +3. Implement Eval UX in studio + + +An additional note we need to have within this spec is that there are going to be two different modes for evaluations. In particular, you may not specify evaluations at all. You just want to look at the outputs. And that's something that's really not thought about with these frameworks like DSPY, because they are really geared towards avoiding hand prompt engineering. But in an actual use case, like you're trying to tune some email generator in a startup, you do have to kind of engineer by vibes, maybe give feedback and try to reward the model at some point, but there's a lot of data points required to get to that point. And so in the first place, you may not specify criterion at all, and that's something that the OpenAI endpoints actually miss out on in their definition of metrics. So we'll add to the spec the following. + + +```python + + +@ell.simple(model="gpt-4o-mini") +def write_cold_email(name :str, career_history :str, hobbies :str): + """You are an expert email writer. You write cold emails to people.""" + + return f"Write a cold email to {name}. Here is their career history: {career_history}. Here is their hobbies: {hobbies}." + +eval = Evaluation( + name="email-eval", + dataset=[ + Datapoint( + input=dict( + name="jeff", + career_history="Jeff is a VP of marketing", + hobbies="Jeff likes to play the guitar", + ) + ), + Datapoint( + input=dict( + name="george", + career_history="George is a software engineer", + hobbies="George enjoys hiking and photography", + ) + ), + ], + # criterion=[] +) + +# Passing no criterion +result = eval.run(write_cold_email) + +print(result.outputs) +# print(result.scores?? ) # This doesn't actually make sense because when we pass criterion to the eval we don't name them... +# What does weave do> +``` + +We'll have to resolve htis later but this spec isn't the worst. \ No newline at end of file diff --git a/docs/ramblings/humanfeedback.py b/docs/ramblings/humanfeedback.py new file mode 100644 index 000000000..304e31327 --- /dev/null +++ b/docs/ramblings/humanfeedback.py @@ -0,0 +1,264 @@ +from typing import Optional +from pydantic import BaseModel, Field +from ell import Evaluation +import ell + +topic_dataset = [ + {"input": "roses"}, + {"input": "violets"}, + {"input": "sunflowers"}, + {"input": "daisies"}, +] + +@ell.simple(model="gpt-4o") +def write_a_poem(about : str) -> str: + """You are poem GPT. Make it 3 sentences long at most.""" + return f"Write a poem about {about}" + + +class PoemFeedback(BaseModel): + """Please provide feedback on the poem.""" + + clarity: float = Field(..., ge=1, le=10, description="The clarity of the poem on a scale of 1 to 10") + + approve : bool = Field(..., description="If the poem is good enough to be approved") + + +eval = Evaluation( + name="eval", + dataset=topic_dataset, + labels={ + "human_feedback": ell.human_feedback(PoemFeedback), + "length": lambda output: len(output) + } +) +eval.run(write_a_poem) + +def Dataset(*args, **kwargs): + return args[0] + +dataset = Dataset([ + {"input": "roses"}, + {"input": "violets"}, + {"input": "sunflowers"}, + {"input": "daisies"}, +]) + + +# SFT. + +class StructuredPoem(BaseModel): + poem: str = Field(..., description="The poem", max_length=100) + + @field_validator("poem") + def poem_length(cls, v): + if len(v) > 100: + raise ValueError("Poem must be 100 characters or less") + # check punctuation + if not v.endswith("."): + raise ValueError("Poem must end with a period") + + return v + notes : Optional[str] = Field(None, description="Any additional notes about the poem") + + +@ell.human(response_format=StructuredPoem) +def write_a_poem(topic): + """You should write a poem about the topic keep all poems under 100 characters""" + return f"Write a poem about {topic}" + +@ell.human(): +def write_a_poem_human(topic): + """You are a human writing poems""" + # no way with this syntax to request sturcutred out for one message. its gotta be yield ell.call or something. but in which case why yield at all? + expert_poem_str = yield f"Write a poem about {topic}" + + return StructuredPoem( + poem=expert_poem_str, + notes= yield "Please provide feedback on the poem." + ) + +# Decide to do the structured poem response format for. Human sft data or something of this form. Then we need to inherently support structured outputs across the entire api in a meaningful way The yield format is quite interesting because it allows us to reconstitute the format of. sort of a dialog between the labeler and Data that is presented and present arbitrary data at any point in time during the human labeling process or human sft data generation process. But this also doesn't allow the dynamic generation of uis that are clean and beautiful in some sense. Of course, we could have markdown data and better renderers for raw data contained within, for example, write a poem about topic and things like this. But this isn't a fully. thought out solution. I do like that this is kind of consistent, right? So by using these resumable generators with async or not async, but just standard send. We're able to actually Ask for various different Structured data and reconstituted directly in python in a nice way, so that the final result is this structured poem. Therefore What we sft on is a multi message context, where you. Have it actually output, like the since ultimately we can only sft on strings at any given point in time. There's no magic occurring. And what we sft on is like this conversation thread here. So if this were like a multiple, like I could say, you know, something like The expert poemster is the first yield and then the additional notes or feedback is the 2nd sort of yield. This broader constitution of human feedback is actually kind of interesting + +dc = ell.DataCollection( + dataset=dataset, + number_examples_per_datapoint=10, +) + +dc.run(write_a_poem) + + + + + + +eval = Evaluation( + name="eval", + dataset=topic_dataset, + metrics={ + "human_feedback": ell.human_feedback(PoemFeedback), + "length": lambda output: len(output) + } +) +res = eval.run(write_a_poem) + +res.scores # dict{ +# "length": [10, 10, 10, 10], +# "human_feedback": [ +# # or Deferred +# { +# "clarity": 8, +# "approve": True +# }, +# { +# "clarity": 8, +# "approve": True +# }, +# ] +# } + +# we could construct the db model. + +# or is human feedback not a metric??? +eval = Evaluation( + name="eval", + dataset=topic_dataset, + annotations={ + "length": lambda output: len(output) + "human_feedback": ell.human_feedback(PoemFeedback), + "reversed": lambda output: output[::-1] + }, + criterion=lambda annotations: annotations["human_feedback"].approve and annotations["length"] < 100 +) + + + +eval = Evaluation( + name="eval", + dataset=topic_dataset, + scores={ # only callable[..., float] + "length": lambda output: len(output) + }, + annotations={ # callable[..., Any] # if its a pydantic type then we pull out the relevant fields? + "human_feedback": ell.human_feedback(PoemFeedback), + "reversed": lambda output: output[::-1] + }, + # final passing criterion optionally but any metric itself could be optimized against. + criterion=lambda annotations: annotations["human_feedback"].approve and annotations["length"] < 100 +) + + +# multiple criteria doesnt make sense now and so on + +# we also have this nice thing +EvaluationRunLabelerResult +# whic automatically yields the mean of a particular labeler result.... +# could be like +x = eval.run(write_a_poem).results["human_feedback"] # whic his a EvaluationRunLabelerResult.... +x.mean() # would be None if its a type of string otherwise. no this isn't quite rigjht since we dont get access to the right components + +# criterion is a aggregator funciton over annotations + +# if metric(dp, output) -> float compute summary statistics over the metric. +# otherwise it's just raw data to adorn the final eval with + +# then what does the final result object look like +# Does this make like 3 different types of invocaiton labeler or what if they are all annotaitons & we have one criterion flag for a pass rate. What about computing mean metrics? + +# this actually stratifies the problem too much, if I watn to track multiple metrics then I have to make a new table for each metric. +# in general I want that labels are just simple types with the exception of human feedback. +# [float, str, bool] as in the DB schema we buiklt + +FieldInfo(annotation=bool, default=None, description='If the poem is good enough to be approved') +# its sad because effectively human feedback also matches + +# A problem here is that criterion is just like a final pass fail, but we might actually want to construct final. Scores that are a function of previous scores and so on. But this is just too much overloading of a feature where someone could implement such. A Functionality via. Inheritance. + + +class EvaluationRun(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + # we could get really seperated. + results: EvaluationResults = Field(default_factory=EvaluationResults) + dataset : Dataset = Field(default_factory=list) + lmp: Optional[LMP] = Field(default=None) + outputs: List[Any] = Field(default_factory=list) + api_params: Dict[str, Any] = Field(default_factory=dict) + start_time: datetime = Field(default_factory=datetime.now) + end_time: Optional[datetime] = None + + + +x = eval.run(write_a_poem) + +x.dataset # [{'input': 'roses'}, {'input': 'violets'}, {'input': 'sunflowers'}, {'input': 'daisies'}] +x.outputs # [..., ..., ..., ...] (lmp outptus) +x.results.scores # {"length": [10, 10, 10, 10]} #always can expect a numpy array +x.results.annotations # {"human_feedback": [..., ..., ..., ...], "reversed": [..., ..., ..., ...]} +x.results.aggregate_annotations() # {"human_feedback": ..., "reversed": ...} maybe???? idk im a little sad about this one +x.results.criterion # [True, True, True, True] + + + + + + + + + +ell.init(verbose=True) +def render_poem_and_collect_feedback(topic): + # ASCII art for poem presentation + print(""" + ╭───────────────────────────────────────────╮ + │ 🌸 Poem Feedback 🌸 │ + ╰───────────────────────────────────────────╯ + """) + + # Call write_a_poem function + poem = write_a_poem(topic) + + # Collect human feedback + print(""" + ╔══════════════════════════════════════════╗ + ║ 🎭 Human Feedback Section 🎭 ║ + ╚══════════════════════════════════════════╝ + """) + + feedback_data = {} + for field_name, field in PoemFeedback.model_fields.items(): + if field.annotation == float: + while True: + try: + value = float(input(f" 📊 {field.description} ({field.metadata[0].ge}-{field.metadata[1].le}): ")) + if field.metadata[0].ge <= value <= field.metadata[1].le: + feedback_data[field_name] = value + break + else: + print(f" ⚠️ Please enter a number between {field.metadata[0].ge} and {field.metadata[1].le}.") + except ValueError: + print(" ❌ Please enter a valid number.") + elif field.annotation == str: + feedback_data[field_name] = input(f" 💬 {field.description}: ") + elif field.annotation == bool: + feedback_data[field_name] = input(f" ✅/❌ {field.description} (yes/no): ").lower() == 'yes' + + # Create PoemFeedback object + feedback = PoemFeedback(**feedback_data) + + print(""" + ╔══════════════════════════════════════════╗ + ║ 🙏 Thank You for Your Input 🙏 ║ + ╚══════════════════════════════════════════╝ + """) + return feedback + +# Example usage +if __name__ == "__main__": + for topic in ["roses", "violets", "sunflowers", "daisies"]: + feedback = render_poem_and_collect_feedback(topic) + print(f"\nCollected feedback for poem about {topic}:") + print(feedback) + print("\n" + "="*50 + "\n") + + + diff --git a/docs/ramblings/notes_on_dspy2.py b/docs/ramblings/notes_on_dspy2.py new file mode 100644 index 000000000..73712c258 --- /dev/null +++ b/docs/ramblings/notes_on_dspy2.py @@ -0,0 +1,146 @@ +from typing import List +import ell + + +def chain_of_thought(input : str, instruction : str): + @ell.simple(model='gpt-4o') + def chain(input : str): + return ell.user(input=input, instruction=instruction) + return chain + + +ell.init(verbose=True, store='./logdir') + + +def nn_layer(params, x): + return params.dot(x) + + +class Model: + def __init__(self, layers = 10): + self.layer_weights = [np.randn(10, 10) for _ in range(layers)] + + def save(self): + pass + + def __call__(self, x): + return dspy.variable("You arew a good llm that does ") + x + + for weight in self.layer_weights: + x = x * weight + return x + + +# People really want this? +expand_example = predict( + input={} + instruction=." +) + +number_to_classify = predict( + input="example", + instruction="Infer the number of companies you need to classify between. Return an integer between 1-5 based on the given example." +) + +TGCIDs = predict( + input=["context", "example", "n"], + instruction="Create a list of viable TGCIDs. TGCIDs should be a comma-separated list of TCGIDs of length n. The TGCIDs should be relevant to the given context and example." +) + +@ell.function() +def generate_answer(self, example : str, k : int): + company_descriptions = expand_example(example) + n = number_to_classidfy(example) + passage="asd" + return TGCIDs() + + + +soft_generate_answer = sft(generate_answer) + +objective_function = lambda: y_pred, x jaracard(y_pred, soft_generate_answer(x)) + + + +for _ in range(100): + loss = objective_function(x,y) + loss.backward() + optimizer.step() + + + + + +# Ver yeasy to try new configurations + +if __name__ == "__main__": + generate_answer("A tech giant known for its search engine and various internet services.") + + + +# Reasonable plan + +# 1. Simple deifned prompts +# 2. Combination of prompts +# 3. Evals +# 4. MLE tyring to prompt. + +# Underneath. +@dataclass +class Closure: + model : Any + prompt : str + +@dataclass +class LMP: + closure : Closure + prompt_fn : Callable[[Any], Union[str, List[Message]]] + + + def __call__(self, x): + prompt = self.prompt_fn(x) + # complex stuff + +# Borrow DSPy + +# A compatible model based interface for LMPs +class ChainOfThought(ell.LMP): + def __init__(self, input_shape, output_shape, instruction : str): + + def prompt(self, input): + return [ + ell.system("You are a helpful assistant. Who solves the following task: " + self.instruction), + + ] + + def parse(self, output): + # Post process of output. + return output + + +# Automatic cration of LMPs from decorators: + +@ell.simple(model="gpt-4o") +def say_hello(name : str): + return f"Say hello to {name}!" + +# -> +# This inherently gets created underneath. +class SayHello(ell.LMP): + def __init__(self, input_shape, output_shape): + super().__init__(input_shape, output_shape) + + def prompt(self, input): + return [ + ell.user(f"Say hello to {input}!") + ] + + def parse(self, output): + return output.text + + + + + + + diff --git a/docs/ramblings/thoughtsonevals.md b/docs/ramblings/thoughtsonevals.md new file mode 100644 index 000000000..8aa34b6cd --- /dev/null +++ b/docs/ramblings/thoughtsonevals.md @@ -0,0 +1,981 @@ +# Evals & Metrics + +We need to build a really good evals and metrics functionality for ell so that prompt engineering becomes more rigorous. This should be extremely approachable for a startup or a person from a non-ML background, but with high quality and rigorous statistical analysis built in. It should work naturally with ell studio. It should be fast. It should be easy. + +We can either design this by coming from the ML background and thinking of the traditional train, test, and validation set loss function for a particular model. I think this is probably useful. But then there's also a bit of UX around the actual iteration loop, where we iterate in concert with changes to versions of various different prompts. We store those iterations over time. And then beyond that typical workflow, we should also consider different types of evals that are not metric-based. There are reward model-based evals, critic-based evals (I'll categorize those as the same), zero-one accuracy evals, and then vibe evals. In my own use case, the vibe evals were probably the most relevant, and those are somewhat harder to assess and make functional for the end user. But I can envision a world where you modify your prompt, you define a set of inputs, and then we enable the user in ell studio to very quickly compare inputs and outputs between versions in a blind fashion. This automatically generates comparison data. It also generates an implicit RLHF dataset and lets you rigorously compare between versions instead of the typical workflow of living in Jupyter notebook and looking at the outputs of one change versus the next. We would then want to build a bunch of guides to get people used to this workflow and this is a great introduction to optimizers as well. + +There are a number of features we could build here that might also be interesting. So these are just evaluations which are good for the prompt engineering workflow. They fundamentally have versions themselves, and those versions are a form of data versioning. But we could also consider things like metrics and runs. So what is that particularly in the prompt engineering use case? That looks something like we were doing a prompt engineering session, and I want to iterate on a prompt for one session. So I spin up a run and then I can track the quality of the iteration on that prompt over that run. So this would be the evals and other metrics. This is kind of similar to MLflow runs and actually is somewhat solved for by having the logger, so you don't have to use the same logger over and over again. So this is something we probably don't need to focus on implementing. We just need to think about evals and metrics. But this would effectively allow you to organize versions or prompt engineering sessions or optimizations as runs, and you can look at loss curves and things like this. This is probably extremely important. + +The next thing we would care about: So evals use metrics. Metrics are ways of just like the same probably as TensorBoard metrics, which is that they are collections of XY plot data, histograms, and so on that can be tied to particular ell objects. Would we ever want a world where we tied a metric to an invocation? Let's say we had user feedback. So we could log for every invocation, you know, if there was feedback from the user, you could log it to ell studio. And then in ell studio, you could look at the average feedback for a version or other summary statistics on that. This requires implementation of a very nice plotting library. But yeah, metrics could be logged at any point in time, and they could be tied to any type of invocation. The question is, do we want to sort of automatically tie them? So what's an example of this? The metric is just a number tied to an invocation or tied to an LMP, and so in fact you can plot metrics over any different X axis: time, version, and so on. + +Now we have automatic tracing. And we also have the notion of how do we define a metric. So we have to introduce all of the underlying ell Studio data types or can we just automatically do this? Are metrics only for invocations and LMPs? So what if I just want to log like a random float? It's kind of like MLflow or TensorFlow, right? Like first step, for example. You could totally do that. So there's like two interfaces. There are metrics that are associated with LMPs and classes of metrics there, as well as metrics that are just floating. They need to be associated to something, they'll probably be like a run. Because, in fact, TensorBoard does this even with the same TensorBoard storage directory, right? Every time you connect a TensorBoard, you get a new run. It might be the instance of the process, or something like this. Or you could hard code the run ID. So yeah, there's a lot of different concepts here. Let's recap. + +We have evals, which are easy to run, parallelized dataset runs of an LMP or a function associated with various metrics, and that are versioned themselves. We have metrics which are ways of keeping track of various numbers associated to invocations or LMPs. And then we might also have metrics that are not associated with anything. + +We need to disambiguate and square all of these concepts so that we can come up with a design spec for a really clean interface and ell. Some context here is that ell is a language model programming framework, or prompt engineering framework, which basically enables rapid iteration of prompts on a local machine or in production. And we need to nail this metric and evaluation interface so that prompt engineering becomes extremely easy and rigorous. + + + +### Ramblings on the interface + +```python +# Would be cool if it opened a UI for you to watch the run as it was going down. + + +eval = Evaluation( + name="test", + inputs=dataset, + labels=dataset, + metric=ell.eval.accuracy +) + +# all of the nice scikit learn stuff on rpedictors +# Examples of different evals. +score = ell.eval.accuracy(predict_capital, inputs=dataset[:,0 ], labels=dataset[:,1], name="accuracy") +score = ell.eval.mean_squared_error(predict_capital, inputs=dataset[:,0 ], labels=dataset[:,1], name="mse") +score = ell.eval.r2_score(predict_capital, inputs=dataset[:,0 ], labels=dataset[:,1], name="r2") + + +# An example of why we need ot fix input outptu format +@ell.simple(model="gpt-4o", temperature=0.0) +def critic_fn(input: str, output: str): + """Answer only with 'yes' or 'no'.""" + return f"Is the following output correct? {output} given the input: {input}" + +@ell.function() +def critic_score(input: str, output: str): + return float(critic_fn(input, output) == "yes") + + +# @ell.function() +# def critic_fn(input: str, output: str): +# output = ell.simple(model="gpt-4o", temperature=0.0)(f"I said {output} to the following input: {input}. Is that correct? ") +# return output == "yes" + +print(ell.eval.eval(predict_capital, inputs=dataset[:,0 ], score=critic_score)) + + + +class Metric(Protocol): + def __call__(self, input : Any, output : Any, label : Optional[Any]) -> float: + ... + + def vectorizedcall(self, inputs : List[Any], output: List[Any], labels : Optional[List[Any]] = None) -> List[float]: + ... + + +Evaluation( + name="test", + inputs=dataset, + labels=dataset, # optionally + metrics=[ + ell.eval.accuracy, + critic_score + ] +) +``` + + + +Counter point is metric don't effect the version of an eval if we view evals as datasets. +```python +eval.run +scikitlearn.metrics.accuracy_score( + predictor, + x = dataset[:,0], + y = dataset[:,1] + "asd", + + +eval = Evaluation( + name="linkedinprofiles", + inputs=dataset[:,0], + labels=dataset[:,1], +) + +# This is a problem with the eval.run interface +{ [0.5, 0.5], [0.5, 0.5] } = eval.run(my_lmp, metrics=[ + ell.eval.accuracy, + critic_score +]) # y_pred + +accuracies = ell.eval.accuracy(y_pred, x = dataset[:,0], y = dataset[:,1]) +critic_scores = ell.eval.critic_score(y_pred, x = dataset[:,0], y = dataset[:,1], critic_fn=critic_fn) + +print(accuracies.mean()) +print(critic_scores.mean()) + + + +@ell.simple(model="gpt-4o") +def write_cold_email(input: str): + return f"Write a cold email to the following person: {input}" + +# most impoeritive case +ell.evaluate( + "linkedin profiles", + write_cold_email, + inputs=dataset[:,0], + labels=dataset[:,1], + metrics=[ + ell.eval.accuracy, + critic_score + ] +) +# Options: +-> [ + { + "accuracy": 0.5, + "critic_score": 0.5 + }, + { + "accuracy": 0.5, + "critic_score": 0.5 + } +] +-> [ + [0.5, 0.5], + [0.5, 0.5] +] +-> EvaluationRun( + scores = [ + [0.5, 0.5], + [0.5, 0.5] + ], + inputs = dataset[:,0], + labels = dataset[:,1], + metrics = [ + ell.eval.accuracy, + critic_score + ] +) + + +# OR: +Evaluation( + name="linkedin profiles eval", + inputs=dataset, + labels=dataset, # optionally + metrics=[ + ell.eval.accuracy, + critic_score + ] +) + +eval.run(write_cold_email) -> (any of the above outptus) + + +run.metrics.critic_score +# im gonna version this for you + + + + +@ell.metric() +def accuracy(y_pred, y_true): + return np.mean(np.array(y_pred) == np.array(y_true)) + +def eval(): + with ell.run("linkedin profiles eval", verbose=False): + from concurrent.futures import ThreadPoolExecutor + import numpy as np + + def process_datapoint(x, y): + out = write_cold_email(x) + return accuracy(y_pred=out, y_true=y) + + with ThreadPoolExecutor() as executor: + scores = list(executor.map(process_datapoint, dataset[:,0], dataset[:,1])) + + scores = np.array(scores) + print(scores.mean()) + ell.log_summary("critic_score", scores.mean()) + + +### OR ### + +eval() + + +# combine with evals +# naive metrics are not a part of evals. + +eval = Evaluation( + name="linkedin profiles eval", + inputs=dataset[:,0], + labels=dataset[:,1], +) + +outputs = eval.run(write_cold_email) # cold emails +scores = [] +for x, y,out in zip(dataset[:,0], dataset[:,1], outputs): + scores.append(accuracy(y_pred=out, y_true=y)) + +print(scores.mean()) +ell.log_summary("avg accuracy", scores.mean()) + +def v(inputs, labels, outputs): + scores = [] + for x, y,out in zip(inputs, labels, outputs): + scores.append(accuracy(y_pred=out, y_true=y)) + ell.log_summary("avg accuracy", scores.mean()) + return scores + +eval = Evaluation( + name="linkedin profiles eval", + inputs=dataset[:,0], + labels=dataset[:,1], + score = callback +) + +scores = eval.run(write_cold_email) + + + + +# Supporting promtp engineering of metrics themeselves, + + + +@ell.simple(model="gpt-4o") +def evaluate_email(cold_email: str): + """ You are an empath who is extremely good at evaluating cold emails. You are given a cold email and you must determine if it is good or bad. You will evaluate the cold email on the following criterion: + + A good cold email is: + - Concise < 5 sentences + - Personalized to the recipient + - Extremely non generic (it must be unique to the sender. + +Your outptu should be in the following format: +Analysis: <5 paragraphs of analysis> +Score of email quality: <1-10> +Is good cold email: + """ + return f"Is the following cold email good for the following person: {cold_email}" + + +@ell.retry(max_retries=3) +@ell.metric() +def is_good_cold_email(x : str): + analysis= is_the_cold_email_good(x) + is_good_cold_email = analysis.split("Is good cold email: ")[1] == "yes" + return float(is_good_cold_email) + + + + + +dataset = Dataset( + name="linkedin profiles", + inputs=dataset[:,0], +) + +ell.init(run="linkedin profiles eval", store="./logdir", verbose=True, tags=["cold-email", "linkedin"]) + + +[[1,1], [0,1]] = dataset.evaluate(write_cold_email, n_workers=4, metrics=[is_good_cold_email, accuracy]) + +ell.log_summary("avg accuracy", outputs[:,1].mean()) +ell.log_summary("avg is_good_cold_email", outputs[:,0].mean()) + + + +outputs = eval.run(write_cold_email) +eval.score(write_cold_email, metric=accuracy) +# list[float] +``` + + +How do we square the following interface: +```python + +run = ell.run("linkedin profiles eval") +# or +ell.init(experiment="linkedin profiles eval", store="./logdir", verbose=True, tags=["cold-email", "linkedin"]) +# or + + +# Logs if it can be traced. +@ell.metric() +def accuracy(y_pred : _lstr , y_true) -> float: + return np.mean(np.array(y_pred) == np.array(y_true)) + +scores = [] +for x,y in dataset: + out = write_cold_email(x) + scores.append(accuracy(y_pred=out, y_true=y)) + +print(scores.mean()) +ell.log("cirtic_score", scores.mean()) +``` + +```python + +eval = Evaluation( + name="linkedin profiles eval", + inputs=dataset[:,0], + labels=dataset[:,1], + metric = callback +) + +outputs = ell.parallel(write_cold_email, inputs=dataset, n_workers=4) +score = np.array(ell.parallel(is_good_cold_email, outputs, n_workers=4)) +print(score.mean()) + +evalrun = eval.run(write_cold_email) + +evalrun.scores +evalrun.outputs + + +``` + + +```python + +dataset = ... +@ell.metric() +def accuracy(y_pred, y_true): + return (np.array(y_pred) == np.array(y_true)) + +500 predicions -> accuracy all ofthem get the same accuracy metric. + +def eval(n_workers=4): + with ell.run("linkedin profiles eval", verbose=False): + from concurrent.futures import ThreadPoolExecutor + import numpy as np + + outputs = ell.parallel(write_cold_email, inputs=dataset, n_workers=n_workers) + + scores = accuracy(outputs, dataset[:,1]) + ell.log("accuracy", scores) + + print(scores.mean()) + ell.log_summary("critic_score", scores.mean()) + +eval() + +### OR ### +@ell.metric() +def accuracy(y_pred, y_true): + return np.mean(np.array(y_pred) == np.array(y_true)) + +dataset = ... +eval = Evaluation( + name="linkedin profiles eval", + inputs=dataset[:,0], + labels=dataset[:,1], + metric=accuracy +) + +eval.run(write_cold_email, n_workers=4 +) + +### OR #### + +dataset = [...] + +@ell.eval("linkedin profiles eval") +def eval(lmp : ell.LMP): + outputs = ell.parallel(lmp, inputs=dataset, n_workers=4) + + scores = accuracy(outputs, dataset[:,1]) + ell.log("accuracy", scores.mean()) + return scores, outputs, "num chars" + str(sum([len(x) for x in outputs])) + +eval(write_cold_email) + + +## Discovery. Metrics are inadequate for describingi indivudal scores afforded to invocations. + +class Metric(Protocol): + def __call__(self, input : List[Any], output : List[Any], label : Optional[List[Any]] = None) -> float: + """Produces a aggregate metric.""" + ... + + +# With this example how do we solve adding more metrics on? + + # X, ypred, ytrue -> float. +metric = Callable[[List[Any], List[Any], Optional[List[Any]]], float] + +def accuracy(y_pred, y_true): + return np.mean(np.array(y_pred) == np.array(y_true)) + +dataset = ... +eval = Evaluation( + name="linkedin profiles eval", + inputs=dataset[:,0], + labels=dataset[:,1], + metric=accuracy +) + +run = eval.run(write_cold_email, n_workers=4) + +run.scores # list[float] +run.result # float + +# recall +y_pred, y_true = run.outputs, run.labels +print(ell.metrics.recall(y_pred, y_true)) + + +# This is bad. + +eval = Evaluation( + name="linkedin profiles eval", + inputs=dataset[:,0], + labels=dataset[:,1], + metrics=dict( + accuracy=ell.metrics.accuracy, + recall=ell.metrics.recall, + critic_score=ell.metrics.critic_score + ) +) + +run = eval.run(write_cold_email, n_workers=4) + +run.scores # list[float] +run.result['accuracy'] # list[float] +run.result['recall'] # list[float] +run.result['critic_score'] # list[float] + + +# 1. We would want to introduce the following: +# Evaluation (heleper) +# Evaluation Run + +# Metric (always aggregrated) +# Grouping. + # Evaluatio induces a grouping. + + +#%% + +# WHat does aggregate reward model look like? + + +@ell.simple(model="gpt-4o") +def evaluate_email(cold_email: str): + """ You are an empath who is extremely good at evaluating cold emails. You are given a cold email and you must determine if it is good or bad. You will evaluate the cold email on the following criterion: + + A good cold email is: + - Concise < 5 sentences + - Personalized to the recipient + - Extremely non generic (it must be unique to the sender. + +Your outptu should be in the following format: +Analysis: <5 paragraphs of analysis> +Score of email quality: <1-10> +Is good cold email: + """ + return f"Is the following cold email good for the following person: {cold_email}" + + +def quality_score(y_pred : List[Any], y_pred : Optional[List[Any]] = None) -> float: + analyses = ell.parallel(evaluate_email, y_pred, n_workers=4) + is_good_cold_email = [float(analysis.split("Is good cold email: ")[1] == "yes") for analysis in analyses] + return np.mean(is_good_cold_email) + + +ell.metricize(quality_score) # No. +``` +### Conclusion + +1. Evaluations + - Evaluatio nclass + - Versioned datasets + - Versioned "metrics" +2. Runs (groupds of invocations) +3. Metric is an **aggregate** statistic on a potentially labeled dataset. + - This is scikit learn metrics interface. + +Which are runs along with versioned datasets and metrics + +## Run vs Interprocess Execution. +Interprocess executions versus runs. Currently, we have evaluations that run on a single process, and those evaluations consist of groups of individual evaluation runs. These evaluation runs are single-process and linked to evaluations. We could envision a scenario where we want to run evaluations or group runs across multiple processes. For instance, let's say I track scores of my evaluations as I perform a prompt engineering process on a thesis. I want to group my runs or invocations by, for example, an emotional empathy thesis. If I'm building a cold email writer, this grouping would be useful. + +The only way to effectively flag this would be to make The init function Specify the current experiment or some such similar flag +``` +ell.init(experiment="emotional empathy", store='./logdir) +``` + +The tree would look like this: +``` +experiment[emotional-empathy] -> + run[leap-for-the-sky-o6] -> # process 1 + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + run[leap-for-the-sky-o7] -> # process 2 + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + run[leap-for-the-sky-o8] -> # process 3 + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] +``` + + +With an eval we get + +``` +experiment[emotional-empathy] -> + evaluation[cold-email-writer] -> + evaluation run | + run[leap-for-the-sky-o6] -> # no logner correspodns to a single process execution. + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + evaluation run | + run[leap-for-the-sky-o7] -> + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] +``` + +In general automatically grouping by "run" Is a bad thing when we think about The production execution with multiple process brokers and processes. + +We can just have experiment labeling as a convenience function where. otherwise We don't label by experiment. And also, this doesn't make sense, necessarily, for production runs as well. + +In that case. We might end up with something that looks like this + +``` +experiment[emotional-empathy] -> + evaluation[cold-email-writer] -> + evaluation run | + run[leap-for-the-sky-o6] -> # no logner correspodns to a single process execution. + invocation[cold-email-1] + invocation[cold-email-2] + invocation[cold-email-3] + invocation[cold-email-4] + invocation[cold-email-5] +invocation[cold-email-6] # happened in dev +invocation[cold-email-7] # happened in production +invocation[cold-email-8] +``` + +We don't necessarily need to solve this in this PR. If we can just build an abstraction that works for evaluations and individualized metrics without thinking about per-invocation scores, adding that later, then we would be happy. I suppose we could take a look at the open AI eval suite and then go from there, just to see if they do scoring per invocation. We could additionally just have an invocation score, and that would be many-to-many. You can score invocations in many different ways. For evals, you build a metric function which takes in a bunch of invocations and labels, and then produces an aggregated metric that is one-to-one with evaluations. We version those metrics. We support individual invocation scores in this PR as a part of the individual changes to the DB schema as well as migrations. But otherwise, don't opinionate the implementation, and we end up with schemas for evaluations and evaluation runs independent of experiments. In fact, this is probably sufficient. + +The question would be whether or not we separate evaluation runs and run groupings of invocations, or do we just tie invocations directly to evaluations? If we were to introduce runs as a full-fledged feature within experiments, or something like this later, then we would have sort of a legacy evaluation runs thing that we need to get rid of later. + + +### OpenAI Evals +It is evident that OpenAI evaluations support multiple metrics by default. An interesting aspect is their criterion specification. Overall the UX is lacking. No ability to see the outputs of individual invocations (perhaps this is what happens if you share the evals with OpenAI etc.). + + +Probably not a good model for our evals but: +1. Test Data + - Import a JSONL file with existing cases or prompts. + +2. Generate Responses + - Generate responses (Optional) + - Prompt + - Generated responses can be evaluated using the sample.output_text variable within testing criteria. + + +3. Criterion (multiple) + a. Factuality + - Check if the content is factually accurate. + b. Semantic Similarity + - Compare generated text to the reference. + c. Sentiment + - Identify the emotional tone of the model’s response. + d. String Check + - Check if the model’s response includes specific string(s). + e. Valid JSON or XML + - Check if the model’s response is valid JSON or XML. + f. Matches Schema + - Ensure the model’s response follows the specified structure. + g. Criteria Match + - Assess if the model’s response matches your criteria. + h. Text Quality + - Assess response quality with BLEU, ROUGE, or Cosine algorithms. + i. Custom Prompt + - Create a test criterion by writing your own custom prompt. + +Concrete recommendations for our case are as follows: we want pre-canned criteria and having per-invocation criteria seems to be important for evaluations. It is also important to note that most of these evaluations are done by models, which is how OpenAI prefers to run them. In our case, we are not designing evaluations for individual per-invocation criteria, but rather aggregate metrics to be more in line with scikit-learn and similar frameworks. + +If we were to go in the opposite direction and support per-invocation criteria, allowing arbitrary datasets in an evaluation and providing a programmatic API for OpenAI-like invocations, it might be worthwhile. If I were to envision that API shape, it would look like the following: + + + +```python +def criterion1(datapoint): + input = datapoint[0] + output = datapoint[1] + return float(output == "yes") + + +@ell.simple(model="gpt-4o-mini") +def was_gramatically_correct(datapoint): + return f"Did the following cold email have any gramatical errors? {datapoint[1]}. Answer with yes or no." + +def criterion2(datapoint): + return float( "yes" in was_gramatically_correct(datapoint).lower()) + +evaluation = Evaluation( + name="cold-email-evaluation", + dataset=dataset, + criterion=[ + criterion1, + criterion2, + criterion3 + ] +) + +evaluation.run() +``` + +What's really weird about this is that it's completely dependent on a dataset of input-output pairs and it doesn't rerun on the same prompt, which is effectively the goal of evaluations in DSP. + +We could kind of re-envision this. If we think about one part of the OpenAI evals, you were allowed to generate a response on top of a dataset, and that's the kind of thing we'll put here. So the dataset will include all the output labels, and then we have a generate response function that takes in some data point and then generates responses as a result of that. But then the user has to index into what would be the correct input. By allowing datasets to have arbitrary data points in them, including the labels and arbitrary columns, just think about them as JSON objects. Then you can have criteria that fit all sorts of different settings, right? So I could have inputs, I could have 100 different labels, and I can assess the total criterion of everything. And then I can think about pass/fail in general. And that might actually be the right thing to do. So let's imagine now we have response generation as a key component. + +```python + +def criterion1(datapoint, output): + input = datapoint['input'] + desired_output = datapoint['output'] + desired_sentiment = datapoint['desired_sentiment'] + +... + + +evaluation = Evaluation( + name="cold-email-evaluation", + dataset=dataset, # pandas dataframe?? + criterion=[ + criterion1, + criterion2, + criterion3 + ] +) + + +evaluation.run( + my_lmp +) + +This works for criterion but itsn ot clear what input the LMP should take. +We could seperate out input and output into two seperate columns. + + +evaluation = Evaluation( + name="cold-email-evaluation", + dataset=dataset, # pandas dataframe?? + labels=label_dataset, + criterion=[ + criterion1, + criterion2, + criterion3 + ] +) + +@ell.simple(model="gpt-4o-mini") +def my_lmp(datapoint): + # Columns of the input dataset are the args to mylmp? + # How would would I actually want this in practice ? + pass + + +evaluation.run( + my_lmp +) +``` + + + +Weave does a much better job at orgnaizing this. There opinion is the following + + +```python +dataset = Dataset( + [ + {"some_random_shit" : "some_value", "expected_output" : "some_value", "other_column" : "other_value"}, + {"some_random_shit" : "some_value", "expected_output" : "some_value", "other_column" : "other_value"}, + {"some_random_shit" : "some_value", "expected_output" : "some_value", "other_column" : "other_value"}, + ] +) + +``` + +Now this data set automatically gets versioned. + +```python +eval = Evaluation( + name="my_evaluation", + dataset=dataset, + scores=[ + score1, + score2, + ] +``` + + +So you basically define scorers, and they will automatically extract the relevant. columns from the data set based on the inspected parameter arguments. + +```python + +def score1(expected_output, output): + return np.mean(np.array(expected_output) == np.array(output)) + +def score2(other_column, output): + return np.mean(np.array(expected_output) == np.array(output)) + +args = inspect.signature(score1).parameters + +... + +class Evaluation: + def run(): + results = {} + for datapoint in self.dataset: + for score in self.scores: + args = inspect.signature(score).parameters + datapoint_subset = {k : datapoint[k] for k in args} + score_output = score(**datapoint_subset) + results[score.__name__].append(score_output) + +``` + + +I do like the idea of being able to publish a dataset. I think we should have parity there. I do like that evaluations are automatically versioned in some sense. We should also have parity there. + +I'm not convinced that the shape of the evaluation function should look like Model output, etc. Also, the model evaluation itself doesn't seem very clean, right? If I'm developing an LMP that I'm going to use somewhere else in my software stack, and I want to evaluate it, now I have to wrap it in some additional function. This layer of indirection between me and the evaluation might cause failure later. The fundamental data shape of the evaluation should be kind of like whatever I'm always using in my LMPs plus labels. I don't want to think about inputs versus outputs in my criterion. The input data shape is holy, and in some sense, I don't want to have to change the LMP's source code just because my input data shape has changed. As for scoring functions, yeah, I think there's some convenience in being able to pull out from the rows of a dataset. In that way, it makes sense. Also, look how clean it is to specify datasets like that. It is beautiful what we've done there, though the Thursday AI guys are not going to like that I've changed this in the way that I have. + +Also, datasets fundamentally are very weird. If they only exist for the purpose of evals, then maybe the abstraction doesn't make any sense. We can do a lot of things with datasets, right? We can fine-tune on them. The traditional ML literature doesn't exactly make sense for the RL use case, which is where we want to head with this. So I think by developing a dataset abstraction now, we're going to cause problems later when we decide to do RL on prompts and things of this nature. + +So let's suppose we just ignore the traditional dataset abstraction for now, right? And so we ship evals as a feature. Evals as a feature just have inputs and outputs. Now, the problem with inputs and outputs, right? When I was actually doing stuff in production with these models, we would take in many times non-serializable objects. I don't want the user to have to think about whether or not their object is truly JSON serializable. But it's not clear exactly how we would define the dataset if it weren't just unpacking some sort of dictionary into the kwargs, right? The shape always has to be a dictionary, and the other convenience function of having datasets where the labels and the inputs and outputs are in line. People are very used to working that way. So if we had separate inputs and outputs, they would have to zip these together, which doesn't make a lot of sense. Okay, but we can be magical like Weaviate. + +When you define the dataset, you do rows of dictionaries, where rows correspond to named kwargs of your LMP. Again, the named kwargs thing is actually bad because certain LMPs are also positional. So being able to swap in an LMP and another LMP, one that uses slightly different named kwargs, will totally break the process. So that's not acceptable. What we can do is serialize sort of positional and non-positional named kwargs in the dataset formulation. Each row can contain an input object, and that input object is either a list or a dictionary. You can name the kwargs or not. It's probably a bad idea to name the kwargs because then you can't swap in different LMPs. Then we always use the inputs here. And this is just typed dicts. I guess we have a type dictionary. And we always use that. Then the rest of the outputs, you can do whatever you want. Your score function will take in the row and the model output. What Weaviate did was they literally said, "Hey, your thing has to accept model output as a kwarg." So they validate the score functions by inspecting keyword arguments. Then everything's wrapped in a Weaviate op because that Weaviate ops automatically log to wandb. So they unify the interface of logging in that way. + +```python + +# Example implementation based on the ideas discussed + +from typing import List, Dict, Any, Union, Callable +import inspect +import numpy as np + +# Define a flexible Dataset type +Dataset = List[Dict[str, Any]] + +# Example dataset +dataset: Dataset = [ + {"input": "What is the capital of France?", "expected_output": "Paris", "difficulty": "easy"}, + {"input": "What is the square root of 144?", "expected_output": "12", "difficulty": "medium"}, + # ... more data points +] + +# Example LMP (Language Model Program) +def my_lmp(input: str) -> str: + # This is a mock LMP that just returns the input + return input + +# Example score functions +def accuracy_score(expected_output: str, output: str) -> float: + return float(expected_output.lower() == output.lower()) + +def difficulty_weighted_score(difficulty: str, expected_output: str, output: str) -> float: + base_score = float(expected_output.lower() == output.lower()) + difficulty_weight = {"easy": 1.0, "medium": 1.5, "hard": 2.0} + return base_score * difficulty_weight.get(difficulty, 1.0) + +class Evaluation: + def __init__(self, name: str, dataset: Dataset, lmp: Callable, scores: List[Callable]): + self.name = name + self.dataset = dataset + self.lmp = lmp + self.scores = scores + + def run(self) -> Dict[str, List[float]]: + results = {score.__name__: [] for score in self.scores} + + for datapoint in self.dataset: + # Run the LMP + lmp_input = datapoint.get("input") + if isinstance(lmp_input, str): + output = self.lmp(lmp_input) + elif isinstance(lmp_input, dict): + output = self.lmp(**lmp_input) + elif isinstance(lmp_input, list): + output = self.lmp(*lmp_input) + else: + raise ValueError(f"Unsupported input type: {type(lmp_input)}") +``` +Alright, so this part is a bit too magical. Essentially, what it's doing is taking the input object and, if it's a single object, passing it directly into the LMP. Otherwise, it destructures the arguments. I do appreciate the use of **kwargs versus list destructuring; it's quite elegant. We can think of it as handling both args and kwargs, which is fine. However, it's also quite clean to write your dataset as single input elements. + +```python + # Calculate scores + for score in self.scores: + args = inspect.signature(score).parameters + datapoint_subset = {k: datapoint.get(k) for k in args if k != 'output'} + score_output = score(**datapoint_subset, output=output) + results[score.__name__].append(score_output) + + return results + +# Usage example +eval = Evaluation( + name="my_evaluation", + dataset=dataset, + lmp=my_lmp, + scores=[accuracy_score, difficulty_weighted_score] +) + +results = eval.run() +print(results) + +# You could then add methods to analyze and visualize the results +# For example: +def analyze_results(results: Dict[str, List[float]]): + for score_name, scores in results.items(): + print(f"{score_name}:") + print(f" Mean: {np.mean(scores):.4f}") + print(f" Median: {np.median(scores):.4f}") + print(f" Min: {np.min(scores):.4f}") + print(f" Max: {np.max(scores):.4f}") + +analyze_results(results) + +``` + +So now let's consider The usability of these input shapes. If we're really going to accept that, there's like some special input data point arg. + + +```python +class DatapointPD(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + input : Dict[str, Any] | List[Any] + labels: Dict[str, Any] + + +# or + +class DatapointTD(TypedDict, total=False): + input : Dict[str, Any] | List[Any] + +# finally +Dataset = List[Datapoint] + + +# This is actually quite in the style of ell where we have input and output in ell studio as either a list of arguments or a dictionary of kwargs. +dataset = [ + Datapoint(input=["What is the capital of France?"], labels={"expected_output": "Paris"}), +] + +# or +dataset = [ + {"input" : {'question' : "What is the capital of France?"}, "answer" : "Paris"}, +] +#/equivalently +dataset = [ + DatapointTD(input=["What is the capital of France?"], labels={"expected_output": "Paris"}), +] + +``` + +This approach is quite elegant. We need to use Pydantic models with `total=False` so we can validate that each entry has an input. + +Imagine defining a dataset in this structured way, where every entry must at least have the shape of an input. You can then add arbitrary fields to the dataset columns. This avoids the issue where the shape of the LMP function needs to be transformed. + +So let's actually write out what the final form of this might actually look like and see if it's palatable. If it's not that's okay. +```python + + +@ell.simple(model="gpt-4o-mini") +def write_a_poem(about :str): + """You are PoetGPT. You always write in iambic pentameter. Only answer with a poem.""" + return f"Write a poem about {about}" + + +@ell.simple(model="gpt-4o-mini") +def iambic_pentameter(poem :str): + return f"Is the following poem in iambic pentameter? {output} answer with yes or no." + + +# This is like OpenAI + weave evals. + +eval = Evaluation( + name="poem-eval", + dataset=[ + Datapoint(input=["a rose"], must_contain="rose", minimum_length=100), + Datapoint(input=["a sunset"], must_contain="sunset", minimum_length=100), + Datapoint(input=["a rainbow"], must_contain="", refuse=True, minimum_length=100), + ], + criterion=[ + lambda datapoint, output: datapoint.must_contain in output, + lambda datapoint, output: len(output) >= datapoint.minimum_length, + lambda datapoint, output: "I refuse to write a poem about that" in output or not datapoint.refuse, + lambda datapoint, output: "yes" in iambic_pentameter(output).lower(), + ] +) + + +eval.run(write_a_poem) +# a set of scores. +# Then we modify write a poem + + + +@ell.simple(model="gpt-4o-mini") +def write_a_poem(about :str): + """You are PoetGPT. You always write in iambic pentameter. Only answer with a poem. Say I refuse to write a poem about that if you are asked to write about rianbows """ + return f"Write a poem about {about}" + + +# Now the refusal criterion will work. +eval.run(write_a_poem) + +# Now we improve iambic pentameter score by trying to rewrite the poem. + +@ell.simple(model="gpt-4o-mini") +def better_poem_writer(about :str): + """You are a poet. You are a poet who is extremely good at writing iambic pentameter. If the poem says I refuse just copy the refusal""" + initial_poem = write_a_poem(about) + + return f"Rewrite the following poem in iambic pentameter: {initial_poem}" + + +eval.run(better_poem_writer) +# highest score. + +``` + +I think I like this Eval the most from any of the specs I have come up with. You can just throw accuracy criteria in there. It's very easy by specifying how the dataset looks. The Weave guys definitely built a really good abstraction here. Some small changes around where things feel magical make this pretty close to an abstraction that we can use. In the above example, it's extremely readable as to what's going on, and I can imagine a very simple flow where I iteratively improve things. I don't have to worry about what's going on with the individual args or kwargs, as they're specified in the input dict. If there's a mismatch, then I just use arguments instead of kwargs. As for the criterion, you just take in the data point and the output. It's just two positional arguments. The data point is literally just what came from the dataset. So if you ever need to look at the schema, it's all there. Inputs are separated out. Inputs are a requirement for data points. We can validate that when we build the eval. This is a very particular type of dataset, and this lets you very quickly and rapidly develop fast evaluations. + + +The only problem here is I think what is very nice about the OpenAI evaluation product is that it comes with tons of evaluations by default. For example, text similarity, text quality, BLEU score, things like this. And because the dataset is so free, we don't have an expected output. We can't run metrics automatically. + +We could, by default, actually include something inside the metric functionality, like a special keyword in the dataset. If we actually use the reserved expected output keyword, then you can just use pre-canned metrics without having to specify them, because then we're sort of moving the transmutation of metrics to the criterion specification, right? But I could automatically run things like BLEU score or text similarity if you use the expected output keyword. Otherwise, I guess we could just make them instantiable, so I might actually prefer this. So let's just do this, for example. + + +```python + + +from ell.evals import cosine_similarity + +@ell.simple(model="gpt-4o-mini") +def write_a_poem(about :str): + """You are PoetGPT. Write with cheesy well-known poems if available.""" + return f"Write a poem about {about}" + + +eval = Evaluation( + name="poem-eval", + dataset=[ + # jsonl injection into dataset formula + Datapoint(input=["a rose"], expert_poem="Roses are red, violets are blue, sugar is sweet, and so are you.") + ], + criterion=[ + cosine_similarity("text-embedding-3-small", expected_output="expert_poem", inner_product="normal") + ] +) + +# can automatically do cosine similarity & other nice things +eval.run(write_a_poem) + +``` + + diff --git a/docs/ramblings/yield_ell.py b/docs/ramblings/yield_ell.py new file mode 100644 index 000000000..7812b2ca0 --- /dev/null +++ b/docs/ramblings/yield_ell.py @@ -0,0 +1,152 @@ + +from functools import wraps +from typing import Generator, Any +# Precanned AI responses for demo purposes +ai_responses = [ + "The capital of France is Paris.", + "Population: 2200000", +] + + +def lmp(func): + """ + Decorator that simulates multi-step calls to an LLM API. + Prints each step and collects all yields from a generator. + Returns the collected values. + """ + @wraps(func) + def wrapper(*args, **kwargs): + system_prompt = func.__doc__ + print(f"\033[94mSystem: {system_prompt}\033[0m") + generator = func(*args, **kwargs) + message_history = [] + step = 1 + + try: + user_prompt = next(generator) + while True: + print(f"\033[92mUser: {user_prompt}\033[0m") + message_history.append({"role": "user", "content": user_prompt}) + + # Use precanned AI response + ai_response = ai_responses[step - 1] if step <= len(ai_responses) else f"AI response for step {step}" + print(f"\033[93mAssistant: {ai_response}\033[0m") + + message_history.append({"role": "assistant", "content": ai_response}) + step += 1 + + # Send AI response back to the generator + user_prompt = generator.send(ai_response) + + except StopIteration as e: + return e.value + return wrapper + +@lmp +def multistep_prompt(): + """You are a helpful assistant.""" + assistant_response = yield "What is the capital of France?" + print("City!", assistant_response) + assistant_response_2 = yield "What is the population of that city?" + + # This is allowed in a generator + return int(assistant_response_2.split("Population: ")[-1]) + +# Execute the multi-step prompt +result = multistep_prompt() +print(f"{result}") + + +import asyncio +from functools import wraps + +async def async_lmp(func): + """ + Async decorator that simulates multi-step calls to an LLM API. + Prints each step and collects all yields from an async generator. + Returns the collected values. + """ + @wraps(func) + async def wrapper(*args, **kwargs): + system_prompt = func.__doc__ + print(f"\033[94mSystem: {system_prompt}\033[0m") + generator = func(*args, **kwargs) + message_history = [] + step = 1 + + try: + user_prompt = await anext(generator) + while True: + print(f"\033[92mUser: {user_prompt}\033[0m") + message_history.append({"role": "user", "content": user_prompt}) + + # Use precanned AI response + ai_response = ai_responses[step - 1] if step <= len(ai_responses) else f"AI response for step {step}" + print(f"\033[93mAssistant: {ai_response}\033[0m") + + message_history.append({"role": "assistant", "content": ai_response}) + step += 1 + + # Send AI response back to the generator + user_prompt = await generator.asend(ai_response) + + except StopAsyncIteration as e: + return e.value + return wrapper + +@async_generator +async def async_multistep_prompt(): + """You are a helpful assistant.""" + resp = await yield_("What is the capital of France?") + resp = await yield_("What is the population of that city?") + return int(resp.split("Population: ")[-1]) + +async def main(): + result = await async_multistep_prompt() + print(f"{result}") + +asyncio.run(main()) + + +# so in some sense this is the most natural interface for ell which is just the fucking api iterface with an lmp context for multistep, the yield statment feels just right for multistep though. it's so unclear to me why async generators do not have a return value though. +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield ell.user("What is the capital of France?") + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield [ell.user("What is the capital of France?")] + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield ell.Call(messages=[ell.user("What is the capital of France?")], api_params={"max_tokens": 10}) + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + resp = yield [ell.user("What is the capital of France?")], {"max_tokens": 10} + + +# This is unacceptable. + +@ell.lmp(model="gpt-4o", temperature=0.0, api_params={"max_tokens": 1000}) +def my_prompt(): + claude_says = yield "What is the capital of France?", {'model': 'claude'} + gpt_says = yield "What is the capital of France?" + +--> + +def normal_prompt(): + anthropic_client = anthropic.Anthropic() + openai_client = openai.OpenAI() + + claude_says = anthropic_client.messages.create(model="claude-3-opus", messages=[{"role": "user", "content": "What is the capital of France?"}]) + gpt_says = openai_client.chat.completions.create(model="gpt-4o", messages=[ + {"role": "user", "content": "What is the capital of France?"}, + { "role": "assistant", "content": claude_says.content}, + {"role": "user", "content": "What is the capital of France?"} + ]) + + return None + + + diff --git a/ell-studio/package-lock.json b/ell-studio/package-lock.json index ea00162d3..5f0d335fd 100644 --- a/ell-studio/package-lock.json +++ b/ell-studio/package-lock.json @@ -14,12 +14,15 @@ "@radix-ui/react-scroll-area": "^1.1.0", "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.1.0", + "@radix-ui/react-tooltip": "^1.1.3", "@tanstack/react-query": "^5.51.21", "@testing-library/jest-dom": "^5.17.0", "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", "axios": "^1.6.0", "base64-js": "^1.5.1", + "chart.js": "^4.4.4", + "chartjs-chart-error-bars": "^4.4.2", "class-variance-authority": "^0.7.0", "clsx": "^2.1.1", "d3-force": "^3.0.0", @@ -33,6 +36,7 @@ "npm": "^10.8.2", "prismjs": "^1.29.0", "react": "^18.3.1", + "react-chartjs-2": "^5.2.0", "react-dom": "^18.3.1", "react-hot-toast": "^2.4.1", "react-icons": "^5.2.1", @@ -41,6 +45,7 @@ "react-responsive": "^10.0.0", "react-router-dom": "^6.18.0", "react-scripts": "^5.0.1", + "react-sparklines": "^1.7.0", "react-syntax-highlighter": "^15.5.0", "reactflow": "^11.11.4", "recharts": "^2.12.7", @@ -3512,6 +3517,12 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@kurkle/color": { + "version": "0.3.2", + "resolved": "https://registry.npmjs.org/@kurkle/color/-/color-0.3.2.tgz", + "integrity": "sha512-fuscdXJ9G1qb7W8VdHi+IwRqij3lBkosAm4ydQtEmbY58OzHXqQhvlxqEkoz0yssNVn38bcpRWgA9PP+OGoisw==", + "license": "MIT" + }, "node_modules/@leichtgewicht/ip-codec": { "version": "2.0.5", "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.5.tgz", @@ -4047,6 +4058,130 @@ } } }, + "node_modules/@radix-ui/react-tooltip": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.1.3.tgz", + "integrity": "sha512-Z4w1FIS0BqVFI2c1jZvb/uDVJijJjJ2ZMuPV81oVgTZ7g3BZxobplnMVvXtFWgtozdvYJ+MFWtwkM5S2HnAong==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.0", + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-context": "1.1.1", + "@radix-ui/react-dismissable-layer": "1.1.1", + "@radix-ui/react-id": "1.1.0", + "@radix-ui/react-popper": "1.2.0", + "@radix-ui/react-portal": "1.1.2", + "@radix-ui/react-presence": "1.1.1", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-slot": "1.1.0", + "@radix-ui/react-use-controllable-state": "1.1.0", + "@radix-ui/react-visually-hidden": "1.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-context": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz", + "integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==", + "license": "MIT", + "peerDependencies": { + "@types/react": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-dismissable-layer": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.1.tgz", + "integrity": "sha512-QSxg29lfr/xcev6kSz7MAlmDnzbP1eI/Dwn3Tp1ip0KT5CUELsxkekFEMVBEoykI3oV39hKT4TKZzBNMbcTZYQ==", + "license": "MIT", + "dependencies": { + "@radix-ui/primitive": "1.1.0", + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-use-callback-ref": "1.1.0", + "@radix-ui/react-use-escape-keydown": "1.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-portal": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.2.tgz", + "integrity": "sha512-WeDYLGPxJb/5EGBoedyJbT0MpoULmwnIPMJMSldkuiMsBAv7N1cRdsTWZWht9vpPOiN3qyiGAtbK2is47/uMFg==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-primitive": "2.0.0", + "@radix-ui/react-use-layout-effect": "1.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, + "node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-presence": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@radix-ui/react-presence/-/react-presence-1.1.1.tgz", + "integrity": "sha512-IeFXVi4YS1K0wVZzXNrbaaUvIJ3qdY+/Ih4eHFhWA9SwGR9UDX7Ck8abvL57C4cv3wwMvUE0OG69Qc3NCcTe/A==", + "license": "MIT", + "dependencies": { + "@radix-ui/react-compose-refs": "1.1.0", + "@radix-ui/react-use-layout-effect": "1.1.0" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc", + "react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-use-callback-ref": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.0.tgz", @@ -7677,6 +7812,27 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/chart.js": { + "version": "4.4.4", + "resolved": "https://registry.npmjs.org/chart.js/-/chart.js-4.4.4.tgz", + "integrity": "sha512-emICKGBABnxhMjUjlYRR12PmOXhJ2eJjEHL2/dZlWjxRAZT1D8xplLFq5M0tMQK8ja+wBS/tuVEJB5C6r7VxJA==", + "license": "MIT", + "dependencies": { + "@kurkle/color": "^0.3.0" + }, + "engines": { + "pnpm": ">=8" + } + }, + "node_modules/chartjs-chart-error-bars": { + "version": "4.4.2", + "resolved": "https://registry.npmjs.org/chartjs-chart-error-bars/-/chartjs-chart-error-bars-4.4.2.tgz", + "integrity": "sha512-rRjfAKjwgoCc6Dt1WE8OWJlBYkNFk68xwVRGAK9J3pAjcUICOgN7E+VVqsCUXb6BBMd3SRFlwtr7u2M5L9AClQ==", + "license": "MIT", + "peerDependencies": { + "chart.js": "^4.1.0" + } + }, "node_modules/check-types": { "version": "11.2.3", "resolved": "https://registry.npmjs.org/check-types/-/check-types-11.2.3.tgz", @@ -21044,6 +21200,16 @@ "integrity": "sha512-kY1AZVr2Ra+t+piVaJ4gxaFaReZVH40AKNo7UCX6W+dEwBo/2oZJzqfuN1qLq1oL45o56cPaTXELwrTh8Fpggg==", "license": "MIT" }, + "node_modules/react-chartjs-2": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/react-chartjs-2/-/react-chartjs-2-5.2.0.tgz", + "integrity": "sha512-98iN5aguJyVSxp5U3CblRLH67J8gkfyGNbiK3c+l1QI/G4irHMPQw44aEPmjVag+YKTyQ260NcF82GTQ3bdscA==", + "license": "MIT", + "peerDependencies": { + "chart.js": "^4.1.1", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + } + }, "node_modules/react-dev-utils": { "version": "12.0.1", "resolved": "https://registry.npmjs.org/react-dev-utils/-/react-dev-utils-12.0.1.tgz", @@ -21526,6 +21692,19 @@ "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0" } }, + "node_modules/react-sparklines": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/react-sparklines/-/react-sparklines-1.7.0.tgz", + "integrity": "sha512-bJFt9K4c5Z0k44G8KtxIhbG+iyxrKjBZhdW6afP+R7EnIq+iKjbWbEFISrf3WKNFsda+C46XAfnX0StS5fbDcg==", + "license": "MIT", + "dependencies": { + "prop-types": "^15.5.10" + }, + "peerDependencies": { + "react": "*", + "react-dom": "*" + } + }, "node_modules/react-style-singleton": { "version": "2.2.1", "resolved": "https://registry.npmjs.org/react-style-singleton/-/react-style-singleton-2.2.1.tgz", diff --git a/ell-studio/package.json b/ell-studio/package.json index 9d7e8c133..d7b22c097 100644 --- a/ell-studio/package.json +++ b/ell-studio/package.json @@ -9,12 +9,15 @@ "@radix-ui/react-scroll-area": "^1.1.0", "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.1.0", + "@radix-ui/react-tooltip": "^1.1.3", "@tanstack/react-query": "^5.51.21", "@testing-library/jest-dom": "^5.17.0", "@testing-library/react": "^13.4.0", "@testing-library/user-event": "^13.5.0", "axios": "^1.6.0", "base64-js": "^1.5.1", + "chart.js": "^4.4.4", + "chartjs-chart-error-bars": "^4.4.2", "class-variance-authority": "^0.7.0", "clsx": "^2.1.1", "d3-force": "^3.0.0", @@ -28,6 +31,7 @@ "npm": "^10.8.2", "prismjs": "^1.29.0", "react": "^18.3.1", + "react-chartjs-2": "^5.2.0", "react-dom": "^18.3.1", "react-hot-toast": "^2.4.1", "react-icons": "^5.2.1", @@ -36,6 +40,7 @@ "react-responsive": "^10.0.0", "react-router-dom": "^6.18.0", "react-scripts": "^5.0.1", + "react-sparklines": "^1.7.0", "react-syntax-highlighter": "^15.5.0", "reactflow": "^11.11.4", "recharts": "^2.12.7", diff --git a/ell-studio/src/App.js b/ell-studio/src/App.js index d3e1d971f..a27afef1f 100644 --- a/ell-studio/src/App.js +++ b/ell-studio/src/App.js @@ -5,11 +5,14 @@ import Sidebar from './components/Sidebar'; import Home from './pages/Home'; import LMP from './pages/LMP'; import Invocations from './pages/Invocations'; +import Evaluations from './pages/Evaluations'; +import Evaluation from './pages/Evaluation'; import { ThemeProvider } from './contexts/ThemeContext'; import './styles/globals.css'; import './styles/sourceCode.css'; import { useWebSocketConnection } from './hooks/useBackend'; import { Toaster, toast } from 'react-hot-toast'; +import EvaluationRun from './pages/EvaluationRun'; const WebSocketConnectionProvider = ({children}) => { const { isConnected } = useWebSocketConnection(); @@ -86,11 +89,14 @@ function App() {
-
+
} /> } /> } /> + } /> + } /> + } />
diff --git a/ell-studio/src/components/HierarchicalTable.js b/ell-studio/src/components/HierarchicalTable.js index 48fee663c..5c5fb2035 100644 --- a/ell-studio/src/components/HierarchicalTable.js +++ b/ell-studio/src/components/HierarchicalTable.js @@ -68,8 +68,8 @@ const SmoothLine = ({ index, startX, startY, endX: endXPreprocess, special, endY -const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWidth, rowClassName, setRowRef, links, linkColumn }) => { - const { expandedRows, selectedRows, toggleRow, toggleSelection, isItemSelected, setHoveredRow, sortedData } = useHierarchicalTable(); +const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWidth, rowClassName, setRowRef, links, linkColumn, showHierarchical, statusColumn }) => { + const { expandedRows, selectedRows, toggleRow, toggleSelection, isItemSelected, setHoveredRow, sortedData, hoveredRow } = useHierarchicalTable(); const hasChildren = item.children && item.children.length > 0; const isExpanded = expandedRows[item.id]; const isSelected = isItemSelected(item); @@ -132,7 +132,7 @@ const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWid ${customRowClassName} ${isNew ? 'animate-fade-in bg-green-900/30' : ''}`} onClick={() => { - if (onRowClick) onRowClick(item); + if (onRowClick) onRowClick(item, toggleRow); }} onMouseEnter={() => setHoveredRow(item.id)} onMouseLeave={() => setHoveredRow(null)} @@ -144,19 +144,26 @@ const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWid onClick={(e) => e.stopPropagation()} /> - -
- {hasChildren && ( - { e.stopPropagation(); toggleRow(item.id); }} - > - {isExpanded ? : } - - )} -
- + + {showHierarchical ? ( +
+ {hasChildren && ( + { e.stopPropagation(); toggleRow(item.id); }}> + {isExpanded ? : } + + )} +
+ ) : statusColumn?.render ? ( +
+ {statusColumn.render(item)} +
+ ) : null} {schema.columns.map((column, index) => { - const content = column.render ? column.render(item) : item[column.key]; + const content = column.render ? column.render(item, index, { + expanded: isExpanded, + isHovered: hoveredRow === item.id + }) : item[column.key]; const maxWidth = column.maxWidth || Infinity; return ( @@ -184,13 +191,13 @@ const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWid })} {hasChildren && isExpanded && item.children.map(child => ( - + ))} ); }; -const TableHeader = ({ schema, columnWidths, updateWidth }) => { +const TableHeader = ({ schema, columnWidths, updateWidth, showHierarchical, statusColumn }) => { const { isAllSelected, toggleAllSelection, sortConfig, onSort } = useHierarchicalTable(); return ( @@ -202,8 +209,12 @@ const TableHeader = ({ schema, columnWidths, updateWidth }) => { onCheckedChange={(checked) => toggleAllSelection(checked)} /> - - + + {showHierarchical ? ( + + ) : ( + statusColumn?.header || '' + )} {schema.columns.map((column, index) => { const maxWidth = column.maxWidth || Infinity; @@ -233,12 +244,11 @@ const TableHeader = ({ schema, columnWidths, updateWidth }) => { ); }; -const TableBody = ({ schema, onRowClick, columnWidths, updateWidth, rowClassName, setRowRef, links, linkColumn }) => { +const TableBody = ({ schema, onRowClick, columnWidths, updateWidth, rowClassName, setRowRef, links, linkColumn, showHierarchical, statusColumn }) => { const { sortedData } = useHierarchicalTable(); const [, forceUpdate] = useState({}); useEffect(() => { - // Force a re-render to trigger position updates forceUpdate({}); }, [sortedData]); @@ -256,6 +266,8 @@ const TableBody = ({ schema, onRowClick, columnWidths, updateWidth, rowClassName setRowRef={setRowRef} links={links} linkColumn={linkColumn} + showHierarchical={showHierarchical} + statusColumn={statusColumn} /> ))} @@ -312,7 +324,26 @@ const PaginationControls = ({ currentPage, totalPages, onPageChange, pageSize, t ); }; -const HierarchicalTable = ({ schema, data, onRowClick, onSelectionChange, initialSortConfig, rowClassName, currentPage, onPageChange, pageSize, totalItems, omitColumns, expandAll, links, expandedLinkColumn, collapsedLinkColumn }) => { +const HierarchicalTable = ({ + schema, + data, + onRowClick, + onSelectionChange, + initialSortConfig, + rowClassName, + currentPage, + onPageChange, + pageSize, + totalItems, + omitColumns, + expandAll, + links, + expandedLinkColumn, + collapsedLinkColumn, + showHierarchical = true, + statusColumn = null, + hierarchicalSort = false +}) => { const [columnWidths, setColumnWidths] = useState({}); const [isExpanded, setIsExpanded] = useState(false); const [rowRefs, setRowRefs] = useState({}); @@ -372,17 +403,21 @@ const HierarchicalTable = ({ schema, data, onRowClick, onSelectionChange, initia return (
- - {/* Update SVG rendering for direct lines */} + { links && - + /> + + }
{onPageChange && ( useContext(HierarchicalTableContext); -export const HierarchicalTableProvider = ({ children, data, onSelectionChange, initialSortConfig, setIsExpanded, expandAll}) => { +export const HierarchicalTableProvider = ({ children, data, schema, onSelectionChange, initialSortConfig, setIsExpanded, expandAll, hierarchicalSort }) => { const [expandedRows, setExpandedRows] = useState({}); const [selectedRows, setSelectedRows] = useState({}); const [sortConfig, setSortConfig] = useState(initialSortConfig || { key: null, direction: 'asc' }); const [hoveredRow, setHoveredRow] = useState(null); + useEffect(() => { const allParentRowsCollapsed = data.every(item => !expandedRows[item.id]); setIsExpanded(!allParentRowsCollapsed); }, [expandedRows, setIsExpanded, data]); - // expandall specifies if the initial state of row is expanded. - // if a rows expansion state is not specified, it is set to expanded if expandAll is true. + // expandall specifies if the initial state of row is expanded. useEffect(() => { if (expandAll) { data.forEach(item => { @@ -33,6 +33,14 @@ export const HierarchicalTableProvider = ({ children, data, onSelectionChange, i })); }, []); + const isItemSelected = useCallback((item) => { + if (!selectedRows[item.id]) return false; + if (item.children) { + return item.children.every(child => isItemSelected(child)); + } + return true; + }, [selectedRows]); + const toggleSelection = useCallback((item, isSelected) => { setSelectedRows(prev => { const newSelectedRows = { ...prev }; @@ -63,15 +71,7 @@ export const HierarchicalTableProvider = ({ children, data, onSelectionChange, i const isAllSelected = useCallback(() => { return data.every(item => isItemSelected(item)); - }, [data, selectedRows]); - - const isItemSelected = useCallback((item) => { - if (!selectedRows[item.id]) return false; - if (item.children) { - return item.children.every(child => isItemSelected(child)); - } - return true; - }, [selectedRows]); + }, [data, isItemSelected]); const onSort = useCallback((key) => { setSortConfig((prevConfig) => ({ @@ -82,7 +82,27 @@ export const HierarchicalTableProvider = ({ children, data, onSelectionChange, i const sortedData = useMemo(() => { if (!sortConfig.key) return data; - return [...data].sort((a, b) => { + + const sortItems = (items) => { + return [...items].sort((a, b) => { + const column = schema?.columns?.find(col => col.key === sortConfig.key); + const sortFn = column?.sortFn; + + const comparison = sortFn + ? (sortConfig.direction === 'asc' ? sortFn(a, b) : sortFn(b, a)) + : defaultCompare(a, b); + + return comparison; + }).map(item => ({ + ...item, + // Recursively sort children if hierarchicalSort is enabled + children: hierarchicalSort && item.children + ? sortItems(item.children) + : item.children + })); + }; + + const defaultCompare = (a, b) => { if (a[sortConfig.key] < b[sortConfig.key]) { return sortConfig.direction === 'asc' ? -1 : 1; } @@ -90,10 +110,12 @@ export const HierarchicalTableProvider = ({ children, data, onSelectionChange, i return sortConfig.direction === 'asc' ? 1 : -1; } return 0; - }); - }, [data, sortConfig]); + }; + + return sortItems(data); + }, [data, sortConfig, schema, hierarchicalSort]); - React.useEffect(() => { + useEffect(() => { if (onSelectionChange) { onSelectionChange(selectedRows); } diff --git a/ell-studio/src/components/IORenderer.js b/ell-studio/src/components/IORenderer.js index 00f3e09e0..b7ead6124 100644 --- a/ell-studio/src/components/IORenderer.js +++ b/ell-studio/src/components/IORenderer.js @@ -52,7 +52,7 @@ const preprocessData = (data, currentLevel = 0, typeMatchLevel = 0) => { }; const renderInline = (data, customRenderers) => { - if (data.__lstr) { + if (data?.__lstr) { data = data.content; } @@ -166,7 +166,7 @@ const renderNdarray = (data) => { }; const renderNonInline = (data, customRenderers, level = 0, isArrayItem = false, postfix = '') => { - if (data.__lstr) { + if (data?.__lstr) { data = data.content; } @@ -298,7 +298,9 @@ const Indent = ({ children }) => ( ); const IORenderer = ({ content : content_obj, customRenderers = [], inline = true, typeMatchLevel = 0 }) => { + const content = JSON.stringify(content_obj) + console.log(content); try { const parsedContent = JSON.parse(content); const preprocessedContent = preprocessData(parsedContent, 0, typeMatchLevel); @@ -307,7 +309,9 @@ const IORenderer = ({ content : content_obj, customRenderers = [], inline = true {inline ? renderInline(preprocessedContent, customRenderers) : renderNonInline(preprocessedContent, customRenderers)}
); - } catch { + } catch (e) { + + console.error(e); return {content}; } }; diff --git a/ell-studio/src/components/LMPDetailsSidePanel.js b/ell-studio/src/components/LMPDetailsSidePanel.js index 882c612d1..8668e3383 100644 --- a/ell-studio/src/components/LMPDetailsSidePanel.js +++ b/ell-studio/src/components/LMPDetailsSidePanel.js @@ -7,7 +7,7 @@ import { useInvocationsFromLMP } from '../hooks/useBackend'; import { LMPCardTitle } from './depgraph/LMPCardTitle'; import { format } from 'date-fns'; import SidePanel from './common/SidePanel'; -import MetricChart from './MetricChart'; +import MetricChart from './oldgraph/OldMetricChart'; import { motion } from 'framer-motion'; import {Card} from './common/Card'; diff --git a/ell-studio/src/components/Card.js b/ell-studio/src/components/OldCard.js similarity index 78% rename from ell-studio/src/components/Card.js rename to ell-studio/src/components/OldCard.js index 27cf55bc4..4f706e895 100644 --- a/ell-studio/src/components/Card.js +++ b/ell-studio/src/components/OldCard.js @@ -1,6 +1,6 @@ import React from "react"; -export function Card({ children, title, noMinW, ...rest }) { +export function OldCard({ children, title, noMinW, ...rest }) { return (
{ const [isExpanded, setIsExpanded] = useState(false); @@ -40,6 +41,7 @@ const Sidebar = () => { @@ -60,4 +62,4 @@ const Sidebar = () => { ); }; -export default Sidebar; \ No newline at end of file +export default Sidebar; diff --git a/ell-studio/src/components/VersionBadge.js b/ell-studio/src/components/VersionBadge.js index 793d2514e..3992ae1aa 100644 --- a/ell-studio/src/components/VersionBadge.js +++ b/ell-studio/src/components/VersionBadge.js @@ -1,34 +1,35 @@ -import React from 'react'; +import React, { useRef, useEffect, useState } from 'react'; const getColorFromVersion = (version) => { - const hue = (version * 137.508) % 360; // Golden angle approximation - return `hsl(${hue}, 40%, 70%)`; // Base color + const hue = (version * 137.508) % 360; + return `hsl(${hue}, 40%, 70%)`; }; -const VersionBadge = ({ version, hash, className = '', shortVersion = false }) => { +const VersionBadge = ({ version, hash, className = '', shortVersion = false, truncationLength = 9 }) => { + const [isOverflowing, setIsOverflowing] = useState(false); + const badgeRef = useRef(null); const baseColor = getColorFromVersion(version); - const lighterColor = `hsl(${baseColor.match(/\d+/)[0]}, 40%, 75%)`; // Slightly lighter - const textColor = 'text-gray-900'; // Dark text for contrast + const lighterColor = `hsl(${baseColor.match(/\d+/)[0]}, 40%, 75%)`; + const textColor = 'text-gray-900'; + + useEffect(() => { + const checkOverflow = () => { + if (badgeRef.current) { + setIsOverflowing(badgeRef.current.scrollWidth > badgeRef.current.clientWidth); + } + }; + + checkOverflow(); + window.addEventListener('resize', checkOverflow); + return () => window.removeEventListener('resize', checkOverflow); + }, [version, hash]); + + const useShortVersion = shortVersion || isOverflowing; return ( -
-
- {shortVersion ? `v${version}` : `Version ${version}`} -
- {hash && ( -
- {hash.substring(0, 9)} -
- )} +
+
{useShortVersion ? `v${version}` : `Version ${version}`}
+ {hash && !useShortVersion &&
{hash.substring(0, truncationLength)}
}
); }; diff --git a/ell-studio/src/components/VersionHistoryPane.js b/ell-studio/src/components/VersionHistoryPane.js index fdc27a68b..52801b367 100644 --- a/ell-studio/src/components/VersionHistoryPane.js +++ b/ell-studio/src/components/VersionHistoryPane.js @@ -1,11 +1,19 @@ import React from 'react'; import { FiGitCommit, FiClock, FiCopy, FiChevronRight } from 'react-icons/fi'; -import { useNavigate, useParams } from 'react-router-dom'; +import { useNavigate, useLocation } from 'react-router-dom'; import VersionBadge from './VersionBadge'; -const VersionHistoryPane = ({ versions, onSelect }) => { +const VersionHistoryPane = ({ + versions, + onSelect, + config: { + getPath, + getId, + isCurrentVersion + } +}) => { const navigate = useNavigate(); - const { id: currentLmpId } = useParams(); + const location = useLocation(); const formatDate = (timestamp) => { const date = new Date(timestamp); @@ -39,7 +47,8 @@ const VersionHistoryPane = ({ versions, onSelect }) => { let totalIndex = 0; const handleVersionClick = (version) => { - navigate(`/lmp/${version.name}/${version.lmp_id}`); + const path = getPath(version); + navigate(path); if (onSelect) { onSelect(version); } @@ -55,44 +64,45 @@ const VersionHistoryPane = ({ versions, onSelect }) => { const commitLines = (version.commit_message || 'Commit message not available').split('\n'); const commitTitle = commitLines[0] || 'Commit message not available'; const commitDetails = commitLines.slice(1).join('\n').trim(); + const versionId = getId(version); return ( -
-
handleVersionClick(version)}> -
-
- - {commitTitle} -
- {commitDetails && ( -
{commitDetails}
- )} -
- Author - {version.author_name || 'Unknown'} committed - - {formatDate(version.created_at)} +
+
handleVersionClick(version)}> +
+
+ + {commitTitle} +
+ {commitDetails && ( +
{commitDetails}
+ )} +
+ Author + {version.author_name || 'Unknown'} committed + + {formatDate(version.created_at)} +
+
- -
-
-
- - {version.lmp_id.substring(0, 7)} +
+
+ + {versionId.substring(0, 7)} +
+
- -
); })} @@ -102,4 +112,4 @@ const VersionHistoryPane = ({ versions, onSelect }) => { ); }; -export default VersionHistoryPane; \ No newline at end of file +export default VersionHistoryPane; diff --git a/ell-studio/src/components/common/MetricCard.js b/ell-studio/src/components/common/MetricCard.js deleted file mode 100644 index 10d29b92c..000000000 --- a/ell-studio/src/components/common/MetricCard.js +++ /dev/null @@ -1,17 +0,0 @@ -import React from 'react'; -import MetricChart from '../MetricChart'; - -const MetricCard = ({ title, rawData, dataKey, color, yAxisLabel, aggregation }) => ( -
- -
-); - -export default MetricCard; \ No newline at end of file diff --git a/ell-studio/src/components/common/Spinner.js b/ell-studio/src/components/common/Spinner.js new file mode 100644 index 000000000..5528864bc --- /dev/null +++ b/ell-studio/src/components/common/Spinner.js @@ -0,0 +1,30 @@ +import React from 'react'; + +export const Spinner = ({ size = 'md' }) => { + const sizeClasses = { + sm: 'w-4 h-4', + md: 'w-6 h-6', + lg: 'w-8 h-8', + }; + + return ( +
+ + + + +
+ ); +}; \ No newline at end of file diff --git a/ell-studio/src/components/depgraph/DependencyGraph.js b/ell-studio/src/components/depgraph/DependencyGraph.js index 1fe114c51..c714f3917 100644 --- a/ell-studio/src/components/depgraph/DependencyGraph.js +++ b/ell-studio/src/components/depgraph/DependencyGraph.js @@ -17,13 +17,15 @@ import ReactFlow, { } from "reactflow"; import { getBezierPath } from 'reactflow'; import { Link } from "react-router-dom"; -import { LMPCardTitle } from "./LMPCardTitle"; // Add this import -import { Card } from "../Card"; +import { LMPCardTitle } from "./LMPCardTitle"; +import { OldCard } from "../OldCard"; +import EvaluationCard from "../evaluations/EvaluationCard"; // Update this import import "reactflow/dist/style.css"; import { ZoomIn, ZoomOut, Lock, Maximize, Unlock } from 'lucide-react'; import { Button } from "components/common/Button"; -import { useLayoutedElements, getInitialGraph } from "./graphUtils"; +import { getInitialGraph } from "./graphUtils"; +import { useLayoutedElements } from "./layoutUtils"; function LMPNode({ data }) { @@ -32,13 +34,28 @@ function LMPNode({ data }) { return ( <> - - + + - - + + + + + + ); +} +function EvalNode({ data }) { + const { evaluation } = data; + + return ( + <> + +
{/* Adjust the width as needed */} + +
+ @@ -56,18 +73,21 @@ const LayoutFlow = ({ initialNodes, initialEdges }) => { useEffect(() => { if (initialised && !didInitialSimulation) { setDidInitialSimulation(true); - toggle(); + // toggle(); fitView({ duration: 500, padding: 0.1 }); setTimeout(() => { - toggle(); + // toggle(); // Fit view after the simulation has run fitView({ duration: 500, padding: 0.1 }); }, 1000); } }, [initialised, didInitialSimulation, toggle, fitView]); - const nodeTypes = useMemo(() => ({ lmp: LMPNode }), []); + const nodeTypes = useMemo(() => ({ + lmp: LMPNode, + evaluation: EvalNode // Add the new EvalNode type + }), []); return (
@@ -119,11 +139,11 @@ function CustomControls() { ); } -export function DependencyGraph({ lmps, traces, ...rest }) { +export function DependencyGraph({ lmps, traces, evals, ...rest }) { // construct ndoes from LMPS const { initialEdges, initialNodes } = useMemo( - () => getInitialGraph(lmps, traces), - [lmps, traces] + () => getInitialGraph(lmps, traces, evals), + [lmps, traces, evals] ); return ( @@ -136,4 +156,4 @@ export function DependencyGraph({ lmps, traces, ...rest }) {
); -} \ No newline at end of file +} diff --git a/ell-studio/src/components/depgraph/LMPCardTitle.js b/ell-studio/src/components/depgraph/LMPCardTitle.js index f81e66831..bff36fa81 100644 --- a/ell-studio/src/components/depgraph/LMPCardTitle.js +++ b/ell-studio/src/components/depgraph/LMPCardTitle.js @@ -14,28 +14,36 @@ export function LMPCardTitle({ shortVersion = false, paddingClassOverride = '', nameOverride = null, - showInvocationCount = true, // New prop to control invocation count display + showInvocationCount = true, + outlineStyle = 'solid', + nameOverridePrint = null, // New prop for printing name override ...rest }) { const paddingClass = paddingClassOverride ? paddingClassOverride : padding ? 'p-2' : ''; + const scaleClass = `scale-${scale}`; const hoverClass = clickable ? ' duration-200 ease-in-out hover:bg-opacity-80 hover:bg-gray-700' : ''; const cursorClass = clickable ? 'cursor-pointer' : ''; + // Define outline styles + const outlineClasses = { + solid: lmp.is_lmp ? 'bg-blue-100 text-blue-800' : 'bg-yellow-100 text-yellow-800', + dashed: lmp.is_lmp ? 'bg-transparent text-blue-500 border border-dotted border-blue-400' : 'bg-transparent text-yellow-500 border border-dotted border-yellow-400' + }; + return (
{lmp.lmp_type === "LM" ? - //
- // LMP logo - //
: lmp.lmp_type === "TOOL" ? + : lmp.lmp_type === "METRIC" ? + : }
- {nameOverride ? nameOverride : - {lmp.name}() + {nameOverride ? nameOverride : + {nameOverridePrint || lmp.name}() } {displayVersion && } {showInvocationCount && lmp.num_invocations > 0 && ( diff --git a/ell-studio/src/components/depgraph/graphUtils.js b/ell-studio/src/components/depgraph/graphUtils.js index dc9c6dfa5..620ef9b75 100644 --- a/ell-studio/src/components/depgraph/graphUtils.js +++ b/ell-studio/src/components/depgraph/graphUtils.js @@ -1,269 +1,143 @@ -import { - forceSimulation, - forceLink, - forceManyBody, - forceX, - forceY, -} from "d3-force"; - import { useMemo } from "react"; - import ReactFlow, { - Panel, - useNodesState, - useEdgesState, useReactFlow, - Background, - Controls, - Handle, useStore, - Position, - ReactFlowProvider, MarkerType, } from "reactflow"; -import collide from "./collide"; - -const simulation = forceSimulation() - .force("charge", forceManyBody().strength(-1000)) - .force("x", forceX().x(0).strength(0.03)) - .force("y", forceY().y(0).strength(0.03)) - .force("collide", collide()) - .alphaTarget(0.05) - .stop(); - -function getLayout(nodes, edges) { - const nodeMap = nodes.reduce((map, node) => { - map[node.id] = node; - return map; - }, {}); - - /* An algorithm that counts the number of references each node has: - For each ndoe get all of its children and add 1 - Then we can put the y coordinate as the number of references - as for X we're going to determien its sort order using the trace order - */ - - // Create a map to store the number of references each node has - const referenceCount = nodes.reduce((map, node) => { - map[node.id] = 0; - return map; - }, {}); - - function increaseReferenceCountOfFamilyTree(node, visited = null) { - if (!visited) visited = new Set(); - if (visited.has(node.id)) return; - visited.add(node.id); - edges - .filter( - (edge) => edge.source === node.id && edge.sourceHandle !== "outputs" - ) - .forEach((edge) => { - referenceCount[edge.target] += 1; - increaseReferenceCountOfFamilyTree(nodeMap[edge.target], visited); - }); +import { computeLayout } from "./layoutUtils"; + +// Add this new function at the top of the file +const calculateNodeDimensions = (nodeType, data) => { + switch (nodeType) { + case 'evaluation': + // EvaluationCard has a fixed width of 400px and variable height + // We'll estimate the height based on the content + const baseHeight = 160; // Base height for an evaluation with 2 metrics + const labelerCount = data.labelers?.length || 0; + const heightPerMetric = (288 - 190 - 2*baseHeight) / (3 - 1); // Slope: (height difference) / (metric difference) + const estimatedHeight = baseHeight + (labelerCount * heightPerMetric); + return { width: 400, height: Math.round(estimatedHeight) }; + case 'lmp': + // LMPNode is more compact, using a Card component + // The size might vary based on the LMP name length + const nameLength = data.name?.length || 0; + const lmpWidth = Math.max(180, (80 + nameLength * 15)); // Min 180px, max 300px + return { width: lmpWidth, height: 100 }; + default: + // Default size for unknown node types + return { width: 150, height: 60 }; } - nodes.forEach((node) => { - increaseReferenceCountOfFamilyTree(node); - }); - - // Now get hte trace order (if a traces into b a < b. for cycles just put them at the end so) - // if a < b < c < d < a then the order should be (a,b,c,d) that is we should have order within a local order - // Implement cycle-aware topological sorting - const traceOrder = []; - const visited = new Set(); - const tempVisited = new Set(); - - function dfs(nodeId) { - if (tempVisited.has(nodeId)) { - // Cycle detected, skip this node - return; - } - if (visited.has(nodeId)) { - return; - } - tempVisited.add(nodeId); - - const outgoingEdges = edges.filter( - (edge) => edge.source === nodeId && edge.sourceHandle === "outputs" - ); - for (const edge of outgoingEdges) { - dfs(edge.target); - } - - tempVisited.delete(nodeId); - visited.add(nodeId); - traceOrder.unshift(nodeId); - } - - // Perform DFS for each node - nodes.forEach((node) => { - if (!visited.has(node.id)) { - dfs(node.id); - } - }); - - // Assign x-coordinates based on the trace order - traceOrder.forEach((nodeId, index) => { - nodeMap[nodeId].position.x = index * 60; - }); - - // Group nodes by all the unique reference count levels - const referenceCountLevels = new Set(Object.values(referenceCount)); - referenceCountLevels.forEach((level) => { - // get all the nodes at this level - const nodesAtLevel = Object.entries(referenceCount) - .filter(([id, count]) => count === level) - .map(([id, count]) => nodeMap[id]); - // for each node at this level, set its x coordinate to be the index of the node +}; - nodesAtLevel.forEach((node, i) => { - node.position.y = (-level * 100 + Math.random() * 10); +/** + * Generates the initial graph structure from LMPs, traces, and evaluations. + * @param {Array} lmps - List of LMP objects. + * @param {Array} traces - List of trace objects. + * @param {Array} evals - List of evaluation objects. + * @returns {Object} - Contains initial nodes and edges. + */ +export const getInitialGraph = (lmps, traces, evals = []) => { + if(!lmps || !traces) return { initialNodes: [], initialEdges: [] }; + const lmpIds = new Set(lmps.map(lmp => lmp.lmp_id)); + const evalLmpIds = new Set(); + const lmpToEvalMap = new Map(); + + // Create evaluation nodes and map LMPs to their evaluations + const evalNodes = (evals || []).map(eval_ => { + eval_.labelers.forEach(labeler => { + evalLmpIds.add(labeler.labeling_lmp_id); + lmpToEvalMap.set(labeler.labeling_lmp_id, eval_.id); }); - }); -} - -export const useLayoutedElements = () => { - const { getNodes, setNodes, getEdges, fitView } = useReactFlow(); - const initialised = useStore((store) => - [...store.nodeInternals.values()].every((node) => node.width && node.height) - ); - - return useMemo(() => { - let nodes = getNodes().map((node) => ({ - ...node, - x: node.position.x, - y: node.position.y, - })); - let edges = getEdges().map((edge) => edge); - let running = false; - - // If React Flow hasn't initialised our nodes with a width and height yet, or - // if there are no nodes in the flow, then we can't run the simulation! - if (!initialised || nodes.length === 0) return [false, {}]; - - simulation.nodes(nodes).force( - "link", - forceLink(edges) - .id((d) => d.id) - .strength(0.10) - .distance(100) - ); - - // The tick function is called every animation frame while the simulation is - // running and progresses the simulation one step forward each time. - const tick = () => { - getNodes().forEach((node, i) => { - const dragging = Boolean( - document.querySelector(`[data-id="${node.lmp_id}"].dragging`) - ); - - // Setting the fx/fy properties of a node tells the simulation to "fix" - // the node at that position and ignore any forces that would normally - // cause it to move. - nodes[i].fx = dragging ? node.position.x : null; - nodes[i].fy = dragging ? node.position.y : null; - }); - - simulation.tick(); - setNodes( - nodes.map((node) => ({ ...node, position: { x: node.x, y: node.y } })) - ); - - window.requestAnimationFrame(() => { - // Give React and React Flow a chance to update and render the new node - // positions before we fit the viewport to the new layout. - // fitView(); - - // If the simulation hasn't be stopped, schedule another tick. - if (running) tick(); - }); + const dimensions = calculateNodeDimensions('evaluation', eval_); + return { + id: `${eval_.id}`, + type: "evaluation", + data: { + label: eval_.name, + evaluation: eval_, + ...dimensions + }, + position: { x: 0, y: 0 }, }; + }); - const toggle = () => { - running = !running; - running && window.requestAnimationFrame(tick); - }; - - const isRunning = () => running; - - return [true, { toggle, isRunning }]; - }, [initialised]); -}; -export function getInitialGraph(lmps, traces) { - const lmpIds = new Set(lmps.map(lmp => lmp.lmp_id)); + // Create LMP nodes, excluding those that are part of evaluations and those of type "metric" + const lmpNodes = lmps.filter(Boolean) + .filter(lmp => !evalLmpIds.has(lmp.lmp_id) && lmp.lmp_type !== "LABELER") + .map(lmp => { + const dimensions = calculateNodeDimensions('lmp', lmp); + console.log(lmp); + return { + id: `${lmp.lmp_id}`, + type: "lmp", + data: { + label: lmp.name, + lmp, + isEvalLabeler: evalLmpIds.has(lmp.lmp_id), + ...dimensions + }, + position: { x: 0, y: 0 }, + }; + }); - const initialNodes = - lmps - .filter((x) => !!x) - .map((lmp) => { + const deadNodes = lmps.flatMap(lmp => + (lmp.uses || []) + .filter(use => !lmpIds.has(use.lmp_id) && !evalLmpIds.has(use.lmp_id)) + .map(use => { + const dimensions = calculateNodeDimensions('lmp', use); return { - id: `${lmp.lmp_id}`, + id: `${use.lmp_id}`, type: "lmp", - data: { label: lmp.name, lmp }, + data: { + label: `Outdated LMP ${use.name}`, + lmp: { + lmp_id: use.lmp_id, + name: `Outdated LMP (${use.name})`, + version_number: use.version_number, + }, + ...dimensions + }, position: { x: 0, y: 0 }, + style: { opacity: 0.5 }, }; - }) || []; - - // Create dead nodes for missing LMPs - const deadNodes = lmps - .filter((x) => !!x) - .flatMap((lmp) => - (lmp.uses || []).filter(use => !lmpIds.has(use.lmp_id)).map(use => ({ - id: `${use.lmp_id}`, - type: "lmp", - data: { label: `Outdated LMP ${use.name}`, lmp: { lmp_id: use.lmp_id, name: `Outdated LMP (${use.name})`, version_number: use.version_number } }, - position: { x: 0, y: 0 }, - style: { opacity: 0.5 }, // Make dead nodes visually distinct - })) - ); - - initialNodes.push(...deadNodes); + }) + ); - const initialEdges = - lmps - .filter((x) => !!x) - .flatMap((lmp) => { - if (lmp.is_old) return []; - return ( - lmp?.uses?.map((use) => { - return { - id: `uses-${lmp.lmp_id}-${use.lmp_id}`, - target: `${lmp.lmp_id}`, - source: `${use.lmp_id}`, - animated: false, - type: "default", - }; - }) || [] - ); - }) || []; + const initialNodes = [...evalNodes, ...lmpNodes, ...deadNodes]; + + const initialEdges = lmps.flatMap(lmp => + lmp.is_old ? [] : (lmp.uses || []).map(use => { + const sourceId = evalLmpIds.has(use.lmp_id) ? lmpToEvalMap.get(use.lmp_id) : `${use.lmp_id}`; + const targetId = evalLmpIds.has(lmp.lmp_id) ? lmpToEvalMap.get(lmp.lmp_id) : `${lmp.lmp_id}`; + return { + id: `uses-${sourceId}-${targetId}`, + source: sourceId, + sourceHandle: "uses", + target: targetId, + targetHandle: "usedby", + animated: false, + type: "default", + }; + }) + ); - // Add horizontal trace edges - if (traces && traces.length > 0) { - traces.forEach((trace, index) => { - initialEdges.push({ - id: `trace-${trace.consumed}-${trace.consumer}`, - source: `${trace.consumed}`, - sourceHandle: "outputs", - target: `${trace.consumer}`, - targetHandle: "inputs", - animated: true, - markerEnd: { - type: MarkerType.ArrowClosed, - width: 30, - height: 30, - }, - style: { - stroke: "#ff7f50", // Coral color - strokeWidth: 1, - }, - labelStyle: { fill: "#ff7f50", fontWeight: 700 }, - // label: 'Trace', - }); + traces?.forEach(trace => { + const sourceId = evalLmpIds.has(trace.consumed) ? lmpToEvalMap.get(trace.consumed) : `${trace.consumed}`; + const targetId = evalLmpIds.has(trace.consumer) ? lmpToEvalMap.get(trace.consumer) : `${trace.consumer}`; + initialEdges.push({ + id: `trace-${sourceId}-${targetId}`, + source: sourceId, + sourceHandle: "outputs", + target: targetId, + targetHandle: "inputs", + animated: true, + markerEnd: { type: MarkerType.ArrowClosed, width: 30, height: 30 }, + style: { stroke: "#ff7f50", strokeWidth: 1 }, + labelStyle: { fill: "#ff7f50", fontWeight: 700 }, }); - } - - getLayout(initialNodes, initialEdges); + }); - return { initialEdges, initialNodes }; -} + computeLayout(initialNodes, initialEdges); + return { initialNodes, initialEdges }; +}; diff --git a/ell-studio/src/components/depgraph/layoutUtils.js b/ell-studio/src/components/depgraph/layoutUtils.js new file mode 100644 index 000000000..1410e211f --- /dev/null +++ b/ell-studio/src/components/depgraph/layoutUtils.js @@ -0,0 +1,318 @@ +import { + forceSimulation, + forceLink, + forceManyBody, + forceX, + forceY, + forceCollide, +} from "d3-force"; + +import { useMemo } from "react"; +import ReactFlow, { + useReactFlow, + useStore, + MarkerType, +} from "reactflow"; + +import collide from "./collide"; + +// Initialize the D3 force simulation +const simulation = forceSimulation() + .force("charge", forceManyBody().strength(-500)) + .force("link", forceLink().id(d => d.id).distance(150)) + .force("x", forceX().strength(0.1)) + .force("y", forceY().strength(0.1)) + .force("collide", collide()) + .alphaTarget(0.05) + .stop(); + +/** + * Maps nodes by ID and initializes reference counts. + * @param {Array} nodes - List of node objects. + * @returns {Object} - Contains nodeMap and referenceCount. + */ +const initializeNodes = (nodes) => { + const nodeMap = {}; + const referenceCount = {}; + nodes.forEach(node => { + nodeMap[node.id] = node; + referenceCount[node.id] = 0; + }); + return { nodeMap, referenceCount }; +}; + +/** + * Recursively increments reference counts for dependencies. + * @param {Object} node - Current node. + * @param {Object} nodeMap - Mapping of node IDs to nodes. + * @param {Array} edges - List of edge objects. + * @param {Object} referenceCount - Reference counts. + * @param {Set} visited - Visited nodes to prevent cycles. + */ +const incrementRefs = (node, nodeMap, edges, referenceCount, visited = new Set()) => { + if (visited.has(node.id)) return; + visited.add(node.id); + edges + .filter(edge => edge.source === node.id && edge.sourceHandle !== "outputs") + .forEach(edge => { + referenceCount[edge.target]++; + incrementRefs(nodeMap[edge.target], nodeMap, edges, referenceCount, visited); + }); +}; + +/** + * Performs a cycle-aware topological sort. + * @param {Array} nodes - List of node objects. + * @param {Array} edges - List of edge objects. + * @returns {Array} - Ordered list of node IDs. + */ +const getTraceOrder = (nodes, edges) => { + const traceOrder = []; + const visited = new Set(); + const tempVisited = new Set(); + + const dfs = (nodeId) => { + if (tempVisited.has(nodeId) || visited.has(nodeId)) return; + tempVisited.add(nodeId); + edges + .filter(edge => edge.source === nodeId && edge.sourceHandle === "outputs") + .forEach(edge => dfs(edge.target)); + tempVisited.delete(nodeId); + visited.add(nodeId); + traceOrder.unshift(nodeId); + }; + + nodes.forEach(node => !visited.has(node.id) && dfs(node.id)); + return traceOrder; +}; + +/** + * Finds connected components in the graph. + * @param {Array} nodes - List of node objects. + * @param {Array} edges - List of edge objects. + * @returns {Array} - Array of connected components, each component is an array of node IDs. + */ +const getConnectedComponents = (nodes, edges) => { + const visited = new Set(); + const components = []; + + const dfs = (nodeId, component) => { + visited.add(nodeId); + component.push(nodeId); + edges.forEach(edge => { + if (edge.source === nodeId && !visited.has(edge.target)) { + dfs(edge.target, component); + } else if (edge.target === nodeId && !visited.has(edge.source)) { + dfs(edge.source, component); + } + }); + }; + + nodes.forEach(node => { + if (!visited.has(node.id)) { + const component = []; + dfs(node.id, component); + components.push(component); + } + }); + + return components; +}; + +/** + * Assigns positions to nodes by topologically separating groups (components), + * laying out within the separate topological groups, and then packing them + * in a non-overlapping way based on their bounding boxes considering node sizes. + * @param {Array} nodes - List of node objects. + * @param {Array} edges - List of edge objects. + */ +const assignPositions = (nodes, edges) => { + // Get connected components + const components = getConnectedComponents(nodes, edges); + const allNodeMap = {}; + nodes.forEach(node => allNodeMap[node.id] = node); + + let offsetX = 0; + let offsetY = 0; + const groupSpacing = 100; // Base space between groups + + components.forEach((componentNodeIds, index) => { + const componentNodes = componentNodeIds.map(id => ({ ...allNodeMap[id] })); + const componentEdges = edges.filter(edge => componentNodeIds.includes(edge.source) && componentNodeIds.includes(edge.target)); + + // Assign positions within the component + assignPositionsToComponent(componentNodes, componentEdges); + + // Compute bounding box of the component considering node sizes + const xs = componentNodes.map(node => node.position.x + (node.data.width || 100)); + const ys = componentNodes.map(node => node.position.y + (node.data.height || 50)); + const minX = Math.min(...componentNodes.map(node => node.position.x)); + const minY = Math.min(...componentNodes.map(node => node.position.y)); + const maxX = Math.max(...xs); + const maxY = Math.max(...ys); + const width = maxX - minX; + const height = maxY - minY; + + // Offset component positions to prevent overlap + const offsetXForGroup = offsetX - minX; + const offsetYForGroup = offsetY - minY; + + componentNodes.forEach(node => { + node.position.x += offsetXForGroup; + node.position.y += offsetYForGroup; + // Update the position in the original nodes array + allNodeMap[node.id].position = node.position; + }); + + // Update offsetX for the next group based on current group's width + offsetX += width + groupSpacing; + + // Optionally, arrange groups in rows if offsetX exceeds a certain limit + // For example, start a new row after 2000px + const maxRowWidth = 1000; + if (offsetX > maxRowWidth) { + offsetX = 0; + offsetY += height + groupSpacing; + } + }); +}; + +/** + * Assigns positions to nodes within a single component based on dependencies and edge types, + * accounting for each node's width and height to prevent overlaps. + * @param {Array} nodes - List of node objects in the component. + * @param {Array} edges - List of edge objects in the component. + */ +const assignPositionsToComponent = (nodes, edges) => { + const { nodeMap, referenceCount } = initializeNodes(nodes); + nodes.forEach(node => incrementRefs(node, nodeMap, edges, referenceCount)); + + const traceOrder = getTraceOrder(nodes, edges); + + // Assign horizontal levels based on inputs/outputs + const horizontalLevels = {}; + traceOrder.forEach(id => { + const incomingOutputEdges = edges.filter(edge => edge.target === id && edge.sourceHandle === "outputs"); + if (incomingOutputEdges.length === 0) { + horizontalLevels[id] = 0; + } else { + horizontalLevels[id] = Math.max(...incomingOutputEdges.map(edge => horizontalLevels[edge.source] + 1)); + } + }); + + // Assign vertical levels based on uses/usedby + const verticalLevels = {}; + traceOrder.forEach(id => { + const incomingUseEdges = edges.filter(edge => edge.target === id && edge.sourceHandle === "uses"); + if (incomingUseEdges.length === 0) { + verticalLevels[id] = 0; + } else { + verticalLevels[id] = Math.max(...incomingUseEdges.map(edge => verticalLevels[edge.source] + 1)); + } + }); + + // Group nodes by horizontal and vertical levels + const nodesByPosition = {}; + nodes.forEach(node => { + const xLevel = horizontalLevels[node.id] || 0; + const yLevel = verticalLevels[node.id] || 0; + const key = `${xLevel}-${yLevel}`; + if (!nodesByPosition[key]) nodesByPosition[key] = []; + nodesByPosition[key].push(node); + }); + + // Determine maximum width and height per horizontal and vertical level + const maxWidthPerXLevel = {}; + const maxHeightPerYLevel = {}; + + nodes.forEach(node => { + const xLevel = horizontalLevels[node.id] || 0; + const yLevel = verticalLevels[node.id] || 0; + const nodeWidth = node.data.width || 100; + const nodeHeight = node.data.height || 50; + + if (!maxWidthPerXLevel[xLevel] || nodeWidth > maxWidthPerXLevel[xLevel]) { + maxWidthPerXLevel[xLevel] = nodeWidth; + } + + if (!maxHeightPerYLevel[yLevel] || nodeHeight > maxHeightPerYLevel[yLevel]) { + maxHeightPerYLevel[yLevel] = nodeHeight; + } + }); + + // Assign positions within the component + let currentYOffsets = {}; + const levelSpacingX = 50; // Horizontal spacing between levels + const levelSpacingY = 50; // Vertical spacing between levels + + Object.keys(nodesByPosition).forEach(key => { + const [xLevel, yLevel] = key.split('-').map(Number); + const nodesAtPosition = nodesByPosition[key]; + const maxHeight = maxHeightPerYLevel[yLevel] || 50; + + // Initialize Y offset for the xLevel if not present + if (!currentYOffsets[xLevel]) currentYOffsets[xLevel] = 0; + + nodesAtPosition.forEach((node, index) => { + const nodeWidth = node.data.width || 100; + const nodeHeight = node.data.height || 50; + + node.position = { + x: xLevel * (maxWidthPerXLevel[xLevel] + levelSpacingX), + y: currentYOffsets[xLevel], + }; + + // Update the Y offset for the next node in this level + currentYOffsets[xLevel] += nodeHeight + levelSpacingY; + }); + }); +}; + +/** + * Computes the initial layout of the graph. + * @param {Array} nodes - List of node objects. + * @param {Array} edges - List of edge objects. + */ +export const computeLayout = (nodes, edges) => nodes.length && assignPositions(nodes, edges); + +/** + * Custom hook to manage layouted elements within React Flow. + * @returns {Array} - Tuple containing a boolean and control methods. + */ +export const useLayoutedElements = () => { + const { getNodes, setNodes, getEdges } = useReactFlow(); + const isInitialized = useStore(store => + [...store.nodeInternals.values()].every(node => node.width && node.height) + ); + + return useMemo(() => { + if (!isInitialized || !getNodes().length) return [false, {}]; + + const nodes = getNodes().map(node => ({ ...node, x: node.position.x, y: node.position.y })); + const edges = getEdges(); + + simulation.nodes(nodes).force("link", forceLink(edges).id(d => d.id).strength(0.1).distance(100)); + + let isRunning = false; + + const tick = () => { + nodes.forEach(node => { + const dragging = Boolean(document.querySelector(`[data-id="${node.id}"].dragging`)); + node.fx = dragging ? node.position.x : null; + node.fy = dragging ? node.position.y : null; + }); + + simulation.tick(); + setNodes(nodes.map(node => ({ ...node, position: { x: node.x, y: node.y } }))); + + if (isRunning) requestAnimationFrame(tick); + }; + + const toggleSimulation = () => { + isRunning = !isRunning; + if (isRunning) requestAnimationFrame(tick); + }; + + return [true, { toggle: toggleSimulation, isRunning: () => isRunning }]; + }, [isInitialized, getNodes, getEdges, setNodes]); +}; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/EvaluationCard.js b/ell-studio/src/components/evaluations/EvaluationCard.js new file mode 100644 index 000000000..34d3fdb06 --- /dev/null +++ b/ell-studio/src/components/evaluations/EvaluationCard.js @@ -0,0 +1,157 @@ +import React, { useState, useMemo } from 'react'; +import { Link } from 'react-router-dom'; +import { FiBarChart2, FiClock, FiDatabase, FiTag, FiZap, FiCode, FiChevronDown, FiChevronUp } from 'react-icons/fi'; +import { Card, CardContent } from '../common/Card'; +import VersionBadge from '../VersionBadge'; +import { getTimeAgo } from '../../utils/lmpUtils'; +import { LMPCardTitle } from '../depgraph/LMPCardTitle'; +import RunSummary from './RunSummary'; +import { EvaluationCardTitle } from './EvaluationCardTitle'; + +const INITIAL_LMP_DISPLAY_COUNT = 2; + +const EvaluationCard = ({ evaluation, isGraphMode = false }) => { + const [showAllLMPs, setShowAllLMPs] = useState(false); + const totalRuns = evaluation.runs.length; + const successfulRuns = evaluation.runs.filter(run => run.success).length; + + const groupedRuns = useMemo(() => { + const groups = {}; + evaluation.runs.forEach(run => { + const lmpName = run.evaluated_lmp.name; + if (!groups[lmpName]) groups[lmpName] = []; + groups[lmpName].push(run); + }); + return groups; + }, [evaluation.runs]); + + const latestRuns = useMemo(() => + Object.values(groupedRuns).map(runs => + runs.reduce((latest, current) => + new Date(current.end_time) > new Date(latest.end_time) ? current : latest + ) + ), [groupedRuns]); + + const evaluatedLMPs = latestRuns.map(run => run.evaluated_lmp) + .sort((a, b) => b.version_number - a.version_number); + + const displayedLMPs = showAllLMPs ? evaluatedLMPs : evaluatedLMPs.slice(0, INITIAL_LMP_DISPLAY_COUNT); + + return ( + + + +
+
+ + +
+
+ + {totalRuns} runs ({successfulRuns} successful) +
+
+ + {evaluation.n_evals} datapoints +
+
+ + Dataset: {evaluation.dataset_id.substring(0, 8)} +
+
+ + {evaluation.labelers.length} metrics +
+
+ + Created: {getTimeAgo(new Date(evaluation.created_at))} +
+
+ + {!isGraphMode && evaluatedLMPs.length > 0 && ( +
+

+ Evaluated LMPs +

+
+ + {displayedLMPs.map((lmp) => ( + + + + ))} + +
+ {evaluatedLMPs.length > INITIAL_LMP_DISPLAY_COUNT && ( + + )} +
+ )} +
+ + {totalRuns > 0 && ( +
+

+ Latest Run Summary +

+
+ + + + + +
+ {!isGraphMode && ( +
+ +
+ )} +
+ )} +
+ + {evaluation.commit_message && ( +

{evaluation.commit_message}

+ )} +
+
+ + ); +}; + +export default EvaluationCard; diff --git a/ell-studio/src/components/evaluations/EvaluationCardTitle.js b/ell-studio/src/components/evaluations/EvaluationCardTitle.js new file mode 100644 index 000000000..8c562e6cd --- /dev/null +++ b/ell-studio/src/components/evaluations/EvaluationCardTitle.js @@ -0,0 +1,58 @@ +import React from "react"; +import { FiZap } from "react-icons/fi"; +import VersionBadge from "../VersionBadge"; +import EvaluationsIcon from "./EvaluationsIcon"; + +export function EvaluationCardTitle({ + evaluation, + fontSize = "sm", + displayVersion = true, + padding = true, + scale = 1, + additionalClassName = '', + clickable = true, + shortVersion = false, + paddingClassOverride = '', + nameOverride = null, + showRunCount = true, + outlineStyle = 'solid', + ...rest +}) { + const paddingClass = paddingClassOverride ? paddingClassOverride : padding ? 'p-2' : ''; + + const scaleClass = `scale-${scale}`; + const hoverClass = clickable ? 'duration-200 ease-in-out hover:bg-opacity-80 hover:bg-gray-700' : ''; + const cursorClass = clickable ? 'cursor-pointer' : ''; + + const outlineClasses = { + solid: 'bg-blue-100 text-blue-800', + dashed: 'bg-transparent text-blue-500 border border-dotted border-blue-400' + }; + + return ( +
+
+ +
+ {nameOverride ? nameOverride : ( + + {evaluation.name} + + )} + {displayVersion && ( + + )} + {showRunCount && evaluation.runs && evaluation.runs.length > 0 && ( +
+ + {evaluation.runs.length} +
+ )} +
+ ); +} diff --git a/ell-studio/src/components/evaluations/EvaluationDataset.js b/ell-studio/src/components/evaluations/EvaluationDataset.js new file mode 100644 index 000000000..ac442b65a --- /dev/null +++ b/ell-studio/src/components/evaluations/EvaluationDataset.js @@ -0,0 +1,109 @@ +import React, { useMemo, useState } from 'react'; +import { useDataset } from '../../hooks/useBackend'; +import HierarchicalTable from '../HierarchicalTable'; +import { ContentsRenderer } from '../invocations/ContentsRenderer'; +import SearchAndFiltersBar from './runs/SearchAndFiltersBar'; + +function EvaluationDataset({ evaluation }) { + const { data: datasetData, isLoading, isError, error } = useDataset(evaluation?.dataset_id); + const [searchQuery, setSearchQuery] = useState(''); + + const filteredData = useMemo(() => { + if (!datasetData?.data) return []; + if (!searchQuery) return datasetData.data; + + const query = searchQuery.toLowerCase(); + + return datasetData.data.filter(item => { + // Convert the entire item to a string for searching + const itemString = JSON.stringify(item).toLowerCase(); + return itemString.includes(query); + }); + }, [datasetData, searchQuery]); + + const columns = useMemo(() => { + if (!datasetData?.data?.[0]) return []; + + return Object.keys(datasetData.data[0]).map(key => ({ + header: key, + key: key, + render: (item) => ( + + ), + maxWidth: 300, + sortable: true, + sortFn: (a, b) => { + const aValue = a[key]; + const bValue = b[key]; + + if (typeof aValue === 'number' && typeof bValue === 'number') { + return aValue - bValue; + } + return String(aValue).localeCompare(String(bValue)); + } + })); + }, [datasetData]); + + // Handle case where there's no dataset + if (!evaluation?.dataset_id) { + return ( +
+ This evaluation was run without a dataset. Each example was generated on-the-fly. +
+ ); + } + + return ( +
+ {isLoading && ( +
Loading dataset...
+ )} + {isError && ( +
+ {error?.response?.data?.detail || 'Error loading dataset'} +
+ )} + {datasetData?.data && datasetData.data.length > 0 && ( +
+
+ Dataset size: {Math.round(datasetData.size / 1024)} KB • + {' '}{datasetData.data.length} examples + {searchQuery && ` • ${filteredData.length} matches`} +
+ + + + +
+ )} + {!datasetData?.data && ( +
+ This evaluation was created without a dataset and is purely "generative". +
+ )} +
+ ); +} + +export default EvaluationDataset; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/EvaluationDetailsSidebar.js b/ell-studio/src/components/evaluations/EvaluationDetailsSidebar.js new file mode 100644 index 000000000..17da3b53c --- /dev/null +++ b/ell-studio/src/components/evaluations/EvaluationDetailsSidebar.js @@ -0,0 +1,75 @@ +import React from 'react'; +import { FiZap, FiBarChart2, FiDatabase, FiTag, FiClock, FiHash } from 'react-icons/fi'; +import MetricTable from './MetricTable'; +import { getTimeAgo } from '../../utils/lmpUtils'; +import SidePanel from '../common/SidePanel'; +import { motion } from 'framer-motion'; +import VersionBadge from '../VersionBadge'; + +function EvaluationDetailsSidebar({ evaluation }) { + return ( + + +
+
+

Version Info

+ +
+
+
+ + Created: +
+
{getTimeAgo(new Date(evaluation?.created_at))}
+
+ + Runs: +
+
{evaluation?.runs.length}
+
+ + Datapoints: +
+
{evaluation?.n_evals}
+
+ + Dataset: +
+
{evaluation?.dataset_id.substring(0, 8)}
+
+ + Metrics: +
+
{evaluation?.labelers.length}
+
+
+ {/* TODO ADD MROE INFO. */} +
+ {/*

Metrics

*/} + {/* {evaluation && ( + // summary.is_scalar)} + // historicalData={evaluation.runs.reduce((acc, run) => { + // run.labeler_summaries.forEach(summary => { + // if (!acc[summary.evaluation_labeler_id]) { + // acc[summary.evaluation_labeler_id] = []; + // } + // acc[summary.evaluation_labeler_id].push(summary.data); + // }); + // return acc; + // }, {})} + // isVertical={true} + // /> + )} */} +
+
+
+ ); +} + +export default EvaluationDetailsSidebar; diff --git a/ell-studio/src/components/evaluations/EvaluationOverview.js b/ell-studio/src/components/evaluations/EvaluationOverview.js new file mode 100644 index 000000000..813f73186 --- /dev/null +++ b/ell-studio/src/components/evaluations/EvaluationOverview.js @@ -0,0 +1,43 @@ +import React, { useState } from 'react'; +import { FiBarChart2, FiClock, FiDatabase, FiTag, FiZap } from 'react-icons/fi'; +import { Card, CardContent } from '../common/Card'; +import RunSummary from './RunSummary'; +import VersionBadge from '../VersionBadge'; +import { getTimeAgo } from '../../utils/lmpUtils'; +import MetricGraphGrid from './MetricGraphGrid'; + +function EvaluationOverview({ evaluation, groupedRuns, onActiveIndexChange }) { + const [activeIndex, setActiveIndex] = useState(null); + + const handleActiveIndexChange = (index) => { + setActiveIndex(index); + onActiveIndexChange(index); + }; + + return ( + <> +
+

+ Evaluation + + +

+ {evaluation.labelers ? ( + + ) : ( +
+
+
+
+
+ )} +
+ + ); +} + +export default EvaluationOverview; diff --git a/ell-studio/src/components/evaluations/EvaluationsAnalyticsSidePanel.js b/ell-studio/src/components/evaluations/EvaluationsAnalyticsSidePanel.js new file mode 100644 index 000000000..84a45aaa8 --- /dev/null +++ b/ell-studio/src/components/evaluations/EvaluationsAnalyticsSidePanel.js @@ -0,0 +1,42 @@ +import React, { useMemo } from 'react'; +import { FiBarChart2, FiClock, FiDatabase } from 'react-icons/fi'; + +const EvaluationsAnalyticsSidePanel = ({ evaluations }) => { + const analytics = useMemo(() => { + const totalEvaluations = evaluations.length; + const activeEvaluations = evaluations.filter(e => e.status === 'Active').length; + const completedEvaluations = evaluations.filter(e => e.status === 'Completed').length; + const totalDatapoints = evaluations.reduce((sum, e) => sum + e.n_evals, 0); + + return { totalEvaluations, activeEvaluations, completedEvaluations, totalDatapoints }; + }, [evaluations]); + + return ( +
+

Evaluation Analytics

+ +
+
+ Total Evaluations + {analytics.totalEvaluations} +
+
+ Active Evaluations + {analytics.activeEvaluations} +
+
+ Completed Evaluations + {analytics.completedEvaluations} +
+
+ Total Datapoints + {analytics.totalDatapoints} +
+
+ + {/* You can add more analytics or charts here */} +
+ ); +}; + +export default EvaluationsAnalyticsSidePanel; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/EvaluationsIcon.js b/ell-studio/src/components/evaluations/EvaluationsIcon.js new file mode 100644 index 000000000..23852bf81 --- /dev/null +++ b/ell-studio/src/components/evaluations/EvaluationsIcon.js @@ -0,0 +1,11 @@ +import React from 'react'; +import { FiBarChart2, FiClipboard } from 'react-icons/fi'; + +const EvaluationsIcon = ({ className = '' }) => ( +
+ + +
+); + +export default EvaluationsIcon; diff --git a/ell-studio/src/components/evaluations/LabelDisplay.js b/ell-studio/src/components/evaluations/LabelDisplay.js new file mode 100644 index 000000000..7b7521f31 --- /dev/null +++ b/ell-studio/src/components/evaluations/LabelDisplay.js @@ -0,0 +1,60 @@ +import React from 'react'; + +const LabelDisplay = ({ + value : valueNumberish, // This is the mean + isAggregate = false, + stats = null // { min, max, stdDev } +}) => { + const value = typeof valueNumberish === 'boolean' ? Number(valueNumberish) : valueNumberish; + + if (typeof value !== 'number') { + return
{value}
; + } + + if (!isAggregate || !stats) { + return
{value.toFixed(2)}
; + } + + const { min, max, stdDev } = stats; + const mean = value; + + // Handle the case where min equals max + const isConstant = min === max; + + // Calculate positions as percentages, clamping to the range + const meanPos = isConstant ? 50 : ((mean - min) / (max - min)) * 100; + const leftStdDevPos = isConstant ? 50 : Math.max(((mean - stdDev - min) / (max - min)) * 100, 0); + const rightStdDevPos = isConstant ? 50 : Math.min(((mean + stdDev - min) / (max - min)) * 100, 100); + const boxWidth = rightStdDevPos - leftStdDevPos; + + return ( +
+
{value.toFixed(2)}
+
+ {/* Base bar */} +
+ {/* StdDev box - only show if there's variation */} + {!isConstant && ( +
+ )} + + {/* Mean marker - made slightly larger when it's a constant value */} +
+
+
+ ); +}; + +export default LabelDisplay; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/MetricDisplay.js b/ell-studio/src/components/evaluations/MetricDisplay.js new file mode 100644 index 000000000..130a75a3a --- /dev/null +++ b/ell-studio/src/components/evaluations/MetricDisplay.js @@ -0,0 +1,74 @@ +import React from 'react'; +import { FiTrendingUp, FiTrendingDown, FiMinus } from 'react-icons/fi'; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '../common/Tooltips'; + +const getColorForTrend = (trend) => { + if (trend > 0) return 'text-emerald-400'; + if (trend < 0) return 'text-rose-400'; + return 'text-gray-400'; +}; + +const getTrendIcon = (trend) => { + if (trend > 0) return ; + if (trend < 0) return ; + return ; +}; + +const MetricDisplay = ({ currentValue : nonFloatCurrentValue, previousValue, label, showTooltip = true, showTrend = true }) => { + const currentValue = Number(nonFloatCurrentValue); + const percentChange = previousValue !== undefined && previousValue !== 0 + ? ((currentValue - previousValue) / Math.abs(previousValue) * 100).toFixed(1) + : (currentValue !== 0 ? '100.0' : '0.0'); + + const trendColorClass = getColorForTrend(parseFloat(percentChange)); + const trendIcon = getTrendIcon(parseFloat(percentChange)); + const [isHighlighted, setIsHighlighted] = React.useState(false); + React.useEffect(() => { + setIsHighlighted(true); + // Reduce timeout from 150ms to 100ms + const timer = setTimeout(() => setIsHighlighted(false), 100); + return () => clearTimeout(timer); + }, [currentValue]); + + const content = ( +
+
+ + {currentValue.toFixed(2)} + +
+ {showTrend && ( +
+ {trendIcon}{Math.abs(parseFloat(percentChange)).toFixed(1)}% +
+ )} +
+ ); + + if (!showTooltip) { + return content; + } + + + return ( + + + +
+ {content} +
+
+ +
+

{label}

+

Current: {currentValue.toFixed(4)}

+

Previous: {previousValue !== undefined && previousValue !== null ? previousValue.toFixed(4) : 'N/A'}

+

Change: {previousValue !== undefined && previousValue !== null ? (currentValue - previousValue).toFixed(4) : 'N/A'}

+
+
+
+
+ ); +}; + +export default MetricDisplay; diff --git a/ell-studio/src/components/evaluations/MetricGraphGrid.js b/ell-studio/src/components/evaluations/MetricGraphGrid.js new file mode 100644 index 000000000..7f4fd8b0e --- /dev/null +++ b/ell-studio/src/components/evaluations/MetricGraphGrid.js @@ -0,0 +1,163 @@ +import React, { useState, useCallback } from 'react'; +import { LMPCardTitle } from '../depgraph/LMPCardTitle'; +import { Card } from '../common/Card'; +import Graph from '../graphing/Graph'; +import { GraphProvider } from '../graphing/GraphSystem'; +import MetricDisplay from './MetricDisplay'; +import { Link } from 'react-router-dom'; + +const MetricGraphGrid = ({ evaluation, groupedRuns, onActiveIndexChange }) => { + const [activeIndex, setActiveIndex] = useState(null); + + const getHistoricalData = (labeler) => { + if (!labeler) return { means: [], stdDevs: [], errors: [], confidenceIntervals: [] }; + return Object.values(groupedRuns).reduce((acc, runs) => { + runs.forEach(run => { + const summary = run.labeler_summaries.find(s => s.evaluation_labeler_id === labeler.id); + if (summary) { + const { mean, std, min, max } = summary.data; + const count = summary.count; + console.log(count) + + // Calculate Standard Error of the Mean (SEM) + const sem = std / Math.sqrt(count); + + // Z-score for 95% confidence + const zScore = 1.96; + + // Margin of Error + let marginOfError = zScore * sem; + + // Bounded Confidence Interval + let lowerBound = Math.max(mean - marginOfError, min); + let upperBound = Math.min(mean + marginOfError, max); + + acc.means.push(mean); + acc.stdDevs.push(std); + acc.errors.push(marginOfError); + acc.confidenceIntervals.push({ low: lowerBound, high: upperBound }); + } + }); + return acc; + }, { means: [], stdDevs: [], errors: [], confidenceIntervals: [] }); + }; + + const xData = Array.from({ length: getHistoricalData(evaluation.labelers?.[0]).means.length}, (_, i) => `${i + 1}`); + + const handleHover = useCallback((index) => { + setActiveIndex(index); + onActiveIndexChange(index); + }, [onActiveIndexChange]); + + const handleLeave = useCallback(() => { + setActiveIndex(null); + onActiveIndexChange(null); + }, [onActiveIndexChange]); + + const hasMultipleValues = getHistoricalData(evaluation.labelers[0]).means.length > 1; + return ( + +
+ {evaluation.labelers.map((labeler) => { + const { means: historicalData, stdDevs, confidenceIntervals } = getHistoricalData(labeler); + if (historicalData.length === 0) return null; + + const currentValue = activeIndex !== null ? historicalData[activeIndex] : historicalData[historicalData.length - 1]; + const previousValue = activeIndex !== null && activeIndex > 0 ? historicalData[activeIndex - 1] : historicalData[historicalData.length - 2]; + + return ( + +
+ + + + +
+ {hasMultipleValues && ( +
+ historicalData[0] ? 'rgba(52, 211, 153, 0.8)' : 'rgba(239, 68, 68, 0.8)', + config: { + backgroundColor: currentValue > historicalData[0] ? 'rgba(52, 211, 153, 0.2)' : 'rgba(239, 68, 68, 0.2)', + borderColor: currentValue > historicalData[0] ? 'rgba(52, 211, 153, 0.8)' : 'rgba(239, 68, 68, 0.8)', + fill: true, + tension: 0.4, + borderWidth: 1, + pointRadius: 3, + } + } + ]} + /> +
+ )} +
+ ); + })} +
+
+ ); +}; + +export default MetricGraphGrid; diff --git a/ell-studio/src/components/evaluations/MetricTable.js b/ell-studio/src/components/evaluations/MetricTable.js new file mode 100644 index 000000000..e79869355 --- /dev/null +++ b/ell-studio/src/components/evaluations/MetricTable.js @@ -0,0 +1,51 @@ +import React, { useState } from 'react'; +import { FiBarChart2 } from 'react-icons/fi'; +import TrendLine from '../graphing/TrendLine'; +import MetricDisplay from './MetricDisplay'; + +const MetricTable = ({ summaries, historicalData, isVertical }) => { + const [hoverIndex, setHoverIndex] = useState(null); + + return ( +
+ {summaries.map((summary, index) => { + const currentValue = hoverIndex !== null ? historicalData[summary.evaluation_labeler_id][hoverIndex].mean : summary.data.mean; + const previousValue = historicalData[summary.evaluation_labeler_id][historicalData[summary.evaluation_labeler_id].length - 2]?.mean; + + return ( + +
+
+ + + {summary.evaluation_labeler.name} + +
+
+
+ d.mean)} + hoverIndex={hoverIndex} + onHover={setHoverIndex} + /> +
+
+ +
+
+ {index < summaries.length - 1 && ( +
+ )} + + ); + })} +
+ ); +}; + +export default MetricTable; diff --git a/ell-studio/src/components/evaluations/RunSummary.js b/ell-studio/src/components/evaluations/RunSummary.js new file mode 100644 index 000000000..943321b2e --- /dev/null +++ b/ell-studio/src/components/evaluations/RunSummary.js @@ -0,0 +1,41 @@ +import React, { useState } from 'react'; +import { LMPCardTitle } from '../depgraph/LMPCardTitle'; +import MetricTable from './MetricTable'; + +const RunSummary = ({ groupedRuns, isVertical }) => { + const latestRuns = Object.values(groupedRuns).map(runs => runs[runs.length - 1]); + const mostRecentRun = latestRuns.reduce((latest, current) => + new Date(current.end_time) > new Date(latest.end_time) ? current : latest + ); + + const scalarSummaries = mostRecentRun.labeler_summaries.filter(summary => summary.is_scalar); + + const historicalData = scalarSummaries.reduce((acc, summary) => { + acc[summary.evaluation_labeler_id] = groupedRuns[mostRecentRun.evaluated_lmp.name] + .map(run => run.labeler_summaries + .find(s => s.evaluation_labeler_id === summary.evaluation_labeler_id)?.data + ) + .filter(Boolean); + return acc; + }, {}); + + return ( +
+ + +
+ ); +}; + +export default RunSummary; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/reference.json b/ell-studio/src/components/evaluations/reference.json new file mode 100644 index 000000000..48f82706a --- /dev/null +++ b/ell-studio/src/components/evaluations/reference.json @@ -0,0 +1,463 @@ + + { + "id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "name": "poem_eval", + "created_at": "2024-10-10T22:29:35Z", + "dataset_id": "38b3eff8baf56627478ec76a704e9b52", + "n_evals": 10, + "version_number": 2, + "commit_message": null, + "labelers": [ + { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "name": "critic_score", + "type": "metric", + "labeling_lmp_id": "lmp-05680bf1e9e658bbcddad24cbb3e5a54", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "labeling_rubric": null, + "labeling_lmp": { + "name": "test_poem_eval..score", + "source": "def score(datapoint, output):\n return \"yes\" in is_good_poem(output)\n", + "created_at": "2024-10-10T22:29:33.592968Z", + "api_params": null, + "initial_global_vars": {}, + "commit_message": "Add @ell.simple decorator to is_good_poem function:\n* Introduced @ell.simple decorator with model and temperature parameters.\n* Changed function name from `score` to `is_good_poem`.\n* Updated system prompt to include specific response format.\n* Modified return value of `score` to check for 'yes' in `is_good_poem` output.", + "lmp_id": "lmp-05680bf1e9e658bbcddad24cbb3e5a54", + "dependencies": "@ell.simple(model=\"gpt-4o\", temperature=0.1)\ndef is_good_poem(poem: str):\n \"\"\"Include either 'yes' or 'no' at the end of your response. . .\"\"\"\n return f\"Is this a good poem yes/no? {poem}\"\n", + "lmp_type": "FUNCTION", + "initial_free_vars": { + "is_good_poem": "" + }, + "num_invocations": 38, + "version_number": 1 + } + }, + { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "name": "length", + "type": "metric", + "labeling_lmp_id": "lmp-11b9aab61d2f514a79f3787b0b14960f", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "labeling_rubric": null, + "labeling_lmp": { + "name": "test_poem_eval..", + "source": "\"length\": lambda _, output: len(output) ,\n", + "created_at": "2024-10-10T22:17:07.535510Z", + "api_params": null, + "initial_global_vars": {}, + "commit_message": null, + "lmp_id": "lmp-11b9aab61d2f514a79f3787b0b14960f", + "dependencies": "\n", + "lmp_type": "FUNCTION", + "initial_free_vars": {}, + "num_invocations": 70, + "version_number": 0 + } + }, + { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "name": "average_word_length", + "type": "metric", + "labeling_lmp_id": "lmp-924138b48e7bfb9ce2add3d00a284300", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "labeling_rubric": null, + "labeling_lmp": { + "name": "test_poem_eval..", + "source": "\"average_word_length\": lambda _, output: sum(len(word) for word in output.split()) / len(output.split())})\n", + "created_at": "2024-10-10T22:17:07.543092Z", + "api_params": null, + "initial_global_vars": {}, + "commit_message": null, + "lmp_id": "lmp-924138b48e7bfb9ce2add3d00a284300", + "dependencies": "\n", + "lmp_type": "FUNCTION", + "initial_free_vars": {}, + "num_invocations": 70, + "version_number": 0 + } + } + ], + "runs": [ + { + "id": 5, + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "evaluated_lmp_id": "lmp-43aa5c2a8bd3f4c03edbef87a97af6fa", + "api_params": {}, + "start_time": "2024-10-10T15:29:30.222298Z", + "end_time": "2024-10-10T15:29:35.050561Z", + "success": true, + "error": null, + "evaluated_lmp": { + "name": "test_poem_eval..write_a_good_poem", + "source": "@ell.simple(model=\"gpt-4o\")\ndef write_a_good_poem():\n \"\"\"Your poem must no logner than 60 words.\"\"\"\n return \"Write a really well written poem.\"\n", + "created_at": "2024-10-10T22:17:06.391998Z", + "api_params": null, + "initial_global_vars": {}, + "commit_message": null, + "lmp_id": "lmp-43aa5c2a8bd3f4c03edbef87a97af6fa", + "dependencies": "\n", + "lmp_type": "LM", + "initial_free_vars": {}, + "num_invocations": 37, + "version_number": 0 + }, + "labeler_summaries": [ + { + "evaluation_run_id": 5, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "created_at": "2024-10-10T22:29:35Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 5.7, + "mean": 5.156704689308549, + "min": 4.775510204081633, + "std": 0.2926856773010382 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "labeling_lmp_id": "lmp-924138b48e7bfb9ce2add3d00a284300", + "labeling_rubric": null, + "name": "average_word_length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 5, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "created_at": "2024-10-10T22:29:35Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 0, + "mean": 0, + "min": 0, + "std": 0 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "labeling_lmp_id": "lmp-05680bf1e9e658bbcddad24cbb3e5a54", + "labeling_rubric": null, + "name": "critic_score", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 5, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "created_at": "2024-10-10T22:29:35Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 310, + "mean": 293.6, + "min": 278, + "std": 11.586198686368192 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "labeling_lmp_id": "lmp-11b9aab61d2f514a79f3787b0b14960f", + "labeling_rubric": null, + "name": "length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + } + ] + }, + { + "id": 6, + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "evaluated_lmp_id": "lmp-35a382a3f39e32779d8e9194a902b63c", + "api_params": {}, + "start_time": "2024-10-10T15:29:35.074010Z", + "end_time": "2024-10-10T15:29:38.401680Z", + "success": true, + "error": null, + "evaluated_lmp": { + "name": "test_poem_eval..write_a_bad_poem", + "source": "@ell.simple(model=\"gpt-4o\")\ndef write_a_bad_poem():\n \"\"\"Your poem must no logner than 75 words.\"\"\"\n return \"Write a really poorly written poem.\"\n", + "created_at": "2024-10-10T22:19:12.331014Z", + "api_params": null, + "initial_global_vars": {}, + "commit_message": null, + "lmp_id": "lmp-35a382a3f39e32779d8e9194a902b63c", + "dependencies": "\n", + "lmp_type": "LM", + "initial_free_vars": {}, + "num_invocations": 22, + "version_number": 0 + }, + "labeler_summaries": [ + { + "evaluation_run_id": 6, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "created_at": "2024-10-10T22:29:38Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 4.267857142857143, + "mean": 3.8121618030312967, + "min": 3.2903225806451615, + "std": 0.27904990397259494 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "labeling_lmp_id": "lmp-924138b48e7bfb9ce2add3d00a284300", + "labeling_rubric": null, + "name": "average_word_length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 6, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "created_at": "2024-10-10T22:29:38Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 0, + "mean": 0, + "min": 0, + "std": 0 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "labeling_lmp_id": "lmp-05680bf1e9e658bbcddad24cbb3e5a54", + "labeling_rubric": null, + "name": "critic_score", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 6, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "created_at": "2024-10-10T22:29:38Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 364, + "mean": 295.8, + "min": 220, + "std": 43.05531326096699 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "labeling_lmp_id": "lmp-11b9aab61d2f514a79f3787b0b14960f", + "labeling_rubric": null, + "name": "length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + } + ] + }, + { + "id": 7, + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "evaluated_lmp_id": "lmp-43aa5c2a8bd3f4c03edbef87a97af6fa", + "api_params": {}, + "start_time": "2024-10-10T15:43:24.667034Z", + "end_time": "2024-10-10T15:43:27.680376Z", + "success": true, + "error": null, + "evaluated_lmp": { + "name": "test_poem_eval..write_a_good_poem", + "source": "@ell.simple(model=\"gpt-4o\")\ndef write_a_good_poem():\n \"\"\"Your poem must no logner than 60 words.\"\"\"\n return \"Write a really well written poem.\"\n", + "created_at": "2024-10-10T22:17:06.391998Z", + "api_params": null, + "initial_global_vars": {}, + "commit_message": null, + "lmp_id": "lmp-43aa5c2a8bd3f4c03edbef87a97af6fa", + "dependencies": "\n", + "lmp_type": "LM", + "initial_free_vars": {}, + "num_invocations": 37, + "version_number": 0 + }, + "labeler_summaries": [ + { + "evaluation_run_id": 7, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "created_at": "2024-10-10T22:43:27Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 5.738095238095238, + "mean": 5.1218758094739005, + "min": 4.695652173913044, + "std": 0.2912878241552186 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "labeling_lmp_id": "lmp-924138b48e7bfb9ce2add3d00a284300", + "labeling_rubric": null, + "name": "average_word_length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 7, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "created_at": "2024-10-10T22:43:27Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 0, + "mean": 0, + "min": 0, + "std": 0 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "labeling_lmp_id": "lmp-05680bf1e9e658bbcddad24cbb3e5a54", + "labeling_rubric": null, + "name": "critic_score", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 7, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "created_at": "2024-10-10T22:43:27Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 342, + "mean": 293.3, + "min": 263, + "std": 21.062051182161717 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "labeling_lmp_id": "lmp-11b9aab61d2f514a79f3787b0b14960f", + "labeling_rubric": null, + "name": "length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + } + ] + }, + { + "id": 8, + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca", + "evaluated_lmp_id": "lmp-bcda1f9818be7e4baaac4bf9a9842008", + "api_params": {}, + "start_time": "2024-10-10T15:43:27.700337Z", + "end_time": "2024-10-10T15:43:32.134376Z", + "success": true, + "error": null, + "evaluated_lmp": { + "name": "test_poem_eval..write_a_bad_poem", + "source": "@ell.simple(model=\"gpt-4o\")\ndef write_a_bad_poem():\n \"\"\"Your poem must no logner than 100 words.\"\"\"\n return \"Write a really poorly written poem.\"\n", + "created_at": "2024-10-10T22:43:29.575229Z", + "api_params": null, + "initial_global_vars": {}, + "commit_message": "Increase word limit in system prompt from 75 to 100: \n* Updated system prompt to require no longer than 100 words.", + "lmp_id": "lmp-bcda1f9818be7e4baaac4bf9a9842008", + "dependencies": "\n", + "lmp_type": "LM", + "initial_free_vars": {}, + "num_invocations": 10, + "version_number": 1 + }, + "labeler_summaries": [ + { + "evaluation_run_id": 8, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "created_at": "2024-10-10T22:43:32Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 4.148148148148148, + "mean": 3.867277533272864, + "min": 3.4705882352941178, + "std": 0.20228923985745975 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-average_word_length-METRIC", + "labeling_lmp_id": "lmp-924138b48e7bfb9ce2add3d00a284300", + "labeling_rubric": null, + "name": "average_word_length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 8, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "created_at": "2024-10-10T22:43:32Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 0, + "mean": 0, + "min": 0, + "std": 0 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-critic_score-METRIC", + "labeling_lmp_id": "lmp-05680bf1e9e658bbcddad24cbb3e5a54", + "labeling_rubric": null, + "name": "critic_score", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + }, + { + "evaluation_run_id": 8, + "evaluation_labeler_id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "created_at": "2024-10-10T22:43:32Z", + "updated_at": null, + "finalized_at": null, + "is_scalar": true, + "data": { + "max": 458, + "mean": 382, + "min": 300, + "std": 58.535459338763204 + }, + "count": 10, + "evaluation_labeler": { + "id": "labeler-evaluation-e0ec6b331a8c820b280f39650ad87eca-length-METRIC", + "labeling_lmp_id": "lmp-11b9aab61d2f514a79f3787b0b14960f", + "labeling_rubric": null, + "name": "length", + "type": "metric", + "evaluation_id": "evaluation-e0ec6b331a8c820b280f39650ad87eca" + } + } + ] + } + ] + } + \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/runs/EvaluationRunDetailsSidebar.js b/ell-studio/src/components/evaluations/runs/EvaluationRunDetailsSidebar.js new file mode 100644 index 000000000..0d754fe2a --- /dev/null +++ b/ell-studio/src/components/evaluations/runs/EvaluationRunDetailsSidebar.js @@ -0,0 +1,76 @@ +import React from 'react'; +import { FiZap, FiClock, FiActivity, FiCpu, FiCheck, FiAlertCircle } from 'react-icons/fi'; +import { motion } from 'framer-motion'; +import SidePanel from '../../common/SidePanel'; +import { Card } from '../../common/Card'; +import { getTimeAgo } from '../../../utils/lmpUtils'; + +function EvaluationRunDetailsSidebar({ run, results }) { + const totalInvocations = results?.length || 0; + const duration = run?.end_time && run?.start_time ? + new Date(run.end_time) - new Date(run.start_time) : null; + + return ( + + +
+

Run Info

+
+
+ + Status: +
+
+ {run?.success ? ( + + Success + + ) : run?.success === false ? ( + + Failed + + ) : ( + Running + )} +
+ +
+ + Started: +
+
+ {run?.start_time ? getTimeAgo(new Date(run.start_time)) : 'N/A'} +
+ +
+ + Duration: +
+
+ {duration ? `${(duration / 1000).toFixed(1)}s` : 'N/A'} +
+ +
+ + Invocations: +
+
{totalInvocations}
+
+ + {run?.error && ( +
+ Error: {run.error} +
+ )} +
+
+
+ ); +} + +export default EvaluationRunDetailsSidebar; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/runs/EvaluationRunMetrics.js b/ell-studio/src/components/evaluations/runs/EvaluationRunMetrics.js new file mode 100644 index 000000000..25261643b --- /dev/null +++ b/ell-studio/src/components/evaluations/runs/EvaluationRunMetrics.js @@ -0,0 +1,189 @@ +import React, { useMemo } from 'react'; +import { Card } from '../../common/Card'; +import { Link } from 'react-router-dom'; +import { LMPCardTitle } from '../../depgraph/LMPCardTitle'; +import Graph from '../../graphing/Graph'; +import { GraphProvider } from '../../graphing/GraphSystem'; +import MetricDisplay from '../MetricDisplay'; + +function EvaluationRunMetrics({ run, results, fullResults }) { + // Use fullResults (unfiltered) to determine axis scales + const { histogramDataMap, axisScales } = useMemo(() => { + if(!results) return { histogramDataMap: null, axisScales: null }; + + const dataMap = new Map(); + const scales = new Map(); + + run?.labeler_summaries?.forEach(summary => { + const labelerId = summary.evaluation_labeler_id; + + // Get values from the full dataset to determine scales + const allValues = (fullResults || results) + .flatMap(result => + result.labels + .filter(label => label.labeler_id === labelerId) + .map(label => label.label_invocation?.contents?.results) + ) + .filter(value => typeof value === 'number' || typeof value === 'boolean'); + + if (allValues.length === 0) return; + + // Calculate global min and max from full dataset + const globalMin = Math.min(...allValues); + const globalMax = Math.max(...allValues); + + scales.set(labelerId, { min: globalMin, max: globalMax }); + + // Now get values from filtered results for the histogram + const filteredValues = results + .flatMap(result => + result.labels + .filter(label => label.labeler_id === labelerId) + .map(label => label.label_invocation?.contents?.results) + ) + .filter(value => typeof value === 'number' || typeof value === 'boolean'); + + if (filteredValues.length === 0) return; + + // Use global min/max for binning, even with filtered data + if (globalMin === globalMax) { + const padding = Math.abs(globalMin * 0.1) || 0.1; + dataMap.set(labelerId, { + binLabels: [(globalMin - padding).toFixed(2), globalMin.toFixed(2), (globalMin + padding).toFixed(2)], + counts: [0, filteredValues.length, 0] + }); + return; + } + + const numBins = 10; + const binWidth = (globalMax - globalMin) / numBins; + + const histogramData = Array(numBins).fill(0); + + filteredValues.forEach(value => { + const binIndex = Math.min( + Math.floor((value - globalMin) / binWidth), + numBins - 1 + ); + if (binIndex >= 0 && binIndex < numBins) { + histogramData[binIndex]++; + } + }); + + const binLabels = Array.from({ length: numBins }, (_, i) => { + const binStart = globalMin + (i * binWidth); + return ((binStart + (binStart + binWidth)) / 2).toFixed(2); + }); + + dataMap.set(labelerId, { + binLabels, + counts: histogramData + }); + }); + + return { histogramDataMap: dataMap, axisScales: scales }; + }, [results, fullResults, run?.labeler_summaries]); + + if(!results) return null; + return ( +
+

Metrics

+ +
+ {run?.labeler_summaries?.map((summary, index) => { + const histogramData = histogramDataMap?.get(summary.evaluation_labeler_id); + const scale = axisScales?.get(summary.evaluation_labeler_id); + + if (!histogramData) return null; + + return ( + +
+ + + + +
+
+ +
+
+ ); + })} +
+
+
+ ); +} + +export default EvaluationRunMetrics; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/runs/EvaluationRunOverview.js b/ell-studio/src/components/evaluations/runs/EvaluationRunOverview.js new file mode 100644 index 000000000..e0a51ac42 --- /dev/null +++ b/ell-studio/src/components/evaluations/runs/EvaluationRunOverview.js @@ -0,0 +1,55 @@ +import React from 'react'; +import { Link } from 'react-router-dom'; +import { Card, CardContent } from '../../common/Card'; +import { EvaluationCardTitle } from '../EvaluationCardTitle'; +import { LMPCardTitle } from '../../depgraph/LMPCardTitle'; +import LMPSourceView from '../../source/LMPSourceView'; + +function EvaluationRunOverview({ run }) { + return ( +
+
+ + +
+ +
+
+ + +
+ Run #{run?.id} +
+
+ +
+

Evaluated LMP

+
+ + +
+ +
+
+
+
+
+ ); +} + +export default EvaluationRunOverview; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/runs/EvaluationRunResultsTable.js b/ell-studio/src/components/evaluations/runs/EvaluationRunResultsTable.js new file mode 100644 index 000000000..4d33dda1f --- /dev/null +++ b/ell-studio/src/components/evaluations/runs/EvaluationRunResultsTable.js @@ -0,0 +1,357 @@ +import React, { useMemo, useEffect } from 'react'; +import HierarchicalTable from '../../HierarchicalTable'; +import { Card } from '../../common/Card'; +import { ContentsRenderer } from '../../invocations/ContentsRenderer'; +import LabelDisplay from '../LabelDisplay'; +import InvocationDetailsPopover from '../../invocations/details/InvocationDetailsPopover'; + +const MAX_PREVIEW_ITEMS = 3; + +const OutputPreview = ({ outputs, invocation, isExpanded }) => { + const totalOutputs = outputs.length; + + if (isExpanded) { + return ( +
+ {totalOutputs} outputs +
+ ); + } + + const previewOutputs = outputs.slice(0, MAX_PREVIEW_ITEMS); + + return ( +
+ {previewOutputs.map((output, idx) => ( +
+ +
+ ))} + {totalOutputs > MAX_PREVIEW_ITEMS && ( +
+ & {totalOutputs - MAX_PREVIEW_ITEMS} more... +
+ )} +
+ ); +}; + +const EvaluationRunResultsTable = ({ + results, + currentPage, + setCurrentPage, + pageSize, + selectedTrace, + setSelectedTrace, + searchQuery, + onFilteredResultsChange +}) => { + const createInvocationWithLabels = (item, results) => { + const result = results.find(r => r.id === item.id); + return { + ...item.invocation, + labels: result?.labels || [] + }; + }; + + const resultsTableData = useMemo(() => { + if (!results) return []; + + // Group results by input hash + const groupedByInput = results.reduce((acc, result) => { + const inputHash = JSON.stringify(result.invocation_being_labeled.contents.params); + if (!acc[inputHash]) { + acc[inputHash] = { + items: [], + input: result.invocation_being_labeled.contents.params, + }; + } + acc[inputHash].items.push(result); + return acc; + }, {}); + + // Calculate mean values and stats for each group + let tableData = Object.entries(groupedByInput).map(([inputHash, group]) => { + // If there's only one item in the group, return it directly without grouping + if (group.items.length === 1) { + const result = group.items[0]; + return { + id: result.id, + invocation: result.invocation_being_labeled, + labels: result.labels.reduce((acc, label) => { + acc[label.labeler_id] = label.label_invocation.contents.results; + return acc; + }, {}), + children: [] + }; + } + + // Rest of the existing grouping logic for multiple items + const children = group.items.map(result => ({ + id: result.id, + invocation: result.invocation_being_labeled, + labels: result.labels.reduce((acc, label) => { + acc[label.labeler_id] = label.label_invocation.contents.results; + return acc; + }, {}), + children: [] + })); + + // Calculate stats for the group + const labelStats = {}; + if (children.length > 0) { + const firstChild = children[0]; + Object.keys(firstChild.labels).forEach(labelerId => { + const values = children + .map(child => child.labels[labelerId]) + .filter(value => typeof value === 'number' || typeof value === 'boolean'); + + if (values.length > 0) { + const mean = values.reduce((a, b) => a + b, 0) / values.length; + const stdDev = Math.sqrt( + values.reduce((acc, val) => acc + Math.pow(val - mean, 2), 0) / values.length + ); + labelStats[labelerId] = { + mean, + stdDev, + min: Math.min(...values), + max: Math.max(...values) + }; + } + }); + } + + // Get all outputs for the preview + const outputs = children.map(child => + child.invocation.contents.results?.content || child.invocation.contents.results + ); + + return { + id: inputHash, + invocation: { + contents: { + params: group.input, + results: outputs + } + }, + labels: Object.fromEntries( + Object.entries(labelStats).map(([key, stats]) => [key, stats.mean]) + ), + labelStats, + children: children, + isGroup: true + }; + }); + + // Apply search filter if there's a search query + if (searchQuery) { + const query = searchQuery.toLowerCase(); + + // Helper function to check if an item matches the search query + const itemMatches = (item) => { + // For leaf nodes (children) + if (!item.isGroup) { + const inputMatch = JSON.stringify(item.invocation.contents.params) + .toLowerCase() + .includes(query); + + const outputMatch = JSON.stringify(item.invocation.contents.results) + .toLowerCase() + .includes(query); + + const labelMatch = Object.values(item.labels).some(value => + String(value).toLowerCase().includes(query) + ); + + return inputMatch || outputMatch || labelMatch; + } + + // For group nodes, check if any children match + return item.children.some(child => itemMatches(child)); + }; + + // Filter the table data, keeping groups that have matching children + tableData = tableData.map(group => { + const matchingChildren = group.children.filter(itemMatches); + + if (matchingChildren.length > 0) { + return { + ...group, + children: matchingChildren + }; + } + return null; + }).filter(Boolean); + + // Create filtered results array for metrics + const filteredResults = tableData.flatMap(group => + group.children.map(child => { + // Find original result that matches this child + return results.find(result => result.id === child.id); + }) + ); + + // Notify parent component of filtered results + onFilteredResultsChange(filteredResults); + } else { + // If no search query, reset filtered results + onFilteredResultsChange(null); + } + + return tableData; + }, [results, searchQuery, onFilteredResultsChange]); + + const labelerColumns = useMemo(() => { + if (!results?.[0]?.labels) return []; + + return results[0].labels.map(label => ({ + header: label.labeler_id.split('-')[3] || 'Label', + key: label.labeler_id, + render: (item) => { + return ( + + ); + }, + maxWidth: 150, + sortable: true, + sortFn: (a, b) => { + const aValue = a.labels[label.labeler_id] ?? -Infinity; + const bValue = b.labels[label.labeler_id] ?? -Infinity; + return aValue - bValue; + } + })); + }, [results]); + + const columns = [ + { + header: 'Input', + key: 'input', + render: (item, _, { expanded, isHovered }) => ( +
+ +
+ ), + maxWidth: 300, + sortable: true, + sortFn: (a, b) => { + const aInput = JSON.stringify(a.invocation.contents.params); + const bInput = JSON.stringify(b.invocation.contents.params); + return aInput.localeCompare(bInput); + } + }, + { + header: 'Output', + key: 'output', + render: (item, _, { expanded }) => ( + item.isGroup ? ( + + ) : ( + + ) + ), + maxWidth: 300, + sortable: true, + sortFn: (a, b) => { + const aOutput = JSON.stringify(a.invocation.contents.results); + const bOutput = JSON.stringify(b.invocation.contents.results); + return aOutput.localeCompare(bOutput); + } + }, + ...labelerColumns, + ]; + + const handleRowClick = (item, toggleRow) => { + if (item.isGroup) { + toggleRow(item.id); + } else { + const trace = createInvocationWithLabels(item, results); + setSelectedTrace(trace); + } + }; + + const hasNextPage = resultsTableData.length === pageSize; + + useEffect(() => { + const handleKeyDown = (e) => { + if (e.key === 'Escape') { + setSelectedTrace(null); + return; + } + + if (selectedTrace) { + // Get all navigable items - both ungrouped items and children of grouped items + const allItems = resultsTableData.flatMap(item => + item.isGroup ? item.children : [item] + ); + + const currentIndex = allItems.findIndex(item => + item.invocation.id === selectedTrace.id + ); + + if (e.key === 'ArrowUp' && currentIndex > 0) { + e.preventDefault(); + const prevItem = allItems[currentIndex - 1]; + const trace = createInvocationWithLabels(prevItem, results); + setSelectedTrace(trace); + } else if (e.key === 'ArrowDown' && currentIndex < allItems.length - 1) { + e.preventDefault(); + const nextItem = allItems[currentIndex + 1]; + const trace = createInvocationWithLabels(nextItem, results); + setSelectedTrace(trace); + } + } + }; + + window.addEventListener('keydown', handleKeyDown); + return () => { + window.removeEventListener('keydown', handleKeyDown); + }; + }, [resultsTableData, selectedTrace, setSelectedTrace, results]); + + return ( + + !item.isGroup && item.invocation.id === selectedTrace?.id ? 'bg-blue-600 bg-opacity-30' : '' + } + /> + ); +}; + +export default EvaluationRunResultsTable; \ No newline at end of file diff --git a/ell-studio/src/components/evaluations/runs/EvaluationRunsTable.js b/ell-studio/src/components/evaluations/runs/EvaluationRunsTable.js new file mode 100644 index 000000000..f255e5a21 --- /dev/null +++ b/ell-studio/src/components/evaluations/runs/EvaluationRunsTable.js @@ -0,0 +1,174 @@ +import React, { useMemo } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { LMPCardTitle } from '../../depgraph/LMPCardTitle'; +import HierarchicalTable from '../../HierarchicalTable'; +import { Card } from '../../common/Card'; +import { getTimeAgo } from '../../../utils/lmpUtils'; +import VersionBadge from '../../VersionBadge'; +import { Spinner } from '../../common/Spinner'; +import { FiXCircle } from 'react-icons/fi'; +import LabelDisplay from '../LabelDisplay'; +import { IoCheckmarkCircleOutline } from 'react-icons/io5'; + +const EvaluationRunsTable = ({ runs, currentPage, setCurrentPage, pageSize, onSelectRun, currentlySelectedRun, activeIndex }) => { + const navigate = useNavigate(); + + const onClickLMP = (run) => { + navigate(`/lmp/${run.evaluated_lmp.name}/${run.evaluated_lmp.lmp_id}`); + }; + + const handleRowClick = (run) => { + // if (onSelectRun) { + // onSelectRun(run); + // } + navigate(`/evaluation-runs/${run.id}`); + }; + + const runsTableData = useMemo(() => { + return runs.map((run, index) => ({ + ...run, + id: run.id, + name: run.evaluated_lmp.name, + version: run.evaluated_lmp.version_number + 1, + created_at: run.end_time ? new Date(run.end_time) : null, + runIndex: index, + })); + }, [runs]); + + const getMetricColumns = () => { + if (runs.length === 0 || !runs[0].labeler_summaries) return []; + + return runs[0].labeler_summaries.map(summary => ({ + header: summary.evaluation_labeler.name, + key: summary.evaluation_labeler.id, + render: (item) => { + const metricSummary = item.labeler_summaries.find(s => s.evaluation_labeler_id === summary.evaluation_labeler_id); + const isRunning = !item.end_time && item.success === null; + + if (isRunning) { + return ; + } + + if (!metricSummary?.data) return null; + + return ( + + ); + }, + maxWidth: 150, + sortable: true, + })); + }; + + const statusColumn = { + header: '', + key: 'success', + render: (item) => { + const isRunning = !item.end_time && item.success === null; + + if (isRunning) { + return ; + } + + return item.success ? ( + + ) : ( + + ); + } + }; + + const columns = [ + { + header: 'LMP', + key: 'name', + render: (item) => ( + + { + e.stopPropagation(); + onClickLMP(item); + }} + showInvocationCount={false} + /> + ), + sortable: true, + maxWidth: 200, + }, + { + header: 'Version', + key: 'version', + render: (item) => ( + + ), + maxWidth: 150, + sortable: true + }, + ...getMetricColumns(), + { + header: 'Finished', + key: 'created_at', + render: (item) => { + const isRunning = !item.end_time && item.success === null; + return ( + + {isRunning ? 'Running...' : getTimeAgo(item.created_at)} + + ); + }, + maxWidth: 150, + sortable: true + }, + ]; + + const initialSortConfig = { key: 'created_at', direction: 'desc' }; + + const hasNextPage = runsTableData.length === pageSize; + + return ( + { + let className = ''; + if (item.runIndex === activeIndex) { + className += 'bg-blue-600 bg-opacity-30 '; + } else if (activeIndex !== null) { + className += 'opacity-50 '; + } + if (item.id === currentlySelectedRun?.id) { + className += 'border-2 border-blue-600 '; + } + return className.trim(); + }} + currentPage={currentPage} + onPageChange={setCurrentPage} + pageSize={pageSize} + hasNextPage={hasNextPage} + links={[]} + showHierarchical={false} + statusColumn={statusColumn} + /> + ); +}; + +export default EvaluationRunsTable; diff --git a/ell-studio/src/components/evaluations/runs/SearchAndFiltersBar.js b/ell-studio/src/components/evaluations/runs/SearchAndFiltersBar.js new file mode 100644 index 000000000..aa3e58ca3 --- /dev/null +++ b/ell-studio/src/components/evaluations/runs/SearchAndFiltersBar.js @@ -0,0 +1,20 @@ +import React from 'react'; + +function SearchAndFiltersBar({ searchQuery, setSearchQuery }) { + return ( +
+
+ setSearchQuery(e.target.value)} + className="w-full px-4 py-2 rounded-md border border-border bg-background text-foreground focus:outline-none focus:ring-2 focus:ring-primary" + /> +
+ {/* We can add more filters here later */} +
+ ); +} + +export default SearchAndFiltersBar; \ No newline at end of file diff --git a/ell-studio/src/components/graphing/ErrorBarPlugin.js b/ell-studio/src/components/graphing/ErrorBarPlugin.js new file mode 100644 index 000000000..c180f8683 --- /dev/null +++ b/ell-studio/src/components/graphing/ErrorBarPlugin.js @@ -0,0 +1,95 @@ +import { Chart as ChartJS } from 'chart.js'; + +// Add this new function to fade the color +const fadeColor = (color, opacity) => { + if (color.startsWith('#')) { + // Convert hex to RGB + const r = parseInt(color.slice(1, 3), 16); + const g = parseInt(color.slice(3, 5), 16); + const b = parseInt(color.slice(5, 7), 16); + return `rgba(${r}, ${g}, ${b}, ${opacity})`; + } else if (color.startsWith('rgb')) { + // If it's already RGB or RGBA, just change the opacity + const rgb = color.match(/\d+/g); + return `rgba(${rgb[0]}, ${rgb[1]}, ${rgb[2]}, ${opacity})`; + } + // If color format is not recognized, return the original color + return color; +}; + +const drawErrorBar = (ctx, x, y, errorLow, errorHigh, color, width) => { + ctx.save(); + ctx.strokeStyle = fadeColor(color, 0.3); + ctx.lineWidth = width; + + // Draw vertical line + ctx.beginPath(); + ctx.moveTo(x, y - errorHigh); + ctx.lineTo(x, y + errorLow); + ctx.stroke(); + + // Draw horizontal caps + const capLength = 5; + ctx.beginPath(); + ctx.moveTo(x - capLength, y - errorHigh); + ctx.lineTo(x + capLength, y - errorHigh); + ctx.moveTo(x - capLength, y + errorLow); + ctx.lineTo(x + capLength, y + errorLow); + ctx.stroke(); + + ctx.restore(); +}; + +const ErrorBarPlugin = { + id: 'errorBar', + beforeInit(chart) { + chart.errorBarData = {}; + }, + afterDatasetsDraw(chart, args, options) { + const { ctx } = chart; + + if (!options.draw) { + return; + } + + chart.data.datasets.forEach((dataset, datasetIndex) => { + if (dataset.errorBars) { + const meta = chart.getDatasetMeta(datasetIndex); + + dataset.data.forEach((datapoint, index) => { + if (dataset.errorBars[index] !== undefined) { + const { x, y } = meta.data[index].getCenterPoint(); + + let errorLow, errorHigh; + if (typeof dataset.errorBars[index] === 'object') { + errorLow = dataset.errorBars[index].low; + errorHigh = dataset.errorBars[index].high; + } else { + errorLow = datapoint - dataset.errorBars[index]; + errorHigh = datapoint + dataset.errorBars[index]; + } + + // Store error bar data for tooltip access + if (!chart.errorBarData[datasetIndex]) { + chart.errorBarData[datasetIndex] = []; + } + chart.errorBarData[datasetIndex][index] = { low: errorLow, high: errorHigh }; + + // Convert to pixel values for drawing + const errorLowPx = Math.abs(chart.scales.y.getPixelForValue(datapoint) - + chart.scales.y.getPixelForValue(errorLow)); + const errorHighPx = Math.abs(chart.scales.y.getPixelForValue(datapoint) - + chart.scales.y.getPixelForValue(errorHigh)); + + drawErrorBar(ctx, x, y, errorLowPx, errorHighPx, dataset.borderColor, dataset.borderWidth || 1); + } + }); + } + }); + } +}; + +export default ErrorBarPlugin; + +// Register the plugin +ChartJS.register(ErrorBarPlugin); diff --git a/ell-studio/src/components/graphing/Graph.js b/ell-studio/src/components/graphing/Graph.js new file mode 100644 index 000000000..ce81d697d --- /dev/null +++ b/ell-studio/src/components/graphing/Graph.js @@ -0,0 +1,22 @@ +import React from 'react'; +import { GraphRenderer, MetricAdder, useGraph } from './GraphSystem'; + +const Graph = ({ graphId, metrics, type = 'line' }) => { + useGraph(graphId); + + return ( + <> + {metrics.map((metric, index) => ( + + ))} + + + ); +}; + +export default Graph; diff --git a/ell-studio/src/components/graphing/GraphSystem.js b/ell-studio/src/components/graphing/GraphSystem.js new file mode 100644 index 000000000..d9b971e79 --- /dev/null +++ b/ell-studio/src/components/graphing/GraphSystem.js @@ -0,0 +1,264 @@ +import React, { createContext, useContext, useState, useCallback } from 'react'; +import { Chart as ChartJS, CategoryScale, LinearScale, PointElement, LineElement, Title, Tooltip, Legend, BarElement } from 'chart.js'; +import { Line, Bar } from 'react-chartjs-2'; +import { SharedVerticalIndicator, useSharedVerticalIndicator } from './SharedVerticalIndicator'; +import ErrorBarPlugin from './ErrorBarPlugin'; + +ChartJS.register(CategoryScale, LinearScale, PointElement, LineElement, BarElement, Title, Tooltip, Legend, ErrorBarPlugin); + +const GraphContext = createContext(); + +export const useGraphContext = () => useContext(GraphContext); + +export const GraphProvider = ({ children, xData, sharedConfig, onHover, onLeave, shared = true }) => { + const [graphs, setGraphs] = useState({}); + const [activeIndicatorIndex, setActiveIndicatorIndex] = useState(null); + const [sharedIndicatorY, setSharedIndicatorY] = useState(null); + + const addGraph = useCallback((graphId) => { + setGraphs(prevGraphs => { + if (!prevGraphs[graphId]) { + return { + ...prevGraphs, + [graphId]: { metrics: [] } + }; + } + return prevGraphs; + }); + }, []); + + const removeGraph = useCallback((graphId) => { + setGraphs(prevGraphs => { + const { [graphId]: removed, ...rest } = prevGraphs; + return rest; + }); + }, []); + + const addMetric = useCallback((graphId, metric) => { + setGraphs(prevGraphs => ({ + ...prevGraphs, + [graphId]: { + ...prevGraphs[graphId], + metrics: [...(prevGraphs[graphId]?.metrics || []), metric] + } + })); + }, []); + + const removeMetric = useCallback((graphId, metricId) => { + setGraphs(prevGraphs => ({ + ...prevGraphs, + [graphId]: { + ...prevGraphs[graphId], + metrics: prevGraphs[graphId]?.metrics.filter(m => m.id !== metricId) || [] + } + })); + }, []); + + const setActiveIndicator = useCallback((index, y) => { + setActiveIndicatorIndex(index); + setSharedIndicatorY(y); + if (onHover) { + onHover(index); + } + }, [onHover]); + + const clearActiveIndicator = useCallback(() => { + setActiveIndicatorIndex(null); + setSharedIndicatorY(null); + if (onLeave) { + onLeave(); + } + }, [onLeave]); + + return ( + + {children} + + ); +}; + +export const GraphRenderer = ({ graphId }) => { + const { + xData, + graphs, + sharedConfig, + activeIndicatorIndex, + sharedIndicatorY, + setActiveIndicator, + clearActiveIndicator, + shared + } = useGraphContext(); + const graph = graphs[graphId]; + const chartRef = React.useRef(null); + const indicatorState = useSharedVerticalIndicator(chartRef, activeIndicatorIndex, sharedIndicatorY, clearActiveIndicator); + + if (!graph || !graph.metrics || graph.metrics.length === 0) { + return
Loading graph...
; + } + + const chartType = graph.metrics[0]?.type || 'line'; + + const labels = shared ? xData : graph.metrics[0]?.xData || xData; + + const data = { + labels, + datasets: graph.metrics.map(metric => ({ + label: metric.label, + data: metric.yData, + xData: metric.xData, + borderColor: metric.color, + backgroundColor: chartType === 'line' ? metric.color : `${metric.color}80`, + errorBars: metric.errorBars, + barPercentage: chartType === 'histogram' ? 1.0 : 0.9, + categoryPercentage: chartType === 'histogram' ? 1.0 : 0.8, + ...metric.config, + })), + }; + + const hasNonZeroErrorBars = data.datasets.some(dataset => + dataset.errorBars && dataset.errorBars.some(error => error > 0 || (error.low - error.high > 0)) + ); + + let yAxisScale = { + ...sharedConfig.options.scales.y,}; + if (hasNonZeroErrorBars || true) { + const minMaxValues = data.datasets.reduce((acc, dataset) => { + dataset.data.forEach((value, index) => { + const errorBar = dataset.errorBars ? dataset.errorBars[index] : 0; + if (typeof errorBar === 'number') { + acc.min = Math.min(acc.min, value - errorBar); + acc.max = Math.max(acc.max, value + errorBar); + } else if (errorBar && typeof errorBar === 'object') { + console.log('errorBar', errorBar); + acc.min = Math.min(acc.min, errorBar.low); + acc.max = Math.max(acc.max, errorBar.high); + } else { + acc.min = Math.min(acc.min, value); + acc.max = Math.max(acc.max, value); + } + }); + return acc; + }, { min: Infinity, max: -Infinity }); + + const yAxisPadding = (minMaxValues.max - minMaxValues.min) * 0.1; + + yAxisScale = { + y: { + ...sharedConfig.options.scales.y, + beginAtZero: false, + min: Math.max(0, minMaxValues.min - yAxisPadding), + max: minMaxValues.max + yAxisPadding, + } + }; + } + + const options = { + responsive: true, + maintainAspectRatio: false, + hover: { mode: 'index', intersect: false }, + ...sharedConfig.options, + onHover: (event, elements, chart) => { + if (elements && elements.length > 0) { + const rect = chart.canvas.getBoundingClientRect(); + const y = rect.bottom - rect.top + 40; + setActiveIndicator(elements[0].index, y); + } else { + clearActiveIndicator(); + } + }, + scales: { + ...sharedConfig.options.scales, + ...yAxisScale, + }, + plugins: { + ...sharedConfig.options.plugins, + tooltip: { + callbacks: { + label: function(context) { + const value = context.parsed.y; + const chartType = context.chart.config.type; + const metricLabel = context.dataset.label || ''; + + if (chartType === 'histogram' || chartType === 'bar') { + // For histograms, show the bin range and count + const binLabel = context.label; + const binWidth = parseFloat(context.dataset.xData?.[1]) - parseFloat(context.dataset.xData?.[0]); + const binStart = parseFloat(binLabel) - (binWidth / 2); + const binEnd = parseFloat(binLabel) + (binWidth / 2); + return `${metricLabel} ${value} samples`; + } + + // For line charts, keep the existing logic + let label = `${metricLabel}: ${value.toFixed(2)}`; + + const chartErrorBarData = context.chart.errorBarData; + const errorData = chartErrorBarData?.[context.datasetIndex]?.[context.dataIndex]; + + if (errorData) { + label += ` (95% CI: [${errorData.low.toFixed(2)}, ${errorData.high.toFixed(2)}])`; + } + + return label; + } + } + }, + errorBar: { + draw: chartType === 'line', + }, + }, + }; + + const ChartComponent = chartType === 'line' ? Line : Bar; + + return ( +
+ + {shared && chartType === 'line' && ( + + )} +
+ ); +}; + +export const MetricAdder = ({ graphId, label, yData, xData, color, config, errorBars, type }) => { + const { addMetric, removeMetric } = useGraphContext(); + + React.useEffect(() => { + const metricId = Date.now(); + addMetric(graphId, { id: metricId, label, yData, xData, color, config, errorBars, type }); + return () => removeMetric(graphId, metricId); + }, [graphId, label, yData, xData, color, config, errorBars, type, addMetric, removeMetric]); + + return null; +}; + +export const useGraph = (graphId) => { + const { addGraph, removeGraph } = useGraphContext(); + + React.useEffect(() => { + addGraph(graphId); + return () => removeGraph(graphId); + }, [graphId, addGraph, removeGraph]); + + return useGraphContext().graphs[graphId]; +}; diff --git a/ell-studio/src/components/graphing/SharedVerticalIndicator.js b/ell-studio/src/components/graphing/SharedVerticalIndicator.js new file mode 100644 index 000000000..38cafd37a --- /dev/null +++ b/ell-studio/src/components/graphing/SharedVerticalIndicator.js @@ -0,0 +1,89 @@ +import React, { useState, useEffect } from 'react'; + +const formatNumber = (value) => { + if (typeof value === 'number') { + if (Math.abs(value) < 1) { + return value.toFixed(4); + } else if (Math.abs(value) < 100) { + return value.toFixed(2); + } else { + return value.toFixed(0); + } + } + return value; +}; + +export const SharedVerticalIndicator = ({ visible, position, labels, datasets, activeIndex, chartHeight }) => { + const [animatedPosition, setAnimatedPosition] = useState(position); + + useEffect(() => { + if (visible) { + setAnimatedPosition(position); + } + }, [visible, position]); + + if (!visible || activeIndex === null) return null; + + // Use the color of the first dataset, or default to black if no datasets + const lineColor = datasets.length > 0 ? datasets[0].borderColor : 'rgba(0,0,0,0.7)'; + return ( + <> +
+ + ); +}; + +export const useSharedVerticalIndicator = (chartRef, activeIndicatorIndex, sharedIndicatorY, clearActiveIndicator) => { + const [indicatorLocation, setIndicatorLocation] = useState({ visible: false, position: { x: 0, y: 0 } }); + const [chartHeight, setChartHeight] = useState(0); + + useEffect(() => { + const chart = chartRef.current; + if (!chart) return; + + const updateSharedVerticalIndicator = () => { + if (!chart || !chart.canvas) return; + + if (activeIndicatorIndex !== null && activeIndicatorIndex >= 0 && activeIndicatorIndex < chart.data.labels.length) { + const meta = chart.getDatasetMeta(0); + if (!meta || !meta.data || activeIndicatorIndex >= meta.data.length) return; + + const activeElement = meta.data[activeIndicatorIndex]; + + setIndicatorLocation({ visible: true, position: { x: activeElement.x, y: sharedIndicatorY } }); + setChartHeight(chart.height); + } else { + setIndicatorLocation(prev => ({ ...prev, visible: false })); + } + }; + + updateSharedVerticalIndicator(); + + const handleMouseLeave = () => { + clearActiveIndicator(); + }; + + chart.canvas.addEventListener('mouseleave', handleMouseLeave); + + return () => { + if (chart.canvas) { + chart.canvas.removeEventListener('mouseleave', handleMouseLeave); + } + }; + }, [activeIndicatorIndex, sharedIndicatorY, clearActiveIndicator, chartRef]); + + return { ...indicatorLocation, chartHeight }; +}; diff --git a/ell-studio/src/components/graphing/TrendLine.js b/ell-studio/src/components/graphing/TrendLine.js new file mode 100644 index 000000000..c137efd48 --- /dev/null +++ b/ell-studio/src/components/graphing/TrendLine.js @@ -0,0 +1,58 @@ +import React from 'react'; +import { Line } from 'react-chartjs-2'; +import { Chart as ChartJS, CategoryScale, LinearScale, LineElement, PointElement, Tooltip as ChartTooltip, Filler } from 'chart.js'; + +ChartJS.register(CategoryScale, LinearScale, LineElement, PointElement, ChartTooltip, Filler); + +const TrendLine = ({ data, hoverIndex, onHover }) => { + const trend = data[data.length - 1] - data[0]; + const trendColor = trend > 0 ? 'rgba(52, 211, 153, 0.8)' : 'rgba(239, 68, 68, 0.8)'; + const fillColor = trend > 0 ? 'rgba(52, 211, 153, 0.2)' : 'rgba(239, 68, 68, 0.2)'; + + const chartData = { + labels: data.map((_, index) => index + 1), + datasets: [{ + data, + borderColor: trendColor, + backgroundColor: fillColor, + pointRadius: 0, + borderWidth: 1, + tension: 0.4, + fill: true, + }], + }; + + const options = { + responsive: true, + maintainAspectRatio: false, + plugins: { + legend: { display: false }, + tooltip: { enabled: false } + }, + scales: { + x: { display: false }, + y: { + display: false, + min: Math.min(...data) * 0.95, + max: Math.max(...data) * 1.05, + } + }, + }; + + return ( +
{ + const rect = e.currentTarget.getBoundingClientRect(); + const x = e.clientX - rect.left; + const index = Math.round((x / rect.width) * (data.length - 1)); + onHover(index); + }} + onMouseLeave={() => onHover(null)} + > + +
+ ); +}; + +export default TrendLine; diff --git a/ell-studio/src/components/invocations/ContentsRenderer.js b/ell-studio/src/components/invocations/ContentsRenderer.js index 8df6b6105..a154a6002 100644 --- a/ell-studio/src/components/invocations/ContentsRenderer.js +++ b/ell-studio/src/components/invocations/ContentsRenderer.js @@ -4,6 +4,7 @@ import IORenderer from '../IORenderer'; export function ContentsRenderer({ item, field, ...rest }) { const contents = item.contents; + console.log(contents[field]); if (contents.is_external && !contents.is_external_loaded) { return
Loading...
; diff --git a/ell-studio/src/components/invocations/InvocationInfoPane.js b/ell-studio/src/components/invocations/InvocationInfoPane.js index 2dab572fd..9f83abe2d 100644 --- a/ell-studio/src/components/invocations/InvocationInfoPane.js +++ b/ell-studio/src/components/invocations/InvocationInfoPane.js @@ -44,7 +44,7 @@ export function InvocationInfoPane({ invocation, isFullWidth }) {
- {invocation.lmp.is_lm ? "LM" : "LMP"} + {invocation.lmp?.is_lm ? "LM" : "LMP"}
diff --git a/ell-studio/src/components/invocations/InvocationsAnalyticsSidePanel.js b/ell-studio/src/components/invocations/InvocationsAnalyticsSidePanel.js index 8a98b8e7f..95f042cb1 100644 --- a/ell-studio/src/components/invocations/InvocationsAnalyticsSidePanel.js +++ b/ell-studio/src/components/invocations/InvocationsAnalyticsSidePanel.js @@ -2,7 +2,7 @@ import React from 'react'; import { FiZap, FiClock, FiHash, FiUsers, FiPercent, FiBox } from 'react-icons/fi'; import { Link } from 'react-router-dom'; import SidePanel from '../common/SidePanel'; -import MetricChart from '../MetricChart'; +import MetricChart from '../oldgraph/OldMetricChart'; import { motion } from 'framer-motion'; import { LMPCardTitle } from '../depgraph/LMPCardTitle'; import { Card } from '../common/Card'; diff --git a/ell-studio/src/components/invocations/InvocationsTable.js b/ell-studio/src/components/invocations/InvocationsTable.js index cb3f946e7..3c34f9123 100644 --- a/ell-studio/src/components/invocations/InvocationsTable.js +++ b/ell-studio/src/components/invocations/InvocationsTable.js @@ -1,7 +1,7 @@ import { LMPCardTitle } from '../depgraph/LMPCardTitle'; import HierarchicalTable from '../HierarchicalTable'; import React, { useMemo, useCallback, useEffect, useState } from 'react'; -import { Card } from '../Card'; +import { OldCard } from '../OldCard'; import { getTimeAgo } from '../../utils/lmpUtils'; import VersionBadge from '../VersionBadge'; import { useNavigate } from 'react-router-dom'; @@ -38,7 +38,7 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize, const invocationsMap = new Map(); const rootInvocations = []; - // First pass: map all invocations and identify roots + // First pass: map all invocations invocations.forEach(invocation => { const mappedInvocation = mapInvocation(invocation); invocationsMap.set(invocation.id, mappedInvocation); @@ -48,6 +48,92 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize, } }); + // Helper function to build clusters within a single level + const buildLevelClusters = (levelInvocations) => { + const visited = new Set(); + const levelClusters = new Map(); + let clusterId = 0; + + const buildCluster = (startId, clusterId) => { + const stack = [startId]; + const clusterItems = []; + + while (stack.length > 0) { + const currentId = stack.pop(); + if (visited.has(currentId)) continue; + + // Only visit nodes that are part of this level + const currentItem = levelInvocations.find(inv => inv.id === currentId); + if (!currentItem) continue; + + visited.add(currentId); + currentItem.clusterId = clusterId; + clusterItems.push(currentItem); + + // Find linked items within this level + const linkedIds = []; + if (currentItem.consumes) { + currentItem.consumes.forEach(c => { + if (c.id !== currentId && levelInvocations.some(inv => inv.id === c.id)) { + linkedIds.push(c.id); + } + }); + } + // Add items that consume currentItem (within this level) + levelInvocations.forEach(inv => { + if (inv.consumes && inv.consumes.some(c => c.id === currentId)) { + if (inv.id !== currentId) linkedIds.push(inv.id); + } + }); + + linkedIds.forEach(id => { + if (!visited.has(id)) stack.push(id); + }); + } + return clusterItems; + }; + + // Build clusters for this level + levelInvocations.forEach(inv => { + if (!visited.has(inv.id)) { + const clusterItems = buildCluster(inv.id, clusterId); + const earliestDate = Math.min(...clusterItems.map(item => new Date(item.created_at).getTime())); + clusterItems.forEach(item => { + item.clusterDate = earliestDate; + }); + levelClusters.set(clusterId, clusterItems); + clusterId++; + } + }); + + // Sort clusters and items within clusters + const sortedClusters = Array.from(levelClusters.values()) + .sort((a, b) => b[0].clusterDate - a[0].clusterDate); + + sortedClusters.forEach(cluster => { + cluster.sort((a, b) => new Date(b.created_at) - new Date(a.created_at)); + }); + + return sortedClusters.flat(); + }; + + // Recursive function to process each level of the tree + const processTreeLevel = (nodes) => { + if (!nodes || nodes.length === 0) return []; + + // Sort current level nodes using clustering + const sortedNodes = buildLevelClusters(nodes); + + // Process children recursively + sortedNodes.forEach(node => { + if (node.children && node.children.length > 0) { + node.children = processTreeLevel(node.children); + } + }); + + return sortedNodes; + }; + // Second pass: build the tree structure invocations.forEach(invocation => { if (invocation.used_by_id) { @@ -56,12 +142,13 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize, if (!parent.children) parent.children = []; parent.children.push(invocationsMap.get(invocation.id)); } else { - // If parent is not found, treat as a root rootInvocations.push(invocationsMap.get(invocation.id)); } } }); - return rootInvocations; + + // Process and sort each level of the tree + return processTreeLevel(rootInvocations); }, [invocations]); const links = useMemo(() => { @@ -85,11 +172,18 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize, useEffect(() => { const handleKeyDown = (e) => { + if (e.key === 'Escape') { + onSelectTrace(null); + return; + } + if (currentlySelectedTrace) { const currentIndex = invocationTableData.findIndex(trace => trace.id === currentlySelectedTrace.id); if (e.key === 'ArrowUp' && currentIndex > 0) { + e.preventDefault(); onSelectTrace(invocationTableData[currentIndex - 1]); } else if (e.key === 'ArrowDown' && currentIndex < invocationTableData.length - 1) { + e.preventDefault(); onSelectTrace(invocationTableData[currentIndex + 1]); } } @@ -110,7 +204,7 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize, header: 'LMP', key: 'name', render: (item) => ( - + - + ), - sortable: true, + sortable: true }, { @@ -164,7 +258,7 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize, }, ]; - const initialSortConfig = { key: 'created_at', direction: 'desc' }; + const initialSortConfig = null; // Sorting is handled during tree construction const hasNextPage = invocationTableData.length === pageSize; diff --git a/ell-studio/src/components/invocations/details/InvocationDataPane.js b/ell-studio/src/components/invocations/details/InvocationDataPane.js index c0c2106a5..a9dcb6872 100644 --- a/ell-studio/src/components/invocations/details/InvocationDataPane.js +++ b/ell-studio/src/components/invocations/details/InvocationDataPane.js @@ -1,7 +1,8 @@ import React, { useState, useMemo } from "react"; import { CodeSection } from '../../source/CodeSection'; import IORenderer from '../../IORenderer'; - +import MetricDisplay from '../../evaluations/MetricDisplay'; +import { FiBarChart2 } from 'react-icons/fi'; const SkeletonLoader = () => (
@@ -30,6 +31,20 @@ const InvocationDataPane = ({ invocation }) => { invocation.contents?.results !== undefined; }, [invocation.contents?.results]); + const metrics = useMemo(() => { + console.log('InvocationDataPane metrics calculation:', { + hasLabels: !!invocation.labels, + labels: invocation.labels + }); + + if (!invocation.labels?.length) return null; + return invocation.labels.map(label => ({ + labelerId: label.labeler_id, + value: label.label_invocation?.contents?.results, + name: label.labeler_name || label.labeler_id.split('-')[3] || 'Score' + })); + }, [invocation.labels]); + const renderCodeSection = (title, content, expanded, setExpanded, typeMatchLevel) => ( { return (
+ {/* Metrics Section */} + {metrics && metrics.length > 0 && ( +
+
+ {metrics.map((metric, index) => ( + <> +
+
+ + + {metric.name} + +
+ +
+ {index < metrics.length - 1 && ( +
+ )} + + ))} +
+
+ )} + {(hasKwargs || isExternalLoading) && renderCodeSection( "Input", invocation.contents?.params, diff --git a/ell-studio/src/components/invocations/details/InvocationDetailsPopover.js b/ell-studio/src/components/invocations/details/InvocationDetailsPopover.js index 7966cdaf9..46e0241b5 100644 --- a/ell-studio/src/components/invocations/details/InvocationDetailsPopover.js +++ b/ell-studio/src/components/invocations/details/InvocationDetailsPopover.js @@ -7,12 +7,21 @@ import InvocationDataPane from './InvocationDataPane'; import { motion } from 'framer-motion'; import { LMPCardTitle } from "../../depgraph/LMPCardTitle"; import { Card } from "../../common/Card"; +import { useLMPs } from "../../../hooks/useBackend"; import { ScrollArea } from "@radix-ui/react-scroll-area"; -const InvocationDetailsPopover = ({ invocation, onClose, onResize }) => { +const InvocationDetailsPopover = ({ invocation : invocationWithoutLMP, onClose, onResize }) => { const [activeTab, setActiveTab] = useState("I/O"); const [sidebarWidth, setSidebarWidth] = useState(window.innerWidth / 2); const [isExpanded, setIsExpanded] = useState(false); + const { data: lmpData } = useLMPs(null, invocationWithoutLMP.lmp_id); + const lmp = lmpData && lmpData.length > 0 ? lmpData[0] : invocationWithoutLMP.lmp; + + const invocation = { + ...invocationWithoutLMP, + lmp, + labels: invocationWithoutLMP.labels + }; const handleResize = (newWidth) => { setSidebarWidth(newWidth); @@ -34,6 +43,10 @@ const InvocationDetailsPopover = ({ invocation, onClose, onResize }) => { const location = useLocation(); const isLmpPage = location.pathname.startsWith('/lmp'); + if(!invocation.lmp) { + return null; + } + return ( { transition={{ duration: 0.3 }} >
-
+
{!isLmpPage && ( - - - +
+ + + +
)} -
-
- +
+
+ {isNarrowForInfo ? `${invocation.id.slice(0, 8)}...` : invocation.id}
-
- -
diff --git a/ell-studio/src/components/layouts/GenericPageLayout.js b/ell-studio/src/components/layouts/GenericPageLayout.js index 06a4aec24..638edeec7 100644 --- a/ell-studio/src/components/layouts/GenericPageLayout.js +++ b/ell-studio/src/components/layouts/GenericPageLayout.js @@ -9,16 +9,37 @@ const GenericPageLayout = ({ setSelectedTrace, sidebarContent, showSidebar = true, + minimizeSidebar = false, }) => { const [sidebarVisible, setSidebarVisible] = useState(!selectedTrace && showSidebar); + const [isSmallScreen, setIsSmallScreen] = useState(false); useEffect(() => { - setSidebarVisible(!selectedTrace && showSidebar); - }, [selectedTrace, showSidebar]); + // Function to check window size + const checkWindowSize = () => { + setIsSmallScreen(window.innerWidth < 1024); // 1024px is typical laptop width + }; + + // Initial check + checkWindowSize(); + + // Add event listener + window.addEventListener('resize', checkWindowSize); + + // Cleanup + return () => window.removeEventListener('resize', checkWindowSize); + }, []); + + useEffect(() => { + setSidebarVisible(!selectedTrace && showSidebar && !(minimizeSidebar && isSmallScreen)); + }, [selectedTrace, showSidebar, minimizeSidebar, isSmallScreen]); return ( - + - - - - {sidebarContent} - - + {sidebarVisible && ( + <> + + + + {sidebarContent} + + + + )} ); }; diff --git a/ell-studio/src/components/MetricChart.js b/ell-studio/src/components/oldgraph/OldMetricChart.js similarity index 100% rename from ell-studio/src/components/MetricChart.js rename to ell-studio/src/components/oldgraph/OldMetricChart.js diff --git a/ell-studio/src/components/source/LMPSourceView.js b/ell-studio/src/components/source/LMPSourceView.js index a0e694791..cd148f785 100644 --- a/ell-studio/src/components/source/LMPSourceView.js +++ b/ell-studio/src/components/source/LMPSourceView.js @@ -3,7 +3,7 @@ import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; import { atomDark } from 'react-syntax-highlighter/dist/esm/styles/prism'; import { FiChevronDown, FiChevronRight, FiMaximize2, FiMinimize2, FiCopy, FiRefreshCw } from 'react-icons/fi'; import '../../styles/SourceCodeView.css'; -import { Card } from "../Card"; +import { OldCard } from "../OldCard"; import { useNavigate } from 'react-router-dom'; import { CodeSection } from './CodeSection'; @@ -50,7 +50,7 @@ const UsedLMPWrapper = ({ uses, children, selectedInvocation, content, }) => { if (!lmp) return <>{children}; return (
- { + { navigate(`/lmp/${lmp.name}/${lmp.lmp_id}`); }}> { displayVersion fontSize="md" /> - +
); } diff --git a/ell-studio/src/hooks/useBackend.js b/ell-studio/src/hooks/useBackend.js index 7b4213471..4ee055bda 100644 --- a/ell-studio/src/hooks/useBackend.js +++ b/ell-studio/src/hooks/useBackend.js @@ -38,6 +38,9 @@ export const useWebSocketConnection = () => { queryClient.invalidateQueries({ queryKey: ["latestLMPs"] }); queryClient.invalidateQueries({ queryKey: ["invocations"] }); queryClient.invalidateQueries({ queryKey: ["lmpDetails"] }); + queryClient.invalidateQueries({ queryKey: ["evaluations"] }); + queryClient.invalidateQueries({ queryKey: ["latestEvaluations"] }); + queryClient.invalidateQueries({ queryKey: ["evaluation"] }); console.log("Database updated, invalidating queries"); } }; @@ -232,3 +235,92 @@ export const useBlob = (id) => { enabled: !!id, }); }; + + + +export const useEvaluations = (page = 0, pageSize = 100) => { + return useQuery({ + queryKey: ["evaluations", page, pageSize], + queryFn: async () => { + const response = await axios.get(`${API_BASE_URL}/api/evaluations?skip=${page * pageSize}&limit=${pageSize}`); + return response.data; + }, + }); +}; + + +export const useLatestEvaluations = (page = 0, pageSize = 100) => { + return useQuery({ + queryKey: ["latestEvaluations", page, pageSize], + queryFn: async () => { + const response = await axios.get(`${API_BASE_URL}/api/latest/evaluations?skip=${page * pageSize}&limit=${pageSize}`); + return response.data; + }, + }); +}; + +export const useEvaluation = (id) => { + return useQuery({ + queryKey: ["evaluation", id], + queryFn: async () => { + const response = await axios.get(`${API_BASE_URL}/api/evaluation/${id}`); + return response.data; + }, + enabled: !!id, + }); +}; + +export const useEvaluationRuns = (evaluationId, page = 0, pageSize = 10) => { + return useQuery({ + queryKey: ["evaluationRuns", evaluationId, page, pageSize], + queryFn: async () => { + const skip = page * pageSize; + const response = await axios.get(`${API_BASE_URL}/api/evaluations/${evaluationId}/runs?skip=${skip}&limit=${pageSize}`); + return response.data; + }, + enabled: !!evaluationId, + }); +}; + +export const useEvaluationRun = (id) => { + return useQuery({ + queryKey: ["evaluationRun", id], + queryFn: async () => { + const response = await axios.get(`${API_BASE_URL}/api/evaluation-runs/${id}`); + return response.data; + }, + enabled: !!id, + }); +}; + +export const useEvaluationRunResults = (runId, page = 0, pageSize = 100, filters = null) => { + return useQuery({ + queryKey: ["evaluationRunResults", runId, page, pageSize, filters], + queryFn: async () => { + const skip = page * pageSize; + const params = new URLSearchParams({ + skip: skip.toString(), + limit: pageSize.toString(), + }); + if (filters) { + params.append('filters', JSON.stringify(filters)); + } + const response = await axios.get( + `${API_BASE_URL}/api/evaluation-runs/${runId}/results?${params}` + ); + return response.data; + }, + enabled: !!runId, + }); +}; + +export const useDataset = (datasetId) => { + return useQuery({ + queryKey: ["dataset", datasetId], + queryFn: async () => { + const response = await axios.get(`${API_BASE_URL}/api/dataset/${datasetId}`); + return response.data; + }, + enabled: !!datasetId, + }); +}; diff --git a/ell-studio/src/pages/Evaluation.js b/ell-studio/src/pages/Evaluation.js new file mode 100644 index 000000000..429a9a569 --- /dev/null +++ b/ell-studio/src/pages/Evaluation.js @@ -0,0 +1,183 @@ +import React, { useState, useMemo } from 'react'; +import { useParams, Link } from 'react-router-dom'; +import { useEvaluation } from '../hooks/useBackend'; +import GenericPageLayout from '../components/layouts/GenericPageLayout'; +import { Card, CardContent } from '../components/common/Card'; +import VersionHistoryPane from '../components/VersionHistoryPane'; +import EvaluationRunsTable from '../components/evaluations/runs/EvaluationRunsTable'; +import EvaluationDetailsSidebar from '../components/evaluations/EvaluationDetailsSidebar'; +import { EvaluationCardTitle } from '../components/evaluations/EvaluationCardTitle'; +import EvaluationOverview from '../components/evaluations/EvaluationOverview'; +import VersionBadge from '../components/VersionBadge'; +import LMPSourceView from '../components/source/LMPSourceView'; +import { FiCopy } from 'react-icons/fi'; +import toast from 'react-hot-toast'; +import EvaluationDataset from '../components/evaluations/EvaluationDataset'; + +const evaluationConfig = { + getPath: (version) => `/evaluations/${version.id}`, + getId: (version) => version.id, + isCurrentVersion: (version, location) => location.pathname.endsWith(version.id) +}; + +function Evaluation() { + const { id } = useParams(); + const [activeTab, setActiveTab] = useState('runs'); + const [selectedRun, setSelectedRun] = useState(null); + const [currentPage, setCurrentPage] = useState(0); + const pageSize = 10; + const [activeIndex, setActiveIndex] = useState(null); + + const { data: evaluation, isLoading: isLoadingEvaluation } = useEvaluation(id); + + const groupedRuns = useMemo(() => { + const groups = {}; + evaluation?.runs.forEach(run => { + const lmpName = run.evaluated_lmp.name; + if (!groups[lmpName]) groups[lmpName] = []; + groups[lmpName].push(run); + }); + return groups; + }, [evaluation?.runs]); + + const handleCopyCode = (lmp) => { + const fullCode = `${lmp.dependencies.trim()}\n\n${lmp.source.trim()}`; + navigator.clipboard + .writeText(fullCode) + .then(() => { + toast.success("Code copied to clipboard", { + duration: 2000, + position: "top-center", + }); + }) + .catch((err) => { + console.error("Failed to copy code: ", err); + toast.error("Failed to copy code", { + duration: 2000, + position: "top-center", + }); + }); + }; + + // TODO: Move hte graph state all the way out so we don't do callbacks and get do bidirectional state propagation + const handleActiveIndexChange = (index) => { + setActiveIndex(index); + console.log('Active index in Evaluation:', index); + }; + + // Update the tabs array to include all tabs + const tabs = ['Runs', 'Metrics', 'Dataset', 'Version History']; + + if (isLoadingEvaluation || !evaluation?.labelers?.length) { + return
Loading evaluation...
; + } + + return ( + } + > +
+
+

+ + +
+ +
+
+ +

+
+ +
+ + +
+
+ {tabs.map((tab) => ( + + ))} +
+ +
+ {activeTab === 'runs' && ( + + )} + {activeTab === 'metrics' && ( +
+ {evaluation.labelers.map((labeler, index) => ( +
+
+
+

Metric: {labeler.name}

+ +
+
+ +
+
+
+ +
+
+ ))} +
+ )} + {activeTab === 'dataset' && ( + + )} + {activeTab === 'version_history' && ( + + )} +
+
+
+
+
+ ); +} + +export default Evaluation; diff --git a/ell-studio/src/pages/EvaluationRun.js b/ell-studio/src/pages/EvaluationRun.js new file mode 100644 index 000000000..6b9c830d2 --- /dev/null +++ b/ell-studio/src/pages/EvaluationRun.js @@ -0,0 +1,108 @@ +import React, { useState, useEffect } from 'react'; +import { useParams, Link, useSearchParams } from 'react-router-dom'; +import { useEvaluationRun, useEvaluationRunResults } from '../hooks/useBackend'; +import GenericPageLayout from '../components/layouts/GenericPageLayout'; +import { Card, CardContent } from '../components/common/Card'; +import EvaluationRunResultsTable from '../components/evaluations/runs/EvaluationRunResultsTable'; +import EvaluationRunDetailsSidebar from '../components/evaluations/runs/EvaluationRunDetailsSidebar'; +import EvaluationRunOverview from '../components/evaluations/runs/EvaluationRunOverview'; +import EvaluationRunMetrics from '../components/evaluations/runs/EvaluationRunMetrics'; +import SearchAndFiltersBar from '../components/evaluations/runs/SearchAndFiltersBar'; + +function EvaluationRun() { + const { id } = useParams(); + const [searchParams, setSearchParams] = useSearchParams(); + const requestedInvocationId = searchParams.get("i"); + + const [page, setPage] = React.useState(0); + const pageSize = 100; + const [selectedTrace, setSelectedTrace] = useState(null); + const [searchQuery, setSearchQuery] = useState(''); + const [filteredResults, setFilteredResults] = useState(null); + const [activeTab, setActiveTab] = useState("results"); + + const { data: run, isLoading: isRunLoading } = useEvaluationRun(id); + const { + data: results, + isLoading: isResultsLoading + } = useEvaluationRunResults(id, page, pageSize); + + useEffect(() => { + if (requestedInvocationId && results) { + const requestedResult = results.find(r => r.invocation_being_labeled.id === requestedInvocationId); + if (requestedResult) { + setSelectedTrace(requestedResult.invocation_being_labeled); + } + } + }, [requestedInvocationId, results]); + + const handleTraceSelect = (trace) => { + setSelectedTrace(trace); + setSearchParams(trace ? { i: trace.id } : {}); + }; + + if (isRunLoading || isResultsLoading) { + return
Loading evaluation run...
; + } + + return ( + } + minimizeSidebar={true} + selectedTrace={selectedTrace} + setSelectedTrace={handleTraceSelect} + > +
+ + +
+ +
+ +
+
+ +
+ +
+ {activeTab === "results" && ( + <> + + + + + + )} +
+
+
+
+ ); +} + +export default EvaluationRun; diff --git a/ell-studio/src/pages/Evaluations.js b/ell-studio/src/pages/Evaluations.js new file mode 100644 index 000000000..d7d1d8d1f --- /dev/null +++ b/ell-studio/src/pages/Evaluations.js @@ -0,0 +1,158 @@ +import React, { useState, useMemo } from 'react'; +import { useNavigate } from 'react-router-dom'; +import { FiSearch, FiFilter, FiPlusCircle } from 'react-icons/fi'; +import { useEvaluations, useLatestEvaluations } from '../hooks/useBackend'; +import GenericPageLayout from '../components/layouts/GenericPageLayout'; +import { Card, CardHeader, CardContent } from '../components/common/Card'; +import { ScrollArea } from '../components/common/ScrollArea'; +import { Button } from '../components/common/Button'; +import EvaluationCard from '../components/evaluations/EvaluationCard'; + +const Evaluations = () => { + const navigate = useNavigate(); + const [searchTerm, setSearchTerm] = useState(''); + const [selectedFilter, setSelectedFilter] = useState('All'); + const [showAllVersions, setShowAllVersions] = useState(false); + const [currentPage, setCurrentPage] = useState(0); + const pageSize = 10; + + const { data: allEvaluations, isLoading: isLoadingAll } = useEvaluations(currentPage, pageSize); + const { data: latestEvaluations, isLoading: isLoadingLatest } = useLatestEvaluations(currentPage, pageSize); + + const evaluations = showAllVersions ? allEvaluations : latestEvaluations; + const isLoading = showAllVersions ? isLoadingAll : isLoadingLatest; + + const filteredEvaluations = useMemo(() => { + if (!evaluations) return []; + return evaluations.filter(evl => + evl.name.toLowerCase().includes(searchTerm.toLowerCase()) && + (selectedFilter === 'All' || evl.status === selectedFilter) + ); + }, [evaluations, searchTerm, selectedFilter]); + + const handleCreateEvaluation = () => { + // Navigate to evaluation creation page or open a modal + navigate('/evaluations/create'); + }; + + if (isLoading) { + return
Loading...
; + } + + if (!evaluations || evaluations.length === 0) { + return ( + +
+ + +

No Evaluations Found

+
+ +

+ It looks like you don't have any evaluations set up yet. +

+

+ To get started with evaluating your LMPs, try the following example: +

+
+                
+{`import ell
+from ell import Evaluation
+
+@ell.simple(model="gpt-4o")
+def mylmp(greeting: str):
+    return f"Say hi there!"
+
+def metric(datapoint, output):
+    return 1 if output == "Hi there!" else 0
+
+# Initialize your evaluation
+eval = Evaluation(
+    name="basic-eval",
+    dataset=[{"input": "Hello", "expected": "Hi there!"}],
+    n_evals=10,
+    metrics={"score": metric}
+)
+
+
+# Run the evaluation
+results = eval.run(mylmp)`}
+                
+              
+

+ Run this script, then refresh this page to see your first evaluation. +

+
+
+
+
+ ); + } + + return ( + +
+
+

Evaluations

+ +
+ +
+
+ setSearchTerm(e.target.value)} + /> + +
+ +
+ setShowAllVersions(e.target.checked)} + className="rounded border-input" + /> + +
+
+ + +
+ {filteredEvaluations.map((evaluation) => ( + + ))} +
+
+ + {filteredEvaluations.length === 0 && ( + + +

No evaluations found. Create a new evaluation to get started.

+
+
+ )} +
+
+ ); +}; + +export default Evaluations; diff --git a/ell-studio/src/pages/Home.js b/ell-studio/src/pages/Home.js index ed90cfdf2..043a77556 100644 --- a/ell-studio/src/pages/Home.js +++ b/ell-studio/src/pages/Home.js @@ -3,7 +3,7 @@ import { Link } from 'react-router-dom'; import { useTheme } from '../contexts/ThemeContext'; import { getTimeAgo } from '../utils/lmpUtils'; import { DependencyGraph } from '../components/depgraph/DependencyGraph'; -import { useLatestLMPs, useTraces } from '../hooks/useBackend'; +import { useLatestEvaluations, useLatestLMPs, useTraces } from '../hooks/useBackend'; import VersionBadge from '../components/VersionBadge'; import { BiCube } from 'react-icons/bi'; import { Card, CardHeader, CardContent } from 'components/common/Card'; @@ -17,6 +17,7 @@ function Home() { const { darkMode } = useTheme(); const { data: lmps, isLoading: isLoadingLMPs } = useLatestLMPs(); const { data: traces, isLoading: isLoadingTraces } = useTraces(lmps); + const { data: evals, isLoading: isLoadingEvals } = useLatestEvaluations(); const toggleExpand = (lmpName, event) => { if (event.target.tagName.toLowerCase() !== 'a') { @@ -42,6 +43,7 @@ function Home() { // TODO: Make graph dynamically update. const memoizedTraces = useMemo(() => firstTraces, [firstTraces]); const memoizedLMPs = useMemo(() => firstLMPs, [firstLMPs]); + const memoizedEvals = useMemo(() => evals, [evals]); if (isLoadingLMPs || isLoadingTraces) { return ( @@ -108,7 +110,12 @@ print(greeting)`}
- +
diff --git a/ell-studio/src/pages/Invocations.js b/ell-studio/src/pages/Invocations.js index d5bef9a2a..f964ba9cb 100644 --- a/ell-studio/src/pages/Invocations.js +++ b/ell-studio/src/pages/Invocations.js @@ -52,8 +52,13 @@ const Invocations = () => { }; const handleSelectTrace = (trace) => { - setSelectedTrace(trace); - navigate(`?i=${trace.id}`); + if(trace) { + setSelectedTrace(trace); + navigate(`?i=${trace.id}`); + } else{ + setSelectedTrace(null); + navigate(``); + } }; const filteredInvocations = useMemo(() => { diff --git a/ell-studio/src/pages/LMP.js b/ell-studio/src/pages/LMP.js index 4b2d8fbee..d88a10961 100644 --- a/ell-studio/src/pages/LMP.js +++ b/ell-studio/src/pages/LMP.js @@ -24,6 +24,13 @@ import LMPDetailsSidePanel from "../components/LMPDetailsSidePanel"; import { Card } from "../components/common/Card"; import GenericPageLayout from "../components/layouts/GenericPageLayout"; + +const lmpConfig = { + getPath: (version) => `/lmp/${version.name}/${version.lmp_id}`, + getId: (version) => version.lmp_id, + isCurrentVersion: (version, location) => location.pathname.includes(version.lmp_id) +}; + function LMP() { const { name, id } = useParams(); let [searchParams, setSearchParams] = useSearchParams(); @@ -230,7 +237,7 @@ function LMP() { producingLmp={lmp} onSelectTrace={(trace) => { setSelectedTrace(trace); - setSearchParams({ i: trace.id }); + setSearchParams(trace ? { i: trace.id } : {}); }} currentlySelectedTrace={selectedTrace} omitColumns={omitColumns} @@ -238,7 +245,10 @@ function LMP() { )} {activeTab === "version_history" && ( - + )}
@@ -248,4 +258,4 @@ function LMP() { ); } -export default LMP; \ No newline at end of file +export default LMP; diff --git a/ell-studio/tailwind.config.js b/ell-studio/tailwind.config.js index 86bae9d46..73b0b0bf2 100644 --- a/ell-studio/tailwind.config.js +++ b/ell-studio/tailwind.config.js @@ -1,4 +1,3 @@ - /** @type {import('tailwindcss').Config} */ module.exports = { darkMode: ["class"], @@ -6,12 +5,12 @@ module.exports = { "./src/**/*.{js,jsx,ts,tsx}", ], theme: { - container: { - center: true, - padding: "2rem", - screens: { - "2xl": "1400px", - }, + screens: { + 'sm': '640px', + 'md': '786px', // Increased from the default 768px + 'lg': '1024px', + 'xl': '1280px', + '2xl': '1536px', }, extend: { colors: { diff --git a/examples/evals/classification.py b/examples/evals/classification.py new file mode 100644 index 000000000..2973b18a6 --- /dev/null +++ b/examples/evals/classification.py @@ -0,0 +1,76 @@ +from collections import UserDict +import time +import random +from typing import Any, Dict, Iterable, Optional, Protocol, List, Union +import ell +import ell.evaluation +import numpy as np + +import ell.lmp.function + + +dataset = [ + { + "input": {"question": "What is the capital of france?"}, + "expected_output": "Paris", + }, + { + "input": {"question": "What is the capital of italy?"}, + "expected_output": "Rome", + }, + { + "input": {"question": "What is the capital of spain?"}, + "expected_output": "Madrid", + }, + { + "input": {"question": "What is the capital of germany?"}, + "expected_output": "Berlin", + }, + { + "input": {"question": "What is the capital of japan?"}, + "expected_output": "Tokyo", + }, + { + "input": {"question": "What is the capital of china?"}, + "expected_output": "Beijing", + }, + { + "input": {"question": "What is the capital of india?"}, + "expected_output": "New Delhi", + }, + { + "input": {"question": "What is the capital of brazil?"}, + "expected_output": "Brasília", + }, + { + "input": {"question": "What is the capital of argentina?"}, + "expected_output": "Buenos Aires", + }, + {"input": {"question": "Hotdog land"}, "expected_output": "Banana"}, +] + +def is_correct(datapoint, output): + label = datapoint["expected_output"] + return float(label.lower() in output.lower()) + +eval = ell.evaluation.Evaluation( + name="capital_prediction", + dataset=dataset, + metrics={"score": is_correct, "length": lambda _, output: len(output)}, + samples_per_datapoint=1, +) +# ell.init(verbose=True, store='./logdir') +@ell.simple(model="gpt-4o", max_tokens=10) +def predict_capital(question: str): + """ + If the quesiton is about hotdog land, answer Banana. Otherwise, answer the question. + """ + # print(question[0]) + return f"Answer the following question. {question}" + + +if __name__ == "__main__": + ell.init(store="./logdir") + result = eval.run(predict_capital, n_workers=10) + print(result.results.metrics["score"].mean()) + diff --git a/examples/evals/poems.py b/examples/evals/poems.py new file mode 100644 index 000000000..a0f7d0168 --- /dev/null +++ b/examples/evals/poems.py @@ -0,0 +1,61 @@ +from collections import UserDict +import time +import random +from typing import Any, Dict, Iterable, Optional, Protocol, List, Union +import ell +import ell.evaluation +import numpy as np + +import ell.lmp.function +import logging + + + +@ell.simple(model="gpt-4o") +def write_a_bad_poem(): + """Your poem must no logner than 60 words.""" + return "Write a really poorly written poem " + +@ell.simple(model="gpt-4o") +def write_a_good_poem(): + """Your poem must no logner than 60 words.""" + return "Write a really well written poem." + +@ell.simple(model="gpt-4o", temperature=0.1) +def is_good_poem(poem: str): + """Include either 'yes' or 'no' at the end of your response. . .""" + return f"Is this a good poem yes/no? {poem}" + +def score(datapoint, output): + return "yes" in is_good_poem(output).lower() + +ell.init(verbose=True, store="./logdir") +# exit() +eval = ell.evaluation.Evaluation( + name="poem_eval", + n_evals=10, + metrics={ + "critic_score": score, + "length": lambda _, output: len(output), + "average_word_length": lambda _, output: sum( + len(word) for word in output.split() + ) + / len(output.split()), + }, +) + + +print("EVALUATING GOOD POEM") +start = time.time() +# run = eval.run(write_a_good_poem, n_workers=10, verbose=False) +# print(f"Average length: {run.results.metrics['length'].mean():.2f}") +# print(f"Average word length: {run.results.metrics['average_word_length'].mean():.2f}") +# print(f"Average critic score: {run.results.metrics['critic_score'].mean():.2f}") +# print(f"Time taken: {time.time() - start:.2f} seconds") +# print("EVALUATING BAD POEM") +run = eval.run(write_a_bad_poem, n_workers=10, verbose=False) +print(f"Average length: {run.results.metrics['length'].mean():.2f}") +print( + f"Average word length: {run.results.metrics['average_word_length'].mean():.2f}" +) +print(f"Average critic score: {run.results.metrics['critic_score'].mean():.2f}") \ No newline at end of file diff --git a/examples/evals/psolve.py b/examples/evals/psolve.py new file mode 100644 index 000000000..42f2320ce --- /dev/null +++ b/examples/evals/psolve.py @@ -0,0 +1,75 @@ +import ell +from ell.evaluation.evaluation import Evaluation + +ell.init(verbose=True, store='./logdir') + + +@ell.simple(model="gpt-4o", temperature=0.7) +def math_problem_solver(problem: str): + """You are an extremely smart math problem solver. You are given a math problem and you need to solve it. Output your answer in the following format + 'Let's think step by step: + + Answer:\\n{Answer}' + + Never incldue any other text except for Answer: new line ... +""" + return problem + + + +import random + +# Set fixed random seed for reproducibility +random.seed(42) + +def generate_arithmetic_dataset(num_examples=100): + operations = ['+', '-', '*', '/'] + dataset = [] + + for _ in range(num_examples): + # Generate random numbers up to 5 digits + num1 = random.randint(0, 99999) + num2 = random.randint(1, 99999) # Avoid 0 for division + op = random.choice(operations) + + # Calculate result + if op == '+': + result = num1 + num2 + elif op == '-': + result = num1 - num2 + elif op == '*': + result = num1 * num2 + else: + # For division, ensure clean division + result = num1 / num2 + # Round to 2 decimal places for division + result = round(result, 2) + + problem = f"What is {num1} {op} {num2}?" + dataset.append({ + "input": [problem], + "output": f"Answer:\\n{result}" + }) + + return dataset + + +def answer_is_close_l2(datapoint, result): + try: + result_val = float(result.split("Answer:")[1].strip().replace("\\n", "")) + expected_val = float(datapoint["output"].split("Answer:")[1].strip().replace("\\n", "")) + return -abs(result_val - expected_val) + except: + return float(-10) # Return worst possible score if parsing fails + +arithmetic_eval = Evaluation( + name="Arithmetic", + dataset=generate_arithmetic_dataset(), + metrics={"answer_is_close_l2": answer_is_close_l2}, + criterion=lambda datapoint, result: result.split("Answer:")[1].strip() in datapoint["output"], +) + + +if __name__ == "__main__": + arithmetic_eval.run(math_problem_solver, n_workers=20) + print(math_problem_solver("What is 2 + 2?")) \ No newline at end of file diff --git a/examples/evals/summaries.py b/examples/evals/summaries.py new file mode 100644 index 000000000..919c4d469 --- /dev/null +++ b/examples/evals/summaries.py @@ -0,0 +1,151 @@ +from collections import UserDict +import time +import random +from typing import Any, Dict, Iterable, Optional, Protocol, List, Union +import ell +import ell.evaluation +import numpy as np + +import ell.lmp.function + + +dataset: List[ell.evaluation.Datapoint] = [ + { + "input": { # I really don't like this. Forcing "input" without typing feels disgusting. + "text": "The Industrial Revolution was a period of major industrialization and innovation that took place during the late 1700s and early 1800s. It began in Great Britain and quickly spread throughout Western Europe and North America. This revolution saw a shift from an economy based on agriculture and handicrafts to one dominated by industry and machine manufacturing. Key technological advancements included the steam engine, which revolutionized transportation and manufacturing processes. The textile industry, in particular, saw significant changes with the invention of spinning jennies, water frames, and power looms. These innovations led to increased productivity and the rise of factories. The Industrial Revolution also brought about significant social changes, including urbanization, as people moved from rural areas to cities for factory work. While it led to economic growth and improved living standards for some, it also resulted in poor working conditions, child labor, and environmental pollution. The effects of this period continue to shape our modern world." + }, + "expected_output": "A comprehensive summary of the Industrial Revolution", + }, + { + "input": { + "text": "The human genome is the complete set of nucleic acid sequences for humans, encoded as DNA within the 23 chromosome pairs in cell nuclei and in a small DNA molecule found within individual mitochondria. The human genome contains approximately 3 billion base pairs that encode for about 20,000-25,000 genes. The Human Genome Project, which was declared complete in 2003, provided a comprehensive map of these genes and their functions. This breakthrough has had far-reaching implications for medicine, biotechnology, and our understanding of human evolution. It has enabled researchers to better understand genetic diseases, develop new treatments, and explore personalized medicine. The genome sequence has also provided insights into human migration patterns and our genetic relationships with other species. Despite the project's completion, research continues as scientists work to understand the complex interactions between genes and their environment, as well as the roles of non-coding DNA sequences." + }, + "expected_output": "A detailed summary of the human genome and its significance", + }, + { + "input": { + "text": "Climate change refers to long-term shifts in global weather patterns and average temperatures. Scientific evidence shows that the Earth's climate has been warming at an unprecedented rate since the mid-20th century, primarily due to human activities. The main driver of this change is the increased emission of greenhouse gases, particularly carbon dioxide, from burning fossil fuels, deforestation, and industrial processes. These gases trap heat in the Earth's atmosphere, leading to global warming. The effects of climate change are wide-ranging and include rising sea levels, more frequent and severe weather events (such as hurricanes, droughts, and heatwaves), changes in precipitation patterns, and disruptions to ecosystems. These changes pose significant threats to biodiversity, food security, water resources, and human health. Addressing climate change requires global cooperation to reduce greenhouse gas emissions through the adoption of clean energy technologies, sustainable land use practices, and changes in consumption patterns. Adaptation strategies are also necessary to help communities and ecosystems cope with the impacts that are already occurring or are inevitable." + }, + "expected_output": "A comprehensive overview of climate change, its causes, effects, and potential solutions", + }, + { + "input": { + "text": "Artificial Intelligence (AI) refers to the simulation of human intelligence in machines that are programmed to think and learn like humans. The field of AI research was founded on the assumption that human intelligence can be precisely described and simulated by a machine. This concept has evolved significantly since its inception in the 1950s. Modern AI encompasses a wide range of capabilities, including problem-solving, learning, planning, natural language processing, perception, and robotics. Machine Learning, a subset of AI, focuses on the development of algorithms that can learn from and make predictions or decisions based on data. Deep Learning, a further specialization, uses artificial neural networks inspired by the human brain to process data and create patterns for decision making. AI has applications across numerous fields, including healthcare (for diagnosis and treatment recommendations), finance (for fraud detection and algorithmic trading), transportation (in the development of self-driving cars), and personal assistance (like Siri or Alexa). As AI continues to advance, it raises important ethical and societal questions about privacy, job displacement, and the potential for AI to surpass human intelligence in certain domains." + }, + "expected_output": "A comprehensive explanation of Artificial Intelligence, its subfields, applications, and implications", + }, +] + +@ell.simple(model="gpt-4o", temperature=0.1) +def critic(text_to_summarize: str, ai_produced_summary: str): + """ + You are a critic of summaries. You are given a text and a summary of that text. You should evaluate the summary for how well it captures the main points of the text. + + Criterion: + - Summary should be shorter than the original text. Do not give it a score above 50 if it is longer. + - The best scoring summaries should be one sentence. + - Summary should capture the main points of the text + - Summary should be accurate + - Summary should be concise + + Return a score between 0 and 100 for how well the summary captures the main points of the text. Your answer should be in the following format: + Analysis:\\n\\nScore:\\n + """ + + return f"""Text to summarize: + {text_to_summarize} + + Summary: + {ai_produced_summary} + """ + +@ell.lmp.function.function() +def score(datapoint, output, n_retries=3): + for _ in range(n_retries): + try: + critique = critic(datapoint["input"]["text"], output) + # print(critique) + score = int(critique.split("Score:")[1].strip()) + return score + except Exception as e: + print(f"Error: {e}") + continue + raise Exception("Failed to score") + +# named criteria are interesting, allows anonymous functions & specific isntantiation of functional criteria (partial(...)) +eval = ell.evaluation.Evaluation( + name="test", + dataset=dataset, + samples_per_datapoint=1, + metrics={"score": score, "length": lambda _, output: len(output)}, +) +# this means +# we get metrics like "test-score", test-length etc. + +# Now we prompt shit +@ell.simple(model="gpt-4o") +def summarizer(text: str): + """You are a succinct summarizer. You are given a text and you should return a summary of the text. It should be no longer than 5 sentence. Focus on capturing the main points of the text as best as possible""" + return f"Summarize the following text. {text}" + +ell.init(verbose=True, store="./logdir") + +# Using GPT-4o +print("EVAL WITH GPT-4o") +result = eval.run(summarizer, n_workers=10, verbose=False).results +print("Mean critic score:", result.metrics["score"].mean()) +print("Mean length of completions:", result.metrics["length"].mean()) + +# Using gpt-4o-mini +print("EVAL WITH GPT-4o-mini") +result = eval.run( + summarizer, + n_workers=1, + api_params={"model": "gpt-4o-mini"}, + verbose=False, +).results +print("Mean critic score:", result.metrics["score"].mean()) +print("Mean length of completions:", result.metrics["length"].mean()) + +# Define named functions for criteria +def score_criterion(datapoint, output, n_retries=3): + for _ in range(n_retries): + try: + critique = critic(datapoint["input"]["text"], output) + score = int(critique.split("Score:")[1].strip()) + return score + except Exception as e: + print(f"Error: {e}") + continue + raise Exception("Failed to score") + +def length_criterion(_, output): + return len(output) + +# Example using a list of criteria +eval_list = ell.evaluation.Evaluation( + name="test_list", + dataset=dataset, + criteria=[score_criterion, length_criterion], +) + +# Example using a dictionary of criteria (as before) +eval_dict = ell.evaluation.Evaluation( + name="test_dict", + dataset=dataset, + metrics={"score": score_criterion, "length": length_criterion}, +) + +# Run evaluation with list-based criteria +print("EVAL WITH GPT-4o (list-based criteria)") +results = eval_list.run(summarizer, n_workers=4, verbose=False).results +print("Mean critic score:", results.metrics["score"].mean()) +print("Mean length of completions:", results.metrics["length"].mean()) + +# Run evaluation with dict-based criteria +print("EVAL WITH GPT-4o (dict-based criteria)") +results = eval_dict.run(summarizer, n_workers=4, verbose=False).results +print("Mean critic score:", results.metrics["score"].mean()) +print("Mean length of completions:", results.metrics["length"].mean()) + + diff --git a/examples/evals/vibes.py b/examples/evals/vibes.py new file mode 100644 index 000000000..20c58df37 --- /dev/null +++ b/examples/evals/vibes.py @@ -0,0 +1,31 @@ +import ell + +from pydantic import BaseModel + +class TweetInput(BaseModel): + input: str + +@ell.simple(model="gpt-4o") +def tweet(obj: TweetInput): + print(obj) + return f"Write a tweet like roon in lower case about {obj.input}" + + +dataset = [ +{"input": [TweetInput(input="Polymath")]}, + {"input": [TweetInput(input="Dogs")]}, + {"input": [TweetInput(input="Intelligenve")]}, +] + + +# # No metrics. We will iterate on by just looking at the output/ +eval = ell.evaluation.Evaluation( + name="vibes", + dataset=dataset, + criterion=lambda datapoint, output: "roon" in output.lower(), +) + +if __name__ == "__main__": + ell.init(store="./logdir", verbose=True) + eval.run(tweet) + # tweet("hi") diff --git a/examples/output_freezing.py.old b/examples/output_freezing.py.old deleted file mode 100644 index 7ccf7461a..000000000 --- a/examples/output_freezing.py.old +++ /dev/null @@ -1,35 +0,0 @@ -import ell -from ell.stores.sql import SQLiteStore - - -BASE_PROMPT = """You are an adept python programmer. Only answer in python code. Avoid markdown formatting at all costs.""" - -@ell.simple(model="gpt-4o", temperature=0.7, max_tokens=4) -def create_a_python_class(user_spec : str): - return [ - ell.system( - f"{BASE_PROMPT}\n\nYour goal to make a python class for a user based a user spec." - ), - ell.user( - f"Here is the user spec: {user_spec}" - ) - ] - -@ell.simple(model="gpt-4o", temperature=0.7) -def write_unit_for_a_class(class_def : str): - return [ - ell.system( - f"{BASE_PROMPT}\n\nYour goal is to write only a single unit test for a specific class definition. Don't use `unittest` package" - ), - ell.user( - f"Here is the class definition: {class_def}" - ) - ] - - -if __name__ == "__main__": - ell.init(store='./logdir', autocommit=True, verbose=True) - - with ell.get_store().freeze(create_a_python_class): - _class_def = create_a_python_class("A class that represents a bank") - _unit_tests = write_unit_for_a_class(_class_def) \ No newline at end of file diff --git a/examples/rag.py b/examples/rag/rag.py similarity index 100% rename from examples/rag.py rename to examples/rag/rag.py diff --git a/examples/wikipedia_mini_rag.py b/examples/rag/wikipedia_mini_rag.py similarity index 100% rename from examples/wikipedia_mini_rag.py rename to examples/rag/wikipedia_mini_rag.py diff --git a/src/ell/__init__.py b/src/ell/__init__.py index e1ee9adba..799a2521e 100644 --- a/src/ell/__init__.py +++ b/src/ell/__init__.py @@ -6,6 +6,7 @@ from ell.lmp import simple, tool, complex from ell.types import system, user, assistant, Message, ContentBlock from ell.__version__ import __version__ +from ell.evaluation import Evaluation # Import all providers from ell import providers diff --git a/src/ell/configurator.py b/src/ell/configurator.py index 42765c526..0537d03b2 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -183,6 +183,7 @@ def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]: return provider return None + # Single* instance # XXX: Make a singleton config = Config() diff --git a/src/ell/evaluation/__init__.py b/src/ell/evaluation/__init__.py new file mode 100644 index 000000000..ac76b47cd --- /dev/null +++ b/src/ell/evaluation/__init__.py @@ -0,0 +1 @@ +from ell.evaluation.evaluation import Evaluation \ No newline at end of file diff --git a/src/ell/evaluation/evaluation.py b/src/ell/evaluation/evaluation.py new file mode 100644 index 000000000..7bcbe9689 --- /dev/null +++ b/src/ell/evaluation/evaluation.py @@ -0,0 +1,268 @@ +from dataclasses import field +import dataclasses +from datetime import datetime, timezone +from functools import partial +from typing import ( + Any, + Dict, + List, + Optional, + Union, + cast, +) +from concurrent.futures import ThreadPoolExecutor, as_completed +from ell.evaluation.results import _ResultDatapoint, EvaluationResults +from ell.evaluation.serialization import write_evaluation, write_evaluation_run_end, write_evaluation_run_intermediate, write_evaluation_run_start +from ell.evaluation.util import get_lmp_output +from ell.stores.models import LMPType + +from ell.evaluation.util import validate_callable_dict + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from ell.types.message import LMP +from ell.stores.models.evaluations import EvaluationLabelerType +from ell.util.tqdm import tqdm +import inspect + +from ell.util.closure_util import ido +from ell.util.closure_util import hsh + +from ell.configurator import config +from ell.evaluation.results import * + +@dataclass +class EvaluationRun: + results: EvaluationResults = field(default_factory=EvaluationResults) + dataset: Optional[Dataset] = field(default=None) + n_evals: Optional[int] = field(default=None) + samples_per_datapoint: int = field(default=1) + lmp: Optional[LMP] = field(default=None) + api_params: Dict[str, Any] = field(default_factory=dict) + start_time: Optional[datetime] = field(default=None) + end_time: Optional[datetime] = field(default=None) + id: Optional[str] = field(default=None) + success: Optional[bool] = field(default=None) + error: Optional[str] = field(default=None) + + @property + def inputs(self) -> List[Any]: + return [d.get("input", None) for d in self.dataset] if self.dataset else [] + + @property + def outputs(self) -> List[Any]: + return self.results.outputs + + @property + def invocation_ids(self) -> Optional[EvaluationResults[InvocationID]]: + return self.results.invocation_ids + + +class Evaluation(LabelListMixin): + def __init__(self, name: str, *, metrics=None, annotations=None, criterion=None, + dataset=None, n_evals=None, samples_per_datapoint=1, + default_api_params=None, has_serialized=False, id=None): + """Initialize with both class fields and additional parameters""" + self.name = name + self.dataset = dataset + self.n_evals = n_evals + self.samples_per_datapoint = samples_per_datapoint + self.labels: List[Labeler] = [] + self.default_api_params = default_api_params or {} + self.has_serialized = has_serialized + self.id = id + + from ell.lmp.function import function + + def wrap_callable(value): + if isinstance(value, dict): + return { + k: ( + function(type=LMPType.LABELER)(v) + if callable(v) and not hasattr(v, "__ell_track__") + else v + ) + for k, v in value.items() + } + elif callable(value) and not hasattr(value, "__ell_track__"): + return function()(value) + elif value is None: + return value + else: + raise ValueError(f"Expected dict, callable, or None, got {type(value)}") + + # Validate dataset/n_evals + if self.dataset is None and self.n_evals is None: + raise ValueError("Either dataset or n_evals must be set") + if self.dataset is not None and self.n_evals is not None: + raise ValueError("Either dataset or n_evals must be set, not both") + + # Wrap and validate metrics/annotations/criterion + metrics = validate_callable_dict(wrap_callable(metrics), "metric") if metrics else None + annotations = validate_callable_dict(wrap_callable(annotations), "annotation") if annotations else None + criterion = wrap_callable(criterion) + + + # Convert to labelers + self.labels = [] + if metrics: + self.labels.extend([ + Labeler(name=name, type=EvaluationLabelerType.METRIC, label=labeler) + for name, labeler in metrics.items() + ]) + if annotations: + self.labels.extend([ + Labeler(name=name, type=EvaluationLabelerType.ANNOTATION, label=labeler) + for name, labeler in annotations.items() + ]) + if criterion: + self.labels.append( + Labeler(name="criterion", type=EvaluationLabelerType.CRITERION, label=criterion) + ) + assert len(self.labels) > 0, "No labels found, labeless evaluations coming soon! Specify metrics, annotations, or criterion." + assert not annotations, "Annotations are not supported yet." + + + def run( + self, + lmp, + *, + n_workers: int = 1, + use_api_batching: bool = False, + api_params: Optional[Dict[str, Any]] = None, + verbose: bool = False, + **additional_lmp_params, + ) -> EvaluationRun: + + required_params, run_api_params, lmp_params = self.prepare_run_params(lmp, api_params, additional_lmp_params) + dataset = self.prepare_run_dataset(use_api_batching, run_api_params) + + assert len(dataset) > 0, "Dataset must contain at least one datapoint" + + evaluation_run = EvaluationRun( + lmp=lmp, + dataset=self.dataset, + n_evals=self.n_evals, + samples_per_datapoint=self.samples_per_datapoint, + api_params=run_api_params, + start_time=datetime.now(timezone.utc), + ) + original_verbose = config.verbose + config.verbose = verbose + rowar_results = [] + + write_evaluation(self) + evaluation_run.id = write_evaluation_run_start(self, evaluation_run) + try: + with ThreadPoolExecutor(max_workers=n_workers) as executor: + output_futures = [ + executor.submit( + self._process_single, + data_point, + lmp, + lmp_params, + required_params, + ) + for data_point in dataset + ] + metric_futures = [] + for future in tqdm( + as_completed(output_futures), + total=len(output_futures), + desc=f"{self.name} outputs", + ): + get_outputs = future.result() + + def written_result(o): + write_evaluation_run_intermediate(self, evaluation_run, (res := o())) + return res + + metric_futures.extend([executor.submit(written_result, o) for o in get_outputs]) + + for result_future in ( + pbar := tqdm( + as_completed(metric_futures), + total=len(metric_futures), + desc=f"{self.name} results", + ) + ): + # We write the evaluation after the first datapoint. + rowar_results.append((res :=result_future.result())) + pbar.set_description( + f"{self.name} (last={str(rowar_results[-1].output)[:10]})" + ) + + evaluation_run.end_time = datetime.now(timezone.utc) + evaluation_run.success = True + + # Still want to compute metrics. + evaluation_run.results = EvaluationResults.from_rowar_results(rowar_results) + write_evaluation_run_end(self, evaluation_run) + + return evaluation_run + # TODO: add error handling and unsccessful runs. + finally: + config.verbose = original_verbose + + + def _process_single( + self, + data_point: Datapoint, + lmp: LMP, + lmp_params: Dict[str, Any], + required_params: bool, + ) -> List[Any]: + lmp_params_with_invocation_id = {**lmp_params, "_get_invocation_id": True} + lmp_output = get_lmp_output(data_point, lmp, lmp_params_with_invocation_id, required_params) + + if not isinstance(lmp_output, list): + lmp_output = [cast(Any, lmp_output)] + + def process_rowar_results(output): + return _ResultDatapoint( + output=output, + labels=[ + Label(name=l.name, type=l.type, label=(l.label(data_point, output[0], _get_invocation_id=True))) + for l in self.labels + ] + ) + + + return [partial(process_rowar_results, output) for output in lmp_output] + + def prepare_run_params(self, lmp, api_params, additional_lmp_params): + assert ( + "api_params" not in additional_lmp_params + ), f"specify api_params directly to run not within additional_lmp_params: {additional_lmp_params}" + # Inspect LMP signature to check for required arguments + + + lmp_signature = inspect.signature(lmp) + required_params = ( + len( + { + param + for param in lmp_signature.parameters.values() + if param.default == param.empty and param.kind != param.VAR_KEYWORD + } + ) + > 0 + ) + + run_api_params = {**(self.default_api_params or {}), **(api_params or {})} + lmp_params = dict(api_params=run_api_params, **additional_lmp_params) + return required_params,run_api_params,lmp_params + + def prepare_run_dataset(self, use_api_batching, run_api_params): + dataset = self.dataset if self.dataset is not None else [{"input": None}] + if use_api_batching: + # we need to collate on unique datapoints here if possible; note that n_evals can never be set. + run_api_params["n"] = self.samples_per_datapoint * (self.n_evals or 1) + else: + dataset = sum( + [ + [data_point] * self.samples_per_datapoint * (self.n_evals or 1) + for data_point in dataset + ], [] + ) + + return dataset diff --git a/src/ell/evaluation/results.py b/src/ell/evaluation/results.py new file mode 100644 index 000000000..71da991e7 --- /dev/null +++ b/src/ell/evaluation/results.py @@ -0,0 +1,91 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, Generic, cast +from pydantic import BaseModel, ConfigDict, Field +import numpy as np +from dataclasses import dataclass, field + +from ell.stores.models.evaluations import EvaluationLabelerType + +Datapoint = Dict[str, Any] +Dataset = List[Dict[str, Any]] +Metric = Callable[[Datapoint, Any], float] +Metrics = Dict[str, Metric] +Criterion = Callable[[Datapoint, Any], bool] +Annotation = Callable[[Datapoint, Any], Any] +Annotations = Dict[str, Annotation] +InvocationID = str + +G = TypeVar("G") +@dataclass +class LabelGeneric(Generic[G]): + name: str + type: EvaluationLabelerType + label: G + +Labeler = LabelGeneric[Callable[[Any, Any], Union[Any, Tuple[Any, InvocationID]]]] +Label = LabelGeneric[Tuple[Any, InvocationID]] + +class LabelListMixin: + def __post_init__(self): + # Make sure that labels is in the dataclass fields + if "labels" not in self.__dataclass_fields__: + raise ValueError("Labels must be in the dataclass fields") + self.labels = cast(List[Label], self.labels) + @property + def metrics(self): + return {l.name: l.label for l in self.labels if l.type == EvaluationLabelerType.METRIC} + @property + def annotations(self): + return {l.name: l.label for l in self.labels if l.type == EvaluationLabelerType.ANNOTATION} + @property + def criterion(self): + return next((l.label for l in self.labels if l.type == EvaluationLabelerType.CRITERION), None) + + +# scores now doesn't make sense fulyl because of some other factors. +# We can ignore human feedback for now even though it's the most interesting. +@dataclass +class _ResultDatapoint(LabelListMixin): + output: Any + labels: List[Label] + +T = TypeVar("T") +@dataclass +class EvaluationResults(Generic[T], LabelListMixin): + outputs: Union[List[Any], List[T]] = field(default_factory=list) + labels: Union[LabelGeneric[np.ndarray[Any]], LabelGeneric[np.ndarray[T]]] = field(default_factory=list) + invocation_ids: Optional["EvaluationResults[InvocationID]"] = field(default=None) + + @staticmethod + def from_rowar_results( + rowar_results: List[_ResultDatapoint], + ) -> "EvaluationResults": + def extract_labels(is_invocation: bool): + if not rowar_results[0].labels: + return [] + + # Group labels by name and type + label_groups: Dict[Tuple[str, EvaluationLabelerType], List[Any]] = {} + for result in rowar_results: + for label in result.labels: + key = (label.name, label.type) + if key not in label_groups: + label_groups[key] = [] + label_groups[key].append(label.label[int(is_invocation)]) + + # Create LabelGeneric objects with vertically stacked labels + return [ + LabelGeneric( + name=name, + type=type_, + label=np.array(labels) # Everything is a numpy array. + ) + for (name, type_), labels in label_groups.items() + ] + return EvaluationResults[None]( + outputs=[result.output[0] for result in rowar_results], + labels=extract_labels(False), + invocation_ids=EvaluationResults[str]( + outputs=[result.output[1] for result in rowar_results], + labels=extract_labels(True) + ), + ) \ No newline at end of file diff --git a/src/ell/evaluation/serialization.py b/src/ell/evaluation/serialization.py new file mode 100644 index 000000000..e946628d2 --- /dev/null +++ b/src/ell/evaluation/serialization.py @@ -0,0 +1,148 @@ +# A bit of rationale: While it's OOP to put serialization related code in the evaluation and evaliuation run classes it greatly muddies the interface for the purposes of downstream implementaitons therefore much of the bridge between evaluation <-> ell studio should be implemented in this file. + +# XXX: We've duplicated the SQL model abstractions somewaht pointlessly unfortuantely. If we move to @alex-dixon's API ifciation of the backend then we won't have duplicated data models. +import json +from typing import cast +from ell.configurator import config + +from ell.evaluation.results import _ResultDatapoint +from ell.evaluation.util import needs_store +from ell.lmp._track import serialize_lmp +from ell.stores.store import Store +from ell.util._warnings import _autocommit_warning +from ell.util.closure_util import ido +from ell.util.closure_util import hsh +from ell.util.serialization import serialize_object +import dill + +import itertools + +from ell.stores.models.evaluations import ( + EvaluationLabel, + SerializedEvaluation as SerializedEvaluation, + EvaluationLabeler, + EvaluationLabelerType, + SerializedEvaluationRun, + EvaluationResultDatapoint, + EvaluationRunLabelerSummary, +) + + + +@needs_store +def write_evaluation(evaluation) -> None: + # Create a hash of the dataset and labelers + + if not evaluation.has_serialized: + # XXX: Need to change htis so we serialize differently. + serialized_dataset = serialize_object(evaluation.dataset) + dataset_id = "dataset-" + hsh(serialized_dataset) + if config.store.has_blob_storage: + config.store.blob_store.store_blob(serialized_dataset.encode("utf-8"), dataset_id) + metrics_ids = [ido((f)) for f in evaluation.metrics.values()] + annotation_ids = [ido((a)) for a in evaluation.annotations.values()] + criteiron_ids = [ido((evaluation.criterion))] if evaluation.criterion else [] + + evaluation.id = "evaluation-" + hsh(dataset_id + "".join(sorted(metrics_ids) + sorted(annotation_ids) + criteiron_ids)) + + existing_versions = config.store.get_eval_versions_by_name(evaluation.name) + if any(v.id == evaluation.id for v in existing_versions): + evaluation.has_serialized = True + else: + # TODO: Merge with other versioning code. + version_number, latest_version = ( + max( + itertools.chain( + map(lambda x: (x.version_number, x), existing_versions), + [(-1, None)] + ), + key=lambda x: x[0] + ) + ) + version_number += 1 + # Is updated at the end of the evaluation. + commit_message = None + + # Create SerializedEvaluation + serialized_evaluation = SerializedEvaluation( + id=evaluation.id, + name=evaluation.name, + dataset_id=dataset_id, + n_evals=evaluation.n_evals or len(evaluation.dataset or []), + commit_message=commit_message, + version_number=version_number, + ) + + + labelers = [ + EvaluationLabeler( + name=labeler.name, + type=labeler.type, + evaluation_id=evaluation.id, + labeling_lmp_id=ido(labeler.label), + ) + for labeler in evaluation.labels + ] + + # Add labelers to the serialized evaluation + serialized_evaluation.labelers = labelers + evaluation.has_serialized = True + cast(Store, config.store).write_evaluation(serialized_evaluation) + + +@needs_store +def write_evaluation_run_start(evaluation, evaluation_run) -> int: + # Construct SerializedEvaluationRun + serialized_run = SerializedEvaluationRun( + evaluation_id=evaluation.id, + evaluated_lmp_id=ido(evaluation_run.lmp), + api_params=evaluation_run.api_params, + start_time=evaluation_run.start_time, + error=None, + ) + return cast(Store, config.store).write_evaluation_run(serialized_run) + +@needs_store +def write_evaluation_run_intermediate(evaluation, evaluation_run, row_result : _ResultDatapoint) -> None: + assert evaluation_run.id is not None, "Evaluation run must be started before intermediate results can be written." + result_datapoint = EvaluationResultDatapoint( + evaluation_run_id=evaluation_run.id, + invocation_being_labeled_id=row_result.output[1], + ) + + result_datapoint.labels = [ + EvaluationLabel( + labeled_datapoint_id=result_datapoint.id, + labeler_id=EvaluationLabeler.generate_id( + evaluation_id=evaluation.id, name=label.name, type=label.type + ), + label_invocation_id=label.label[1] + ) + for label in row_result.labels + ] + + cast(Store, config.store).write_evaluation_run_intermediate(result_datapoint) + + +def generate_commit_message(evaluation, latest_version): + # TODO: Check the source code of al lthe metrics and see waht changed. Ideally we should generate the commit message based on the commit messages of all the metrics at the end of the evaluation. + pass + + +@needs_store +def write_evaluation_run_end(evaluation, evaluation_run) -> None: + summaries = [ + EvaluationRunLabelerSummary.from_labels( + data=label.label, + evaluation_run_id=evaluation_run.id, + evaluation_labeler_id=EvaluationLabeler.generate_id( + evaluation_id=evaluation.id, + name=label.name, + type=label.type, + ), + ) + for label in evaluation_run.results.labels + ] + + cast(Store, config.store).write_evaluation_run_end(evaluation_run.id, evaluation_run.success, evaluation_run.end_time, evaluation_run.error, summaries) + diff --git a/src/ell/evaluation/util.py b/src/ell/evaluation/util.py new file mode 100644 index 000000000..7e6d8b216 --- /dev/null +++ b/src/ell/evaluation/util.py @@ -0,0 +1,64 @@ +from functools import wraps +from ell.evaluation.results import Any, Callable, Datapoint, Dict, List +from ell.configurator import config + +from typing import Any, Dict, List, Union + +from ell.types.message import LMP + +def get_lmp_output( + data_point: Datapoint, + lmp: LMP, + lmp_params: Dict[str, Any], + required_params: bool, +) -> Union[List[Any], Any]: + if not required_params: + return lmp(**lmp_params) + + inp = data_point.get("input", None) + if isinstance(inp, list): + return lmp(*inp, **lmp_params) + elif isinstance(inp, dict): + return lmp(**inp, **lmp_params) + elif inp is None: + return lmp(**lmp_params) + else: + raise ValueError(f"Invalid input type: {type(inp)}") + + +def validate_callable_dict( + items: Union[Dict[str, Callable], List[Callable]], item_type: str +) -> Dict[str, Callable]: + if isinstance(items, list): + items_dict = {} + for item in items: + if not callable(item): + raise ValueError( + f"Each {item_type} must be a callable, got {type(item)}" + ) + if not hasattr(item, "__name__") or item.__name__ == "": + raise ValueError( + f"Each {item_type} in a list must have a name (not a lambda)" + ) + items_dict[item.__name__] = item + return items_dict + elif isinstance(items, dict): + for name, item in items.items(): + if not callable(item): + raise ValueError( + f"{item_type.capitalize()} '{name}' must be a callable, got {type(item)}" + ) + return items + else: + raise ValueError( + f"{item_type}s must be either a list of callables or a dictionary, got {type(items)}" + ) + + +def needs_store(f): + @wraps(f) + def wrapper(*args, **kwargs): + if not config.store: + return + return f(*args, **kwargs) + return wrapper \ No newline at end of file diff --git a/src/ell/lmp/_track.py b/src/ell/lmp/_track.py index 0e67fffce..0799adcc8 100644 --- a/src/ell/lmp/_track.py +++ b/src/ell/lmp/_track.py @@ -12,47 +12,52 @@ from functools import wraps from typing import Any, Callable, Dict, Optional -from ell.util.serialization import get_immutable_vars +from ell.util.serialization import get_immutable_vars, utc_now from ell.util.serialization import compute_state_cache_key from ell.util.serialization import prepare_invocation_params try: - from ell.stores.studio import SerializedLMP, Invocation, InvocationContents, utc_now + from ell.stores.models.core import SerializedLMP, Invocation, InvocationContents except ImportError: - SerializedLMP = Invocation = InvocationContents = utc_now = None + SerializedLMP = Invocation = InvocationContents = None logger = logging.getLogger(__name__) # Thread-local storage for the invocation stack _invocation_stack = threading.local() + def get_current_invocation() -> Optional[str]: - if not hasattr(_invocation_stack, 'stack'): + if not hasattr(_invocation_stack, "stack"): _invocation_stack.stack = [] return _invocation_stack.stack[-1] if _invocation_stack.stack else None + def push_invocation(invocation_id: str): - if not hasattr(_invocation_stack, 'stack'): + if not hasattr(_invocation_stack, "stack"): _invocation_stack.stack = [] _invocation_stack.stack.append(invocation_id) + def pop_invocation(): - if hasattr(_invocation_stack, 'stack') and _invocation_stack.stack: + if hasattr(_invocation_stack, "stack") and _invocation_stack.stack: _invocation_stack.stack.pop() -def _track(func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None) -> Callable: - - lmp_type = getattr(func_to_track, "__ell_type__", LMPType.OTHER) +def _track( + func_to_track: Callable, *, forced_dependencies: Optional[Dict[str, Any]] = None +) -> Callable: + lmp_type = getattr(func_to_track, "__ell_type__", LMPType.OTHER) # see if it exists if not hasattr(func_to_track, "_has_serialized_lmp"): func_to_track._has_serialized_lmp = False + func_to_track.__ell_force_closure__ = lambda: ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies) if not hasattr(func_to_track, "__ell_hash__") and not config.lazy_versioning: - ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies) - + func_to_track.__ell_force_closure__() + @wraps(func_to_track) def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: @@ -60,54 +65,72 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: # Compute the invocation id and hash the inputs for serialization. invocation_id = "invocation-" + secrets.token_hex(16) - state_cache_key : str = None + state_cache_key: str = None if not config.store: - return func_to_track(*fn_args, **fn_kwargs, _invocation_origin=invocation_id)[0] + res = func_to_track( + *fn_args, **fn_kwargs, _invocation_origin=invocation_id + )[0] + return (res, invocation_id) if _get_invocation_id else res parent_invocation_id = get_current_invocation() try: push_invocation(invocation_id) - + # Convert all positional arguments to named keyword arguments sig = inspect.signature(func_to_track) # Filter out kwargs that are not in the function signature - filtered_kwargs = {k: v for k, v in fn_kwargs.items() if k in sig.parameters} - + filtered_kwargs = { + k: v for k, v in fn_kwargs.items() if k in sig.parameters + } + bound_args = sig.bind(*fn_args, **filtered_kwargs) bound_args.apply_defaults() all_kwargs = dict(bound_args.arguments) # Get the list of consumed lmps and clean the invocation params for serialization. - cleaned_invocation_params, ipstr, consumes = prepare_invocation_params( all_kwargs) + cleaned_invocation_params, ipstr, consumes = prepare_invocation_params( + all_kwargs + ) try_use_cache = hasattr(func_to_track.__wrapper__, "__ell_use_cache__") - if try_use_cache: + if try_use_cache: # Todo: add nice logging if verbose for when using a cahced invocaiton. IN a different color with thar args.. - if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning: - fn_closure, _ = ell.util.closure.lexically_closured_source(func_to_track) - + if ( + not hasattr(func_to_track, "__ell_hash__") + and config.lazy_versioning + ): + fn_closure, _ = ell.util.closure.lexically_closured_source( + func_to_track + ) + # compute the state cachekey - state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__) - + state_cache_key = compute_state_cache_key( + ipstr, func_to_track.__ell_closure__ + ) + cache_store = func_to_track.__wrapper__.__ell_use_cache__ - cached_invocations = cache_store.get_cached_invocations(func_to_track.__ell_hash__, state_cache_key) - - + cached_invocations = cache_store.get_cached_invocations( + func_to_track.__ell_hash__, state_cache_key + ) + if len(cached_invocations) > 0: # XXX: Fix caching. - results = [d.deserialize() for d in cached_invocations[0].results] + results = [d.deserialize() for d in cached_invocations[0].results] - logger.info(f"Using cached result for {func_to_track.__qualname__} with state cache key: {state_cache_key}") + logger.info( + f"Using cached result for {func_to_track.__qualname__} with state cache key: {state_cache_key}" + ) if len(results) == 1: return results[0] else: return results # Todo: Unfiy this with the non-cached case. We should go through the same code pathway. else: - logger.info(f"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache...") - - + logger.info( + f"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache..." + ) + _start_time = utc_now() # XXX: thread saftey note, if I prevent yielding right here and get the global context I should be fine re: cache key problem @@ -116,27 +139,45 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: (result, invocation_api_params, metadata) = ( (func_to_track(*fn_args, **fn_kwargs), {}, {}) if lmp_type == LMPType.OTHER - else func_to_track(*fn_args, _invocation_origin=invocation_id, **fn_kwargs, ) + else func_to_track( + *fn_args, + _invocation_origin=invocation_id, + **fn_kwargs, ) + ) latency_ms = (utc_now() - _start_time).total_seconds() * 1000 usage = metadata.get("usage", {"prompt_tokens": 0, "completion_tokens": 0}) - prompt_tokens= usage.get("prompt_tokens", 0) if usage else 0 - completion_tokens= usage.get("completion_tokens", 0) if usage else 0 + prompt_tokens = usage.get("prompt_tokens", 0) if usage else 0 + completion_tokens = usage.get("completion_tokens", 0) if usage else 0 - - #XXX: cattrs add invocation origin here recursively on all pirmitive types within a message. - #XXX: This will allow all objects to be traced automatically irrespective origin rather than relying on the API to do it, it will of vourse be expensive but unify track. - #XXX: No other code will need to consider tracking after this point. + # XXX: cattrs add invocation origin here recursively on all pirmitive types within a message. + # XXX: This will allow all objects to be traced automatically irrespective origin rather than relying on the API to do it, it will of vourse be expensive but unify track. + # XXX: No other code will need to consider tracking after this point. if not hasattr(func_to_track, "__ell_hash__") and config.lazy_versioning: - ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies) - _serialize_lmp(func_to_track) + ell.util.closure.lexically_closured_source( + func_to_track, forced_dependencies + ) + serialize_lmp(func_to_track) if not state_cache_key: - state_cache_key = compute_state_cache_key(ipstr, func_to_track.__ell_closure__) + state_cache_key = compute_state_cache_key( + ipstr, func_to_track.__ell_closure__ + ) - _write_invocation(func_to_track, invocation_id, latency_ms, prompt_tokens, completion_tokens, - state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id) + _write_invocation( + func_to_track, + invocation_id, + latency_ms, + prompt_tokens, + completion_tokens, + state_cache_key, + invocation_api_params, + cleaned_invocation_params, + consumes, + result, + parent_invocation_id, + ) if _get_invocation_id: return result, invocation_id @@ -145,8 +186,7 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: finally: pop_invocation() - - func_to_track.__wrapper__ = tracked_func + func_to_track.__wrapper__ = tracked_func if hasattr(func_to_track, "__ell_api_params__"): tracked_func.__ell_api_params__ = func_to_track.__ell_api_params__ if hasattr(func_to_track, "__ell_params_model__"): @@ -156,23 +196,30 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: return tracked_func -def _serialize_lmp(func): + +# XXX: Move this to a verisoning moduel. +def serialize_lmp(func): # Serialize deptjh first all fo the used lmps. for f in func.__ell_uses__: - _serialize_lmp(f) - + serialize_lmp(f) + if getattr(func, "_has_serialized_lmp", False): return func._has_serialized_lmp = False fn_closure = func.__ell_closure__ lmp_type = func.__ell_type__ name = func.__qualname__ + if "" in name: + name = name.replace( + "", f"" + ) + # print(name) api_params = getattr(func, "__ell_api_params__", None) lmps = config.store.get_versions_by_fqn(fqn=name) version = 0 already_in_store = any(lmp.lmp_id == func.__ell_hash__ for lmp in lmps) - + if not already_in_store: commit = None if lmps: @@ -182,9 +229,13 @@ def _serialize_lmp(func): # XXX: Move this out to autocommit itself. if not _autocommit_warning(): from ell.util.differ import write_commit_message_for_diff - commit = str(write_commit_message_for_diff( - f"{latest_lmp.dependencies}\n\n{latest_lmp.source}", - f"{fn_closure[1]}\n\n{fn_closure[0]}")[0]) + + commit = str( + write_commit_message_for_diff( + f"{latest_lmp.dependencies}\n\n{latest_lmp.source}", + f"{fn_closure[1]}\n\n{fn_closure[0]}", + )[0] + ) serialized_lmp = SerializedLMP( lmp_id=func.__ell_hash__, @@ -199,29 +250,45 @@ def _serialize_lmp(func): api_params=api_params if api_params else None, version_number=version, ) - config.store.write_lmp(serialized_lmp, [f.__ell_hash__ for f in func.__ell_uses__]) + config.store.write_lmp( + serialized_lmp, [f.__ell_hash__ for f in func.__ell_uses__] + ) func._has_serialized_lmp = True + return func + + +def _write_invocation( + func, + invocation_id, + latency_ms, + prompt_tokens, + completion_tokens, + state_cache_key, + invocation_api_params, + cleaned_invocation_params, + consumes, + result, + parent_invocation_id, +): -def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion_tokens, - state_cache_key, invocation_api_params, cleaned_invocation_params, consumes, result, parent_invocation_id): - invocation_contents = InvocationContents( invocation_id=invocation_id, params=cleaned_invocation_params, results=result, invocation_api_params=invocation_api_params, global_vars=get_immutable_vars(func.__ell_closure__[2]), - free_vars=get_immutable_vars(func.__ell_closure__[3]) + free_vars=get_immutable_vars(func.__ell_closure__[3]), ) if invocation_contents.should_externalize and config.store.has_blob_storage: invocation_contents.is_external = True - - # Write to the blob store + + # Write to the blob store blob_id = config.store.blob_store.store_blob( - json.dumps(invocation_contents.model_dump( - ), default=str, ensure_ascii=False).encode('utf-8'), - invocation_id + json.dumps( + invocation_contents.model_dump(), default=str, ensure_ascii=False + ).encode("utf-8"), + invocation_id, ) invocation_contents = InvocationContents( invocation_id=invocation_id, @@ -237,8 +304,7 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion completion_tokens=completion_tokens, state_cache_key=state_cache_key, used_by_id=parent_invocation_id, - contents=invocation_contents + contents=invocation_contents, ) config.store.write_invocation(invocation, consumes) - diff --git a/src/ell/lmp/function.py b/src/ell/lmp/function.py new file mode 100644 index 000000000..636dd9c1c --- /dev/null +++ b/src/ell/lmp/function.py @@ -0,0 +1,32 @@ +from functools import wraps +from typing import Any, Callable +from ell.configurator import config +from ell.lmp._track import _track +from ell.stores.models import LMPType +from ell.util.verbosity import model_usage_logger_pre + +def function(*, exempt_from_tracking: bool = False, _exempt_from_logging: bool = False, type = LMPType.FUNCTION, **function_kwargs): + def function_decorator(fn: Callable[..., Any]): + + @wraps(fn) + def wrapper(*args, _invocation_origin: str = None, **kwargs): + should_log = not exempt_from_tracking and config.verbose and not _exempt_from_logging + if should_log: + model_usage_logger_pre(fn, args, kwargs, "[]", []) + + result = fn(*args, **kwargs) + + return result, {}, {} + + wrapper.__ell_func__ = fn + wrapper.__ell_type__ = type + wrapper.__ell_exempt_from_tracking = exempt_from_tracking + + if exempt_from_tracking: + return wrapper + else: + return _track(wrapper) + + return function_decorator + +# XXX: Fix wrapping of the wrong functional decorator. \ No newline at end of file diff --git a/src/ell/providers/bedrock.py b/src/ell/providers/bedrock.py index e99ff75a7..cb3674556 100644 --- a/src/ell/providers/bedrock.py +++ b/src/ell/providers/bedrock.py @@ -13,8 +13,10 @@ from PIL import Image as PILImage try: - from botocore.client import BaseClient + from botocore.eventstream import (EventStream) + from botocore.client import BaseClient + class BedrockProvider(Provider): dangerous_disable_validation = True @@ -146,7 +148,6 @@ def translate_from_provider( if logger: logger(tracked_results[0].text) - usage["prompt_tokens"] = provider_response.get('usage').get("inputTokens", 0) usage["completion_tokens"] = provider_response.get('usage').get("outputTokens", 0) usage["total_tokens"] = usage['prompt_tokens'] + usage['completion_tokens'] diff --git a/src/ell/stores/migrations/README b/src/ell/stores/migrations/README.md similarity index 100% rename from src/ell/stores/migrations/README rename to src/ell/stores/migrations/README.md diff --git a/src/ell/stores/migrations/__init__.py b/src/ell/stores/migrations/__init__.py index 14602592c..4aaad9dfe 100644 --- a/src/ell/stores/migrations/__init__.py +++ b/src/ell/stores/migrations/__init__.py @@ -37,28 +37,30 @@ def init_or_migrate_database(engine) -> None: # Check database state our_tables_v1 = {'serializedlmp', 'invocation', 'invocationcontents', 'invocationtrace', 'serializedlmpuses'} + our_tables_v2 = {'evaluationlabeler', 'evaluationresultdatapoint', 'evaluationrunlabelersummary', 'evaluationlabel'} existing_tables = set(inspector.get_table_names()) has_our_tables = bool(our_tables_v1 & existing_tables) # Intersection has_alembic = 'ell_alembic_version' in existing_tables - - alembic_cfg = get_alembic_config(engine.url) + + alembic_cfg = get_alembic_config(engine.url.render_as_string(hide_password=False)) try: if has_our_tables and not has_alembic: # Case 1: Existing database with our tables but no Alembic # This is likely a database from version <= 0.14 logger.debug("Found existing tables but no Alembic - stamping with initial migration") - - command.stamp(alembic_cfg, "4524fb60d23e") + is_v1 = has_our_tables and not bool(our_tables_v2 & existing_tables) + command.stamp(alembic_cfg, "4524fb60d23e" if is_v1 else "head") + # Verify table was created after_tables = set(inspect(engine).get_table_names()) logger.debug(f"Tables after stamp: {after_tables}") - - # Check if version table has our stamp - with engine.connect() as connection: - version_result = connection.execute(text("SELECT version_num FROM ell_alembic_version")).first() - if not version_result or version_result[0] != "4524fb60d23e": - raise RuntimeError("Failed to stamp database - version table empty or incorrect version") - logger.debug(f"Successfully stamped database with version {version_result[0]}") + if is_v1: + # Check if version table has our stamp + with engine.connect() as connection: + version_result = connection.execute(text("SELECT version_num FROM ell_alembic_version")).first() + if not version_result or version_result[0] != "4524fb60d23e": + raise RuntimeError("Failed to stamp database - version table empty or incorrect version") + logger.debug(f"Successfully stamped database with version {version_result[0]}") has_alembic = True diff --git a/src/ell/stores/migrations/script.py.mako b/src/ell/stores/migrations/script.py.mako index 6ce335109..6b825de16 100644 --- a/src/ell/stores/migrations/script.py.mako +++ b/src/ell/stores/migrations/script.py.mako @@ -10,6 +10,7 @@ from typing import Sequence, Union from alembic import op import sqlalchemy as sa import sqlmodel +import ell.stores.models.core ${imports if imports else ""} # revision identifiers, used by Alembic. diff --git a/src/ell/stores/migrations/versions/f6528d04bbbd_evaluations.py b/src/ell/stores/migrations/versions/f6528d04bbbd_evaluations.py new file mode 100644 index 000000000..60c731b6e --- /dev/null +++ b/src/ell/stores/migrations/versions/f6528d04bbbd_evaluations.py @@ -0,0 +1,107 @@ +"""evaluations + +Revision ID: f6528d04bbbd +Revises: 4524fb60d23e +Create Date: 2024-11-19 19:31:38.381105+00:00 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel +import ell.stores.models.core + + +# revision identifiers, used by Alembic. +revision: str = 'f6528d04bbbd' +down_revision: Union[str, None] = '4524fb60d23e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('serializedevaluation', + sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('created_at', ell.stores.models.core.UTCTimestamp(timezone=True), nullable=False), + sa.Column('dataset_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('n_evals', sa.Integer(), nullable=False), + sa.Column('version_number', sa.Integer(), nullable=False), + sa.Column('commit_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('evaluationlabeler', + sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('type', sa.Enum('METRIC', 'ANNOTATION', 'CRITERION', name='evaluationlabelertype'), nullable=False), + sa.Column('labeling_lmp_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('evaluation_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('labeling_rubric', sa.JSON(), nullable=True), + sa.ForeignKeyConstraint(['evaluation_id'], ['serializedevaluation.id'], ), + sa.ForeignKeyConstraint(['labeling_lmp_id'], ['serializedlmp.lmp_id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_evaluationlabeler_labeling_lmp_id'), 'evaluationlabeler', ['labeling_lmp_id'], unique=False) + op.create_table('serializedevaluationrun', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('evaluation_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('evaluated_lmp_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('api_params', sa.JSON(), nullable=True), + sa.Column('start_time', ell.stores.models.core.UTCTimestamp(timezone=True), nullable=True), + sa.Column('end_time', ell.stores.models.core.UTCTimestamp(timezone=True), nullable=True), + sa.Column('success', sa.Boolean(), nullable=True), + sa.Column('error', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint(['evaluated_lmp_id'], ['serializedlmp.lmp_id'], ), + sa.ForeignKeyConstraint(['evaluation_id'], ['serializedevaluation.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_serializedevaluationrun_evaluated_lmp_id'), 'serializedevaluationrun', ['evaluated_lmp_id'], unique=False) + op.create_index(op.f('ix_serializedevaluationrun_evaluation_id'), 'serializedevaluationrun', ['evaluation_id'], unique=False) + op.create_table('evaluationresultdatapoint', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('invocation_being_labeled_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('evaluation_run_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['evaluation_run_id'], ['serializedevaluationrun.id'], ), + sa.ForeignKeyConstraint(['invocation_being_labeled_id'], ['invocation.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('evaluationrunlabelersummary', + sa.Column('evaluation_labeler_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('evaluation_run_id', sa.Integer(), nullable=False), + sa.Column('created_at', ell.stores.models.core.UTCTimestamp(timezone=True), nullable=True), + sa.Column('updated_at', ell.stores.models.core.UTCTimestamp(timezone=True), nullable=True), + sa.Column('finalized_at', ell.stores.models.core.UTCTimestamp(timezone=True), nullable=True), + sa.Column('is_scalar', sa.Boolean(), nullable=False), + sa.Column('data', sa.JSON(), nullable=True), + sa.Column('count', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['evaluation_labeler_id'], ['evaluationlabeler.id'], ), + sa.ForeignKeyConstraint(['evaluation_run_id'], ['serializedevaluationrun.id'], ), + sa.PrimaryKeyConstraint('evaluation_labeler_id', 'evaluation_run_id') + ) + op.create_table('evaluationlabel', + sa.Column('labeler_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('labeled_datapoint_id', sa.Integer(), nullable=False), + sa.Column('label_invocation_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('manual_label', sa.JSON(), nullable=True), + sa.ForeignKeyConstraint(['label_invocation_id'], ['invocation.id'], ), + sa.ForeignKeyConstraint(['labeled_datapoint_id'], ['evaluationresultdatapoint.id'], ), + sa.ForeignKeyConstraint(['labeler_id'], ['evaluationlabeler.id'], ), + sa.PrimaryKeyConstraint('labeler_id', 'labeled_datapoint_id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('evaluationlabel') + op.drop_table('evaluationrunlabelersummary') + op.drop_table('evaluationresultdatapoint') + op.drop_index(op.f('ix_serializedevaluationrun_evaluation_id'), table_name='serializedevaluationrun') + op.drop_index(op.f('ix_serializedevaluationrun_evaluated_lmp_id'), table_name='serializedevaluationrun') + op.drop_table('serializedevaluationrun') + op.drop_index(op.f('ix_evaluationlabeler_labeling_lmp_id'), table_name='evaluationlabeler') + op.drop_table('evaluationlabeler') + op.drop_table('serializedevaluation') + # ### end Alembic commands ### diff --git a/src/ell/stores/models/__init__.py b/src/ell/stores/models/__init__.py new file mode 100644 index 000000000..20a64be9c --- /dev/null +++ b/src/ell/stores/models/__init__.py @@ -0,0 +1,2 @@ +from .core import * +from .evaluations import * diff --git a/src/ell/stores/studio.py b/src/ell/stores/models/core.py similarity index 63% rename from src/ell/stores/studio.py rename to src/ell/stores/models/core.py index 3d6aa1ed7..f339bbab6 100644 --- a/src/ell/stores/studio.py +++ b/src/ell/stores/models/core.py @@ -1,33 +1,27 @@ from datetime import datetime, timezone -import enum from functools import cached_property + import sqlalchemy.types as types from ell.types.lmp import LMPType from ell.types.message import Any, Any, Field, Message, Optional +from sqlmodel import Column, Field, SQLModel +from typing import Optional + from typing import Optional -from dataclasses import dataclass -from typing import Dict, List, Literal, Union, Any, Optional +from typing import Dict, List, Union, Any, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel from datetime import datetime from typing import Any, List, Optional from sqlmodel import Field, SQLModel, Relationship, JSON, Column from sqlalchemy import Index, func -import sqlalchemy as sa -from typing import TypeVar, Any - -def utc_now() -> datetime: - """ - Returns the current UTC timestamp. - Serializes to ISO-8601. - """ - return datetime.now(tz=timezone.utc) +from typing import Any class SerializedLMPUses(SQLModel, table=True): """ @@ -36,20 +30,26 @@ class SerializedLMPUses(SQLModel, table=True): This class is used to track which LMPs use or are used by other LMPs. """ - lmp_user_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is being used - lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP + lmp_user_id: Optional[str] = Field( + default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True + ) # ID of the LMP that is being used + lmp_using_id: Optional[str] = Field( + default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True + ) # ID of the LMP that is using the other LMP class UTCTimestamp(types.TypeDecorator[datetime]): cache_ok = True impl = types.TIMESTAMP - def process_result_value(self, value: datetime, dialect:Any): - return value.replace(tzinfo=timezone.utc) + + def process_result_value(self, value: Optional[datetime], dialect: Any) -> Optional[datetime]: + if value is not None: + return value.replace(tzinfo=timezone.utc) + return None -def UTCTimestampField(index:bool=False, **kwargs:Any): - return Field( - sa_column=Column(UTCTimestamp(timezone=True), index=index, **kwargs)) +def UTCTimestampField(index: bool = False, **kwargs: Any): + return Field(sa_column=Column(UTCTimestamp(timezone=True), index=index, **kwargs)) class SerializedLMPBase(SQLModel): @@ -60,9 +60,15 @@ class SerializedLMPBase(SQLModel): created_at: datetime = UTCTimestampField(index=True, nullable=False) lmp_type: LMPType - api_params: Optional[Dict[str, Any]] = Field(default_factory=dict, sa_column=Column(JSON)) - initial_free_vars: Optional[Dict[str, Any]] = Field(default_factory=dict, sa_column=Column(JSON)) - initial_global_vars: Optional[Dict[str, Any]] = Field(default_factory=dict, sa_column=Column(JSON)) + api_params: Optional[Dict[str, Any]] = Field( + default_factory=dict, sa_column=Column(JSON) + ) + initial_free_vars: Optional[Dict[str, Any]] = Field( + default_factory=dict, sa_column=Column(JSON) + ) + initial_global_vars: Optional[Dict[str, Any]] = Field( + default_factory=dict, sa_column=Column(JSON) + ) num_invocations: Optional[int] = Field(default=0) commit_message: Optional[str] = Field(default=None) version_number: Optional[int] = Field(default=None) @@ -86,7 +92,10 @@ class SerializedLMP(SerializedLMPBase, table=True): secondaryjoin="SerializedLMP.lmp_id==SerializedLMPUses.lmp_user_id", ), ) - + + evaluation_runs: List["SerializedEvaluationRun"] = Relationship(back_populates="evaluated_lmp") + + class Config: table_name = "serializedlmp" # XXX: THis is not a real constraint. @@ -94,13 +103,19 @@ class Config: class InvocationTrace(SQLModel, table=True): - invocation_consumer_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) - invocation_consuming_id: str = Field(foreign_key="invocation.id", primary_key=True, index=True) + invocation_consumer_id: str = Field( + foreign_key="invocation.id", primary_key=True, index=True + ) + invocation_consuming_id: str = Field( + foreign_key="invocation.id", primary_key=True, index=True + ) + # Should be subtyped for differnet kidns of LMPS. # XXX: Move all ofh te binary data out to a different table. # XXX: Need a flag that says dont store images. # XXX: Deprecate the args columns + class InvocationBase(SQLModel): id: Optional[str] = Field(default=None, primary_key=True) lmp_id: str = Field(foreign_key="serializedlmp.lmp_id", index=True) @@ -109,42 +124,56 @@ class InvocationBase(SQLModel): completion_tokens: Optional[int] = Field(default=None) state_cache_key: Optional[str] = Field(default=None) created_at: datetime = UTCTimestampField(default=func.now(), nullable=False) - used_by_id: Optional[str] = Field(default=None, foreign_key="invocation.id", index=True) - # global_vars and free_vars removed from here + used_by_id: Optional[str] = Field( + default=None, foreign_key="invocation.id", index=True + ) + + +class ExternalizeableModel(SQLModel): + is_external: bool = Field(default=False) + -class InvocationContentsBase(SQLModel): - invocation_id: str = Field(foreign_key="invocation.id", index=True, primary_key=True) +class InvocationContentsBase(ExternalizeableModel): + invocation_id: str = Field( + foreign_key="invocation.id", index=True, primary_key=True + ) params: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) - results: Optional[Union[List[Message], Any]] = Field(default=None, sa_column=Column(JSON)) - invocation_api_params: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) + results: Optional[Union[List[Message], Any]] = Field( + default=None, sa_column=Column(JSON) + ) + invocation_api_params: Optional[Dict[str, Any]] = Field( + default=None, sa_column=Column(JSON) + ) + global_vars: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) free_vars: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON)) - is_external : bool = Field(default=False) @cached_property def should_externalize(self) -> bool: import json - + json_fields = [ self.params, self.results, self.invocation_api_params, self.global_vars, - self.free_vars + self.free_vars, ] - + total_size = sum( len(json.dumps(field, default=(lambda x: json.dumps(x.model_dump(), default=str, ensure_ascii=False) if isinstance(x, BaseModel) else str(x)), ensure_ascii=False).encode('utf-8')) for field in json_fields if field is not None ) # print("total_size", total_size) - + return total_size > 102400 # Precisely 100kb in bytes + class InvocationContents(InvocationContentsBase, table=True): invocation: "Invocation" = Relationship(back_populates="contents") + class Invocation(InvocationBase, table=True): lmp: SerializedLMP = Relationship(back_populates="invocations") consumed_by: List["Invocation"] = Relationship( @@ -163,12 +192,19 @@ class Invocation(InvocationBase, table=True): secondaryjoin="Invocation.id==InvocationTrace.invocation_consumer_id", ), ) - used_by: Optional["Invocation"] = Relationship(back_populates="uses", sa_relationship_kwargs={"remote_side": "Invocation.id"}) + used_by: Optional["Invocation"] = Relationship( + back_populates="uses", sa_relationship_kwargs={"remote_side": "Invocation.id"} + ) uses: List["Invocation"] = Relationship(back_populates="used_by") contents: InvocationContents = Relationship(back_populates="invocation") - __table_args__ = ( - Index('ix_invocation_lmp_id_created_at', 'lmp_id', 'created_at'), - Index('ix_invocation_created_at_latency_ms', 'created_at', 'latency_ms'), - Index('ix_invocation_created_at_tokens', 'created_at', 'prompt_tokens', 'completion_tokens'), + Index("ix_invocation_lmp_id_created_at", "lmp_id", "created_at"), + Index("ix_invocation_created_at_latency_ms", "created_at", "latency_ms"), + Index( + "ix_invocation_created_at_tokens", + "created_at", + "prompt_tokens", + "completion_tokens", + ), ) + evaluation_result_datapoints: List["EvaluationResultDatapoint"] = Relationship(back_populates="invocation_being_labeled") diff --git a/src/ell/stores/models/evaluations.py b/src/ell/stores/models/evaluations.py new file mode 100644 index 000000000..dd587a0eb --- /dev/null +++ b/src/ell/stores/models/evaluations.py @@ -0,0 +1,262 @@ +from datetime import datetime +from enum import Enum +from functools import lru_cache + +import numpy as np + +from ell.types.message import Field, Message + +from sqlmodel import Column, Field, SQLModel, Relationship, JSON +from typing import Dict, List, Literal, Union, Any, Optional, cast + +from pydantic import field_validator + +from sqlalchemy import func + +from .core import Invocation, SerializedLMP, UTCTimestampField + +############################# +### Evaluation & Labeling ### +############################# +class EvaluationLabelerType(str, Enum): + METRIC = "metric" + ANNOTATION = "annotation" + CRITERION = "criterion" + +class EvaluationLabelerBase(SQLModel): + id: str = Field(primary_key=True) + name: str + type: EvaluationLabelerType + labeling_lmp_id: Optional[str] = Field( + default=None, foreign_key="serializedlmp.lmp_id", index=True + ) + evaluation_id: str = Field(foreign_key="serializedevaluation.id") + labeling_rubric: Optional[Dict[str, Any]] = Field( + default=None, sa_column=Column(JSON) + ) + +class EvaluationLabeler(EvaluationLabelerBase, table=True): + evaluation: "SerializedEvaluation" = Relationship(back_populates="labelers") + labeling_lmp: Optional[SerializedLMP] = Relationship() + labels: List["EvaluationLabel"] = Relationship(back_populates="labeler") + evaluation_run_summaries: List["EvaluationRunLabelerSummary"] = Relationship(back_populates="evaluation_labeler") + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.id is None: + self.id = EvaluationLabeler.generate_id(self.evaluation_id, self.name, self.type) + + @field_validator("id") + def validate_id(cls, v): + if v is not None: + assert v.startswith("labeler-") + evaluation, eid, name, type = v.split("-")[1:] + assert evaluation == "evaluation" + assert type in EvaluationLabelerType.__members__ + return v + + @staticmethod + @lru_cache(maxsize=128) + def generate_id(evaluation_id: str, name: str, type: EvaluationLabelerType) -> str: + return f"labeler-{evaluation_id}-{name}-{type.name}" + + @field_validator("labeling_lmp_id", "labeling_rubric") + def validate_labeler_or_instructions(cls, v, values): + if "labeling_lmp_id" not in values and "labeling_rubric" not in values: + raise ValueError("Either labeler_lmp_id or instructions must be set") + return v + +class EvaluationLabelBase(SQLModel): + + labeler_id: str = Field( + foreign_key="evaluationlabeler.id", + primary_key=True, + ) + labeled_datapoint_id: int = Field( + foreign_key="evaluationresultdatapoint.id", + primary_key=True, + ) + label_invocation_id: Optional[str] = Field( + default=None, foreign_key="invocation.id" + ) + manual_label: Optional[Dict[str, Any]] = Field( + default=None, sa_column=Column(JSON) + ) + +class EvaluationLabel(EvaluationLabelBase, table=True): + labeled_datapoint: "EvaluationResultDatapoint" = Relationship(back_populates="labels") + labeler: EvaluationLabeler = Relationship(back_populates="labels") + label_invocation: Optional[Invocation] = Relationship() + +class EvaluationResultDatapointBase(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + invocation_being_labeled_id: str = Field( + foreign_key="invocation.id" + ) + evaluation_run_id: int = Field(foreign_key="serializedevaluationrun.id") + +class EvaluationResultDatapoint(EvaluationResultDatapointBase, table=True): + invocation_being_labeled: Invocation = Relationship(back_populates="evaluation_result_datapoints") + evaluation_run: "SerializedEvaluationRun" = Relationship(back_populates="results") + labels: List[EvaluationLabel] = Relationship(back_populates="labeled_datapoint") + +class EvaluationRunLabelerSummaryBase(SQLModel): + evaluation_labeler_id: str = Field(foreign_key="evaluationlabeler.id", primary_key=True) + evaluation_run_id: int = Field(foreign_key="serializedevaluationrun.id", primary_key=True) + created_at: datetime = UTCTimestampField(default=func.now()) + updated_at: Optional[datetime] = UTCTimestampField(default=None) + finalized_at: Optional[datetime] = UTCTimestampField(default=None) + is_scalar: bool = Field(default=False) + data: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON)) + count: int = Field(default=0) + +class EvaluationRunLabelerSummary(EvaluationRunLabelerSummaryBase, table=True): + evaluation_run: "SerializedEvaluationRun" = Relationship(back_populates="labeler_summaries") + evaluation_labeler: EvaluationLabeler = Relationship(back_populates="evaluation_run_summaries") + + def mean( + self, + ) -> Union[float, Dict[str, Any], List[Union[float, Dict[str, Any], None]], None]: + return self._get_value_recursively("_mean") + + def std( + self, + ) -> Union[float, Dict[str, Any], List[Union[float, Dict[str, Any], None]], None]: + return self._get_value_recursively("_std") + + def min( + self, + ) -> Union[float, Dict[str, Any], List[Union[float, Dict[str, Any], None]], None]: + return self._get_value_recursively("_min") + + def max( + self, + ) -> Union[float, Dict[str, Any], List[Union[float, Dict[str, Any], None]], None]: + return self._get_value_recursively("_max") + + @classmethod + def from_labels( + cls, + data: Union[List[float], List[Dict[str, Any]]], + **other_keys + ) -> "EvaluationRunLabelerSummary": + if len(data) == 0: + # XXXL revisit. + raise ValueError( + "Aggregated run cannot contain empty data, at least one datapoint is required." + ) + + stats = lambda x: { + "mean": float(np.mean(x)), + "std": float(np.std(x)), + "min": float(np.min(x)), + "max": float(np.max(x)), + } + try: + return cls( + is_scalar=True, + data=stats(data), + count=len(data), + **other_keys, + ) + except TypeError: + def recursive_aggregate(items): + try: + if all(isinstance(item, dict) for item in items): + result = {} + for key in items[0].keys(): + values = [item[key] for item in items if key in item] + result[key] = recursive_aggregate(values) + return result + else: + return stats(items) + except TypeError: + return { + "mean": None, + "std": None, + "min": None, + "max": None, + } + + aggregated_data = recursive_aggregate(data) + + return cls( + is_scalar=False, data=aggregated_data, count=len(data), **other_keys + ) + + def _get_value_recursively( + self, key: str, value=None + ) -> Union[float, Dict[str, Any], List[Union[float, Dict[str, Any], None]], None]: + # recursively gets the value within the internal data structure + if self.is_scalar: + return self.data[key] + else: + # return the same schema of object as from which it was created but with the mean std min and max on each of the nested objects. + if value is None: + return None + if isinstance(value, dict): + # if _mean is in the dict. + if (possible_result := value.get(key)) is not None: + return possible_result + else: + return { + k: self._get_value_recursively(key, v) for k, v in value.items() + } + else: + raise RuntimeError( + f"Failed to acceses {key} of the aggregated evaluation run result. The object schema does not conform to the expected schema. Current object: {self.data}" + ) + + def update(self, new_data: Union[float, Dict[str, Any]]): + raise NotImplementedError( + "Ell studio does not currently support updating evaluation run results with new data." + ) + +class SerializedEvaluationRunBase(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + evaluation_id: str = Field( + foreign_key="serializedevaluation.id", index=True + ) + evaluated_lmp_id: str = Field( + foreign_key="serializedlmp.lmp_id", index=True + ) + api_params: Optional[Dict[str, Any]] = Field( + default=None, sa_column=Column(JSON) + ) + start_time: datetime = UTCTimestampField() + end_time: Optional[datetime] = UTCTimestampField(default=None) + success: Optional[bool] = Field(default=None) + error: Optional[str] = Field(default=None) + +class SerializedEvaluationRun(SerializedEvaluationRunBase, table=True): + evaluated_lmp: SerializedLMP = Relationship(back_populates="evaluation_runs") + evaluation: "SerializedEvaluation" = Relationship(back_populates="runs") + results: List[EvaluationResultDatapoint] = Relationship(back_populates="evaluation_run") + labeler_summaries: List[EvaluationRunLabelerSummary] = Relationship(back_populates="evaluation_run") + +class SerializedEvaluationBase(SQLModel): + id: str = Field(primary_key=True) + name: str + created_at: datetime = UTCTimestampField(default=func.now(), nullable=False) + dataset_id: str + n_evals: int + version_number: int = Field(default=0) + commit_message: Optional[str] = Field(default=None) + +class SerializedEvaluation(SerializedEvaluationBase, table=True): + labelers: List[EvaluationLabeler] = Relationship(back_populates="evaluation") + runs: List[SerializedEvaluationRun] = Relationship(back_populates="evaluation") + + @field_validator("id") + def validate_id(cls, v): + if v is not None: + assert v.startswith("evaluation-") + assert v.count("-") == 1 + return v + return v + + def get_labeler(self, type: EvaluationLabelerType, name: Optional[str] = None) -> Optional[EvaluationLabeler]: + for labeler in self.labelers: + if labeler.type == type and (name is None or labeler.name == name): + return labeler + return None \ No newline at end of file diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py index ebfbb9ab5..b2f25d75a 100644 --- a/src/ell/stores/sql.py +++ b/src/ell/stores/sql.py @@ -1,16 +1,32 @@ from datetime import datetime, timedelta import os +from typing import Any, Optional, Dict, List, Set, Union +from pydantic import BaseModel +import sqlalchemy from pathlib import Path from typing import Any, Optional, Dict, List, Set from sqlmodel import Session, SQLModel, create_engine, select from ell.stores.migrations import init_or_migrate_database import ell.stores.store from sqlalchemy.sql import text -from ell.stores.studio import InvocationTrace, SerializedLMP, Invocation +from ell.types._lstr import _lstr +from sqlalchemy import or_, func, and_, extract, FromClause +from sqlalchemy.types import TypeDecorator, VARCHAR +from ell.stores.models import SerializedLMPUses +from ell.stores.models.evaluations import ( + EvaluationLabeler, + EvaluationResultDatapoint, + EvaluationRunLabelerSummary, + SerializedEvaluation, + SerializedEvaluationRun, +) +from ell.stores.models.core import InvocationTrace, SerializedLMP, Invocation, InvocationContents from sqlalchemy import func, and_ -from ell.util.serialization import pydantic_ltype_aware_cattr +from ell.util.serialization import pydantic_ltype_aware_cattr, utc_now import gzip import json +from sqlalchemy.exc import IntegrityError + import logging @@ -19,38 +35,62 @@ class SQLStore(ell.stores.store.Store): def __init__(self, db_uri: str, blob_store: Optional[ell.stores.store.BlobStore] = None): # XXX: Use Serialization serialzie_object in incoming PR. - self.engine = create_engine(db_uri, - json_serializer=lambda obj: json.dumps(pydantic_ltype_aware_cattr.unstructure(obj), - sort_keys=True, default=repr, ensure_ascii=False)) + self.engine = create_engine( + db_uri, + json_serializer=lambda obj: json.dumps( + pydantic_ltype_aware_cattr.unstructure(obj), + sort_keys=True, + default=repr, + ensure_ascii=False, + ), + ) init_or_migrate_database(self.engine) self.open_files: Dict[str, Dict[str, Any]] = {} super().__init__(blob_store) - def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]: + def write_lmp( + self, serialized_lmp: SerializedLMP, uses: Dict[str, Any] + ) -> Optional[Any]: with Session(self.engine) as session: - # Bind the serialized_lmp to the session - lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == serialized_lmp.lmp_id)).first() - - if lmp: - # Already added to the DB. - return lmp - else: - session.add(serialized_lmp) - - for use_id in uses: - used_lmp = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == use_id)).first() - if used_lmp: - serialized_lmp.uses.append(used_lmp) - - session.commit() - return None + try: + # Bind the serialized_lmp to the session + lmp = session.exec( + select(SerializedLMP).filter( + SerializedLMP.lmp_id == serialized_lmp.lmp_id + ) + ).first() - def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optional[Any]: + if lmp: + # Already added to the DB. + return lmp + else: + session.add(serialized_lmp) + + for use_id in uses: + used_lmp = session.exec( + select(SerializedLMP).where(SerializedLMP.lmp_id == use_id) + ).first() + if used_lmp: + serialized_lmp.uses.append(used_lmp) + + session.commit() + return None + except sqlalchemy.exc.IntegrityError as e: + session.rollback() + return None + + def write_invocation( + self, invocation: Invocation, consumes: Set[str] + ) -> Optional[Any]: with Session(self.engine) as session: - lmp = session.exec(select(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id)).first() - assert lmp is not None, f"LMP with id {invocation.lmp_id} not found. Writing invocation erroneously" - + lmp = session.exec( + select(SerializedLMP).filter(SerializedLMP.lmp_id == invocation.lmp_id) + ).first() + assert ( + lmp is not None + ), f"LMP with id {invocation.lmp_id} not found. Writing invocation erroneously" + # Increment num_invocations if lmp.num_invocations is None: lmp.num_invocations = 1 @@ -59,69 +99,181 @@ def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Option # Add the invocation contents session.add(invocation.contents) - + # Add the invocation session.add(invocation) # Now create traces. for consumed_id in consumes: - session.add(InvocationTrace( - invocation_consumer_id=invocation.id, - invocation_consuming_id=consumed_id - )) + session.add( + InvocationTrace( + invocation_consumer_id=invocation.id, + invocation_consuming_id=consumed_id, + ) + ) session.commit() return None - - def get_cached_invocations(self, lmp_id :str, state_cache_key :str) -> List[Invocation]: + + def write_evaluation(self, evaluation: SerializedEvaluation) -> str: + with Session(self.engine) as session: + try: + # Check if the evaluation already exists + existing_evaluation = session.exec( + select(SerializedEvaluation).where( + SerializedEvaluation.id == evaluation.id + ) + ).first() + + if existing_evaluation: + # Update the existing evaluation + existing_evaluation.name = evaluation.name + existing_evaluation.dataset_id = evaluation.dataset_id + existing_evaluation.n_evals = evaluation.n_evals + existing_evaluation.version_number = evaluation.version_number + existing_evaluation.commit_message = evaluation.commit_message + else: + # Add the new evaluation + session.add(evaluation) + + # Process labelers + for labeler in evaluation.labelers: + existing_labeler = session.exec( + select(EvaluationLabeler).where( + (EvaluationLabeler.evaluation_id == evaluation.id) + & (EvaluationLabeler.name == labeler.name) + ) + ).first() + + if existing_labeler: + # Update existing labeler + existing_labeler.type = labeler.type + existing_labeler.labeling_lmp_id = labeler.labeling_lmp_id + existing_labeler.labeling_rubric = labeler.labeling_rubric + else: + # Add new labeler + labeler.evaluation_id = evaluation.id + session.add(labeler) + + session.commit() + return evaluation.id + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error writing evaluation: {str(e)}") + except Exception as e: + session.rollback() + raise e + + def write_evaluation_run(self, evaluation_run: SerializedEvaluationRun) -> int: with Session(self.engine) as session: - return self.get_invocations(session, lmp_filters={"lmp_id": lmp_id}, filters={"state_cache_key": state_cache_key}) + session.add(evaluation_run) + session.commit() + return evaluation_run.id - def get_versions_by_fqn(self, fqn :str) -> List[SerializedLMP]: + def write_evaluation_run_intermediate(self, row_result : EvaluationResultDatapoint) -> None: + # add a new result datapoint + with Session(self.engine) as session: + session.add(row_result) + session.commit() + + def write_evaluation_run_end(self, evaluation_run_id : str, success : bool, end_time : datetime, error : Optional[str], summaries: List[EvaluationRunLabelerSummary]) -> None: + # Update hte evaluation run adn add summaries to it + with Session(self.engine) as session: + evaluation_run = session.exec(select(SerializedEvaluationRun).where(SerializedEvaluationRun.id == evaluation_run_id)).first() + assert evaluation_run is not None, "Evaluation run must exist to write end." + evaluation_run.success = success + evaluation_run.end_time = end_time + evaluation_run.error = error + evaluation_run.labeler_summaries.extend(summaries) + session.add(evaluation_run) + session.commit() + + def write_evaluation_run_labeler_summaries( + self, summaries: List[EvaluationRunLabelerSummary] + ) -> int: + with Session(self.engine) as session: + session.add_all(summaries) + session.commit() + return len(summaries) + + def get_cached_invocations( + self, lmp_id: str, state_cache_key: str + ) -> List[Invocation]: + with Session(self.engine) as session: + return self.get_invocations( + session, + lmp_filters={"lmp_id": lmp_id}, + filters={"state_cache_key": state_cache_key}, + ) + + def get_versions_by_fqn(self, fqn: str) -> List[SerializedLMP]: with Session(self.engine) as session: return self.get_lmps(session, name=fqn) - - ## HELPER METHODS FOR ELL STUDIO! :) - def get_latest_lmps(self, session: Session, skip: int = 0, limit: int = 10) -> List[Dict[str, Any]]: + + ## HELPER METHODS FOR ELL STUDIO! :) + def get_latest_lmps( + self, session: Session, skip: int = 0, limit: int = 10 + ) -> List[Dict[str, Any]]: """ Gets all the lmps grouped by unique name with the highest created at """ subquery = ( - select(SerializedLMP.name, func.max(SerializedLMP.created_at).label("max_created_at")) + select( + SerializedLMP.name, + func.max(SerializedLMP.created_at).label("max_created_at"), + ) .group_by(SerializedLMP.name) .subquery() ) - - filters = { - "name": subquery.c.name, - "created_at": subquery.c.max_created_at - } - - return self.get_lmps(session, skip=skip, limit=limit, subquery=subquery, **filters) - - def get_lmps(self, session: Session, skip: int = 0, limit: int = 10, subquery=None, **filters: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: + filters = {"name": subquery.c.name, "created_at": subquery.c.max_created_at} + + return self.get_lmps( + session, skip=skip, limit=limit, subquery=subquery, **filters + ) + + def get_lmps( + self, + session: Session, + skip: int = 0, + limit: int = 10, + subquery=None, + **filters: Optional[Dict[str, Any]], + ) -> List[Dict[str, Any]]: query = select(SerializedLMP) - + if subquery is not None: - query = query.join(subquery, and_( - SerializedLMP.name == subquery.c.name, - SerializedLMP.created_at == subquery.c.max_created_at - )) - + query = query.join( + subquery, + and_( + SerializedLMP.name == subquery.c.name, + SerializedLMP.created_at == subquery.c.max_created_at, + ), + ) + if filters: for key, value in filters.items(): query = query.where(getattr(SerializedLMP, key) == value) - - query = query.order_by(SerializedLMP.created_at.desc()) # Sort by created_at in descending order + + query = query.order_by( + SerializedLMP.created_at.desc() + ) # Sort by created_at in descending order query = query.offset(skip).limit(limit) results = session.exec(query).all() - + return results - def get_invocations(self, session: Session, lmp_filters: Dict[str, Any], skip: int = 0, limit: int = 10, filters: Optional[Dict[str, Any]] = None, hierarchical: bool = False) -> List[Dict[str, Any]]: - + def get_invocations( + self, + session: Session, + lmp_filters: Dict[str, Any], + skip: int = 0, + limit: int = 10, + filters: Optional[Dict[str, Any]] = None, + hierarchical: bool = False, + ) -> List[Dict[str, Any]]: + query = select(Invocation).join(SerializedLMP) # Apply LMP filters @@ -139,9 +291,9 @@ def get_invocations(self, session: Session, lmp_filters: Dict[str, Any], skip: i invocations = session.exec(query).all() return invocations - def get_traces(self, session: Session): - query = text(""" + query = text( + """ SELECT consumer.lmp_id, trace.*, @@ -152,77 +304,187 @@ def get_traces(self, session: Session): invocationtrace AS trace ON consumer.id = trace.invocation_consumer_id JOIN invocation AS consumed ON trace.invocation_consuming_id = consumed.id - """) + """ + ) results = session.exec(query).all() - + traces = [] - for (consumer_lmp_id, consumer_invocation_id, consumed_invocation_id, consumed_lmp_id) in results: - traces.append({ - 'consumer': consumer_lmp_id, - 'consumed': consumed_lmp_id - }) - + for ( + consumer_lmp_id, + consumer_invocation_id, + consumed_invocation_id, + consumed_lmp_id, + ) in results: + traces.append({"consumer": consumer_lmp_id, "consumed": consumed_lmp_id}) + return traces - - def get_invocations_aggregate(self, session: Session, lmp_filters: Dict[str, Any] = None, filters: Dict[str, Any] = None, days: int = 30) -> Dict[str, Any]: + + def get_invocations_aggregate( + self, + session: Session, + lmp_filters: Dict[str, Any] = None, + filters: Dict[str, Any] = None, + days: int = 30, + ) -> Dict[str, Any]: # Calculate the start date for the graph data start_date = datetime.utcnow() - timedelta(days=days) # Base subquery base_subquery = ( - select(Invocation.created_at, Invocation.latency_ms, Invocation.prompt_tokens, Invocation.completion_tokens, Invocation.lmp_id) + select( + Invocation.created_at, + Invocation.latency_ms, + Invocation.prompt_tokens, + Invocation.completion_tokens, + Invocation.lmp_id, + ) .join(SerializedLMP, Invocation.lmp_id == SerializedLMP.lmp_id) .filter(Invocation.created_at >= start_date) ) # Apply filters if lmp_filters: - base_subquery = base_subquery.filter(and_(*[getattr(SerializedLMP, k) == v for k, v in lmp_filters.items()])) + base_subquery = base_subquery.filter( + and_(*[getattr(SerializedLMP, k) == v for k, v in lmp_filters.items()]) + ) if filters: - base_subquery = base_subquery.filter(and_(*[getattr(Invocation, k) == v for k, v in filters.items()])) + base_subquery = base_subquery.filter( + and_(*[getattr(Invocation, k) == v for k, v in filters.items()]) + ) - data = session.exec(base_subquery).all() # Calculate aggregate metrics total_invocations = len(data) total_tokens = sum(row.prompt_tokens + row.completion_tokens for row in data) - avg_latency = sum(row.latency_ms for row in data) / total_invocations if total_invocations > 0 else 0 + avg_latency = ( + sum(row.latency_ms for row in data) / total_invocations + if total_invocations > 0 + else 0 + ) unique_lmps = len(set(row.lmp_id for row in data)) # Prepare graph data graph_data = [] for row in data: - graph_data.append({ - "date": row.created_at, - "avg_latency": row.latency_ms, - "tokens": row.prompt_tokens + row.completion_tokens, - "count": 1 - }) + graph_data.append( + { + "date": row.created_at, + "avg_latency": row.latency_ms, + "tokens": row.prompt_tokens + row.completion_tokens, + "count": 1, + } + ) return { "total_invocations": total_invocations, "total_tokens": total_tokens, "avg_latency": avg_latency, "unique_lmps": unique_lmps, - "graph_data": graph_data + "graph_data": graph_data, } + def get_evaluations( + self, session: Session, filters: Dict[str, Any], skip: int = 0, limit: int = 100 + ) -> List[SerializedEvaluation]: + query = select(SerializedEvaluation) + + for key, value in filters.items(): + query = query.where(getattr(SerializedEvaluation, key) == value) + print(key, value) + + query = query.offset(skip).limit(limit) + + results = session.exec(query).all() + return results + + def get_latest_evaluations( + self, session: Session, skip: int = 0, limit: int = 100 + ) -> List[SerializedEvaluation]: + # Subquery to get the latest version number for each evaluation name + latest_versions = ( + select( + SerializedEvaluation.name, + func.max(SerializedEvaluation.version_number).label("max_version"), + ) + .group_by(SerializedEvaluation.name) + .subquery() + ) + + # Main query to get the latest evaluations + query = ( + select(SerializedEvaluation) + .join( + latest_versions, + and_( + SerializedEvaluation.name == latest_versions.c.name, + SerializedEvaluation.version_number + == latest_versions.c.max_version, + ), + ) + .order_by(SerializedEvaluation.created_at.desc()) + .offset(skip) + .limit(limit) + ) + + results = session.exec(query).all() + return list(results) + + def get_eval_versions_by_name(self, name: str) -> List[SerializedEvaluation]: + with Session(self.engine) as session: + query = select(SerializedEvaluation).where( + SerializedEvaluation.name == name + ) + query = query.order_by( + SerializedEvaluation.version_number.desc() + ) # Sort by version number in descending order + results = session.exec(query).all() + return list( + results + ) # Convert to list to ensure it's a List[SerializedEvaluation] + + + def get_evaluation_run(self, session: Session, run_id: str) -> SerializedEvaluationRun: + query = select(SerializedEvaluationRun).where( + SerializedEvaluationRun.id == run_id, + + ) + result = session.exec(query).one() + + return result + + def get_evaluation_run_results(self, session: Session, run_id: str, skip: int = 0, limit: int = 100, filters : Optional[Dict[str, Any]] = None) -> List[EvaluationResultDatapoint]: + query = select(EvaluationResultDatapoint).where( + EvaluationResultDatapoint.evaluation_run_id == run_id + ) + + if filters: + for key, value in filters.items(): + query = query.where(getattr(EvaluationResultDatapoint, key) == value) + + query = query.offset(skip).limit(limit) + + results = session.exec(query).all() + print(f"Found {len(results)} results for run {run_id}") + return list(results) + + class SQLiteStore(SQLStore): def __init__(self, db_dir: str): - assert not db_dir.endswith('.db'), "Create store with a directory not a db." - + assert not db_dir.endswith(".db"), "Create store with a directory not a db." + os.makedirs(db_dir, exist_ok=True) self.db_dir = db_dir - db_path = os.path.join(db_dir, 'ell.db') + db_path = os.path.join(db_dir, "ell.db") blob_store = SQLBlobStore(db_dir) - super().__init__(f'sqlite:///{db_path}', blob_store=blob_store) + super().__init__(f"sqlite:///{db_path}", blob_store=blob_store) + class SQLBlobStore(ell.stores.store.BlobStore): def __init__(self, db_dir: str): self.db_dir = db_dir - def store_blob(self, blob: bytes, blob_id : str) -> str: + def store_blob(self, blob: bytes, blob_id: str) -> str: file_path = self._get_blob_path(blob_id) os.makedirs(os.path.dirname(file_path), exist_ok=True) with gzip.open(file_path, "wb") as f: @@ -238,11 +500,13 @@ def _get_blob_path(self, id: str, depth: int = 2) -> str: assert "-" in id, "Blob id must have a single - in it to split on." _type, _id = id.split("-") increment = 2 - dirs = [_type] + [_id[i:i+increment] for i in range(0, depth*increment, increment)] - file_name = _id[depth*increment:] + dirs = [_type] + [ + _id[i : i + increment] for i in range(0, depth * increment, increment) + ] + file_name = _id[depth * increment :] return os.path.join(self.db_dir, *dirs, file_name) + class PostgresStore(SQLStore): def __init__(self, db_uri: str): super().__init__(db_uri) - diff --git a/src/ell/stores/store.py b/src/ell/stores/store.py index 408513cef..737743976 100644 --- a/src/ell/stores/store.py +++ b/src/ell/stores/store.py @@ -3,8 +3,10 @@ from datetime import datetime from typing import Any, Optional, Dict, List, Set, Union from ell.types._lstr import _lstr -from ell.stores.studio import SerializedLMP, Invocation +from ell.stores.models.core import SerializedLMP, Invocation from ell.types.message import InvocableLM +from ell.stores.models.evaluations import EvaluationResultDatapoint, EvaluationRunLabelerSummary, SerializedEvaluation, SerializedEvaluationRun +# from ell.types.studio import SerializedEvaluation, SerializedEvaluationRun class BlobStore(ABC): @abstractmethod @@ -52,6 +54,51 @@ def write_invocation(self, invocation: Invocation, consumes: Set[str]) -> Optio """ pass + @abstractmethod + def write_evaluation(self, evaluation: SerializedEvaluation) -> str: + """ + Write an evaluation to the storage. + + :param evaluation: Evaluation object containing all evaluation details. + :param runs: List of EvaluationRun objects representing the evaluation runs. + :return: Optional return value. + """ + pass + + @abstractmethod + def write_evaluation_run(self, evaluation_run: SerializedEvaluationRun) -> int: + """ + Write an evaluation run to the storage. + + :param evaluation_run: EvaluationRun object containing all evaluation run details. + :return: Optional return value. + """ + pass + + @abstractmethod + def write_evaluation_run_intermediate(self, row_result : EvaluationResultDatapoint) -> None: + """ + Write an evaluation run intermediate result to the storage. + """ + pass + + @abstractmethod + def write_evaluation_run_end(self, evaluation_run_id : str, successful : bool, end_time : datetime, error : Optional[str], summaries: List[EvaluationRunLabelerSummary]) -> None: + """ + Write an evaluation run end to the storage. + """ + pass + + @abstractmethod + def write_evaluation_run_labeler_summaries(self, summaries: List[EvaluationRunLabelerSummary]) -> int: + """ + Write evaluation run labeler summaries to the storage. + + :param summaries: List of EvaluationRunLabelerSummary objects containing all evaluation run labeler summary details. + :return: Optional return value. + """ + pass + @abstractmethod def get_cached_invocations(self, lmp_id :str, state_cache_key :str) -> List[Invocation]: """ @@ -66,6 +113,17 @@ def get_versions_by_fqn(self, fqn :str) -> List[SerializedLMP]: """ pass + @abstractmethod + def get_eval_versions_by_name(self, name: str) -> List[SerializedEvaluation]: + """ + Get all versions of an evaluation by its name. + + :param name: The name of the evaluation. + :return: A list of SerializedEvaluation objects representing all versions of the evaluation. + """ + pass + + @contextmanager def freeze(self, *lmps: InvocableLM): diff --git a/src/ell/studio/datamodels.py b/src/ell/studio/datamodels.py index 521929f7e..7b67f5f6d 100644 --- a/src/ell/studio/datamodels.py +++ b/src/ell/studio/datamodels.py @@ -1,7 +1,15 @@ from datetime import datetime from typing import List, Optional, Dict, Any from sqlmodel import SQLModel -from ell.stores.studio import SerializedLMPBase, InvocationBase, InvocationContentsBase +from ell.stores.models.evaluations import ( + EvaluationLabelBase, + EvaluationLabelerBase, + SerializedEvaluationBase, + SerializedEvaluationRunBase, + EvaluationRunLabelerSummaryBase, + EvaluationResultDatapointBase, +) +from ell.stores.models.core import SerializedLMPBase, InvocationBase, InvocationContentsBase class SerializedLMPWithUses(SerializedLMPBase): @@ -17,7 +25,16 @@ class InvocationPublic(InvocationBase): class InvocationPublicWithConsumes(InvocationPublic): consumes: List[InvocationPublic] consumed_by: List[InvocationPublic] - + + +class InvocationPublicWithoutLMP(InvocationBase): + uses : List["InvocationPublicWithoutLMPAndConsumes"] + contents: InvocationContentsBase + + +class InvocationPublicWithoutLMPAndConsumes(InvocationPublicWithoutLMP): + consumes: List[InvocationPublicWithoutLMP] + consumed_by: List[InvocationPublicWithoutLMP] from pydantic import BaseModel @@ -39,3 +56,36 @@ class InvocationsAggregate(BaseModel): # success_rate: float graph_data: List[GraphDataPoint] + +# Update these models at the end of the file +class EvaluationLabelerPublic(EvaluationLabelerBase): + labeling_lmp: Optional[SerializedLMPBase] + +class EvaluationRunLabelerSummaryPublic(EvaluationRunLabelerSummaryBase): + evaluation_labeler: EvaluationLabelerPublic + +class EvaluationRunPublic(SerializedEvaluationRunBase): + evaluated_lmp: SerializedLMPBase + labeler_summaries: List[EvaluationRunLabelerSummaryPublic] + +class EvaluationPublic(SerializedEvaluationBase): + labelers: List[EvaluationLabelerPublic] + runs: List[EvaluationRunPublic] + +# XXXX +class EvaluationPublicWithoutRuns(SerializedEvaluationBase): + labelers: List[EvaluationLabelerPublic] + +# XXXXXX +class EvaluationLabelPublic(EvaluationLabelBase): + label_invocation: Optional[InvocationPublicWithoutLMP] + labeler : EvaluationLabelerBase + +class EvaluationResultDatapointPublic(EvaluationResultDatapointBase): + invocation_being_labeled: InvocationPublicWithoutLMP + labels: List[EvaluationLabelPublic] + +class SpecificEvaluationRunPublic(SerializedEvaluationRunBase): + evaluated_lmp: SerializedLMPBase + evaluation: EvaluationPublicWithoutRuns + labeler_summaries: List[EvaluationRunLabelerSummaryPublic] diff --git a/src/ell/studio/server.py b/src/ell/studio/server.py index 2879ea7a2..40ea6a809 100644 --- a/src/ell/studio/server.py +++ b/src/ell/studio/server.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from sqlmodel import Session from ell.stores.sql import PostgresStore, SQLiteStore @@ -9,11 +9,12 @@ import json from ell.studio.config import Config from ell.studio.connection_manager import ConnectionManager -from ell.studio.datamodels import InvocationPublicWithConsumes, SerializedLMPWithUses +from ell.studio.datamodels import EvaluationResultDatapointPublic, InvocationPublicWithConsumes, SerializedLMPWithUses, EvaluationPublic, SpecificEvaluationRunPublic -from ell.stores.studio import SerializedLMP +from ell.stores.models.core import SerializedLMP from datetime import datetime, timedelta from sqlmodel import select +from ell.stores.models.evaluations import SerializedEvaluation logger = logging.getLogger(__name__) @@ -220,4 +221,129 @@ def get_invocations_aggregate( - return app \ No newline at end of file + @app.get("/api/evaluations", response_model=List[EvaluationPublic]) + def get_evaluations( + evaluation_id: Optional[str] = Query(None), + lmp_id: Optional[str] = Query(None), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + session: Session = Depends(get_session) + ): + filters: Dict[str, Any] = {} + if evaluation_id: + filters['id'] = evaluation_id + if lmp_id: + filters['lmp_id'] = lmp_id + + evaluations = serializer.get_evaluations( + session, + filters=filters, + skip=skip, + limit=limit + ) + + + return evaluations + + @app.get("/api/latest/evaluations", response_model=List[EvaluationPublic]) + def get_latest_evaluations( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + session: Session = Depends(get_session) + ): + evaluations = serializer.get_latest_evaluations( + session, + skip=skip, + limit=limit + ) + + return evaluations + + @app.get("/api/evaluation/{evaluation_id}", response_model=EvaluationPublic) + def get_evaluation( + evaluation_id: str, + session: Session = Depends(get_session) + ): + evaluation = serializer.get_evaluations(session, filters={"id": evaluation_id}) + if not evaluation: + raise HTTPException(status_code=404, detail="Evaluation not found") + return evaluation[0] + + + + @app.get("/api/evaluation-runs/{run_id}", response_model=SpecificEvaluationRunPublic) + def get_evaluation_run( + run_id: str, + session: Session = Depends(get_session) + ): + runs = serializer.get_evaluation_run(session, run_id) + return runs + + @app.get("/api/evaluation-runs/{run_id}/results", response_model=List[EvaluationResultDatapointPublic]) + def get_evaluation_run_results( + run_id: str, + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + session: Session = Depends(get_session) + ): + results = serializer.get_evaluation_run_results( + session, + run_id, + skip=skip, + limit=limit, + ) + return results + + @app.get("/api/all-evaluations", response_model=List[EvaluationPublic]) + def get_all_evaluations( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + session: Session = Depends(get_session) + ): + # Get all evaluations ordered by creation date, without deduplication + query = ( + select(SerializedEvaluation) + .order_by(SerializedEvaluation.created_at.desc()) + .offset(skip) + .limit(limit) + ) + results = session.exec(query).all() + return list(results) + + @app.get("/api/dataset/{dataset_id}") + def get_dataset( + dataset_id: str, + session: Session = Depends(get_session) + ): + if not serializer.blob_store: + raise HTTPException(status_code=400, detail="Blob storage not configured") + + try: + # Get the blob data + blob_data = serializer.blob_store.retrieve_blob(dataset_id) + + + # Check if size is under 5MB + if len(blob_data) > 5 * 1024 * 1024: # 5MB in bytes + raise HTTPException( + status_code=413, + detail="Dataset too large to preview (>5MB)" + ) + + # Decode and parse JSON + dataset_json = json.loads(blob_data.decode('utf-8')) + + return { + "size": len(blob_data), + "data": dataset_json + } + + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Dataset not found") + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="Invalid JSON data") + except Exception as e: + logger.error(f"Error retrieving dataset: {str(e)}") + raise HTTPException(status_code=500, detail="Internal server error") + + return app diff --git a/src/ell/types/lmp.py b/src/ell/types/lmp.py index c67df8a0a..383637d5e 100644 --- a/src/ell/types/lmp.py +++ b/src/ell/types/lmp.py @@ -4,5 +4,6 @@ class LMPType(str, enum.Enum): LM = "LM" TOOL = "TOOL" - MULTIMODAL = "MULTIMODAL" - OTHER = "OTHER" + LABELER = "LABELER" + FUNCTION = "FUNCTION" + OTHER = "OTHER" \ No newline at end of file diff --git a/src/ell/util/WARNING b/src/ell/util/WARNING new file mode 100644 index 000000000..7f40bc3c8 --- /dev/null +++ b/src/ell/util/WARNING @@ -0,0 +1,23 @@ +THIS MODULE WILL BE DEPRECATED WE ARE MOVING TO LOCAL UTIL MODULES +``` +project/ +│ +├── module1/ +│ ├── __init__.py +│ ├── core.py +│ └── util.py +│ +├── module2/ +│ ├── __init__.py +│ ├── main.py +│ └── util.py +│ +├── module3/ +│ ├── __init__.py +│ ├── handler.py +│ └── util.py +│ +└── common/ + ├── __init__.py + └── util.py +``` \ No newline at end of file diff --git a/src/ell/util/closure_util.py b/src/ell/util/closure_util.py index da2a28c7e..80ca6d24a 100644 --- a/src/ell/util/closure_util.py +++ b/src/ell/util/closure_util.py @@ -1,4 +1,6 @@ import ast +from functools import lru_cache +import hashlib import importlib import os import black @@ -157,4 +159,14 @@ def format_source(source: str) -> str: return black.format_str(source, mode=black.Mode()) except: # If Black formatting fails, return the original source - return source \ No newline at end of file + return source + + +def ido(f): + if not hasattr( f.__ell_func__, "__ell_hash__"): + f.__ell_force_closure__() + return f.__ell_func__.__ell_hash__ + +@lru_cache(maxsize=128) +def hsh(x): + return hashlib.md5(x.encode()).hexdigest() \ No newline at end of file diff --git a/src/ell/util/serialization.py b/src/ell/util/serialization.py index 855004caa..48e41bcb4 100644 --- a/src/ell/util/serialization.py +++ b/src/ell/util/serialization.py @@ -1,6 +1,7 @@ # Global converter import base64 +from datetime import datetime, timezone import hashlib from io import BytesIO import json @@ -66,8 +67,6 @@ def unstructure_lstr(obj): ) - - def get_immutable_vars(vars_dict): converter = cattrs.Converter() @@ -97,13 +96,19 @@ def compute_state_cache_key(ipstr, fn_closure): return state_cache_key + + +def serialize_object(obj): + serialized_obj = pydantic_ltype_aware_cattr.unstructure(obj) + jstr = json.dumps(serialized_obj, sort_keys=True, default=repr, ensure_ascii=False) + return jstr + + def prepare_invocation_params(params): invocation_params = params - - cleaned_invocation_params = pydantic_ltype_aware_cattr.unstructure(invocation_params) # Thisis because we wneed the caching to work on the hash of a cleaned and serialized object. - jstr = json.dumps(cleaned_invocation_params, sort_keys=True, default=repr, ensure_ascii=False) + jstr = serialize_object(invocation_params) consumes = set() import re @@ -158,4 +163,14 @@ def is_immutable_variable(value): if isinstance(value, (tuple, frozenset)): return all(is_immutable_variable(item) for item in value) - return False \ No newline at end of file + return False + + +def utc_now() -> datetime: + """ + Returns the current UTC timestamp. + Serializes to ISO-8601. + """ + return datetime.now(tz=timezone.utc) + + diff --git a/src/ell/util/tqdm.py b/src/ell/util/tqdm.py new file mode 100644 index 000000000..e18fdd2c4 --- /dev/null +++ b/src/ell/util/tqdm.py @@ -0,0 +1,37 @@ +"""Copyright (c) 2024, the tiny corp""" +import math +import shutil +import sys +import time +from typing import Optional + + +class tqdm: + def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100): + self.iterable, self.disable, self.unit, self.unit_scale, self.rate = iterable, disable, unit, unit_scale, rate + self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total + self.set_description(desc) + self.update(0) + def __iter__(self): + for item in self.iterable: + yield item + self.update(1) + self.update(close=True) + def set_description(self, desc:str): self.desc = f"{desc}: " if desc else "" + def update(self, n:int=0, close:bool=False): + self.n, self.i = self.n+n, self.i+1 + if self.disable or (not close and self.i % self.skip != 0): return + prog, elapsed, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns + if self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1) + def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x) + def SI(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00' + prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}' + elapsed_text = HMS(elapsed) + (f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else '') + it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?" + suf = f'{prog_text} [{elapsed_text}, {it_text}{self.unit}/s]' + sz = max(ncols-len(self.desc)-3-2-2-len(suf), 1) + bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf + print(bar[:ncols+1], flush=True, end='\n'*close, file=sys.stderr) + +class trange(tqdm): + def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs) \ No newline at end of file diff --git a/src/ell/util/verbosity.py b/src/ell/util/verbosity.py index af4ca6986..f15b54468 100644 --- a/src/ell/util/verbosity.py +++ b/src/ell/util/verbosity.py @@ -170,7 +170,7 @@ def model_usage_logger_pre( print(f"{PIPE_COLOR}║ {BOLD}Prompt:{RESET}") print(f"{PIPE_COLOR}╟{'─' * (terminal_width - 2)}╢{RESET}") - max_role_length = max(len("assistant"), max(len(message.role) for message in messages)) + max_role_length = max(len("assistant"), max(len(message.role) for message in messages) if len(messages) > 0 else 0) print_wrapped_messages(messages, max_role_length, color) def model_usage_logger_post_start(color: str = "", n: int = 1): diff --git a/tests/test_autocommit_model.py b/tests/test_autocommit_model.py index 725b8fe21..54f762e06 100644 --- a/tests/test_autocommit_model.py +++ b/tests/test_autocommit_model.py @@ -57,18 +57,18 @@ def write_a_chord_progression_for_song(genre: Optional[str], key : Optional[str] import os -if os.environ.get("OPENAI_API_KEY"): - from ell.util.differ import write_commit_message_for_diff - ell.init(verbose=True, autocommit_model="gpt-4o-mini") - # ell.init(verbose=True, autocommit_model="claude-3-haiku-20240307") - def test_commit_message_1(): - # test 1 - (response, *args) = write_commit_message_for_diff(test1_v1, test1_v2) - print(response) - - # test 2 - (response, *args) = write_commit_message_for_diff(test2_v1, test2_v2) - print(response) +# if os.environ.get("OPENAI_API_KEY"): +# from ell.util.differ import write_commit_message_for_diff +# ell.init(verbose=True, autocommit_model="gpt-4o-mini") +# # ell.init(verbose=True, autocommit_model="claude-3-haiku-20240307") +# def test_commit_message_1(): +# # test 1 +# (response, *args) = write_commit_message_for_diff(test1_v1, test1_v2) +# print(response) + +# # test 2 +# (response, *args) = write_commit_message_for_diff(test2_v1, test2_v2) +# print(response) ### --BEFORE PROMPT CHANGES-- ### diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py new file mode 100644 index 000000000..21ee3ba2b --- /dev/null +++ b/tests/test_evaluation.py @@ -0,0 +1,101 @@ +import pytest + +import ell.lmp.function +from datetime import datetime +from ell.evaluation.evaluation import Evaluation, EvaluationRun +from ell.evaluation.results import EvaluationResults +from ell.configurator import config + +# Mock classes and functions +@ell.lmp.function.function() +def MockLMP(param=None, api_params=None): + return "mock_output" + +@ell.lmp.function.function() +def paramless(api_params=None): + return "mock_output" + +@pytest.fixture +def mock_evaluation(): + return Evaluation( + name="test_evaluation", + n_evals=10, + samples_per_datapoint=2, + metrics={"mock_metric": lambda x, y: 1.0}, + # annotations={"mock_annotation": lambda x, y: "annotation"}, + criterion=lambda x, y: True + ) + + +def test_evaluation_initialization(mock_evaluation): + assert mock_evaluation.name == "test_evaluation" + assert mock_evaluation.n_evals == 10 + assert mock_evaluation.samples_per_datapoint == 2 + assert "mock_metric" in mock_evaluation.metrics + # assert "mock_annotation" in mock_evaluation.annotations + + +def test_evaluation_run_process_single(mock_evaluation): + data_point = {"input": {"param": "test_input"}} + lmp = MockLMP + required_params = False + + results = mock_evaluation._process_single(data_point, lmp, {}, required_params) + assert len(results) == 1 + assert results[0]().output[0] == "mock_output" + +def test_evaluation_run(mock_evaluation): + lmp = paramless + + evaluation_run = mock_evaluation.run(lmp, n_workers=1, verbose=False) + assert evaluation_run.n_evals == 10 + assert evaluation_run.samples_per_datapoint == 2 + +def test_evaluation_run_with_different_inputs(mock_evaluation): + # Test with list input + data_point = {"input": ["test_input1", "test_input2"]} + lmp = MockLMP + lmp_params = {} + required_params = True + + results = mock_evaluation._process_single(data_point, lmp, lmp_params, required_params) + assert len(results) == 1 + assert results[0]().output[0] == "mock_output" + + # Test with no input + data_point = {} + results = mock_evaluation._process_single(data_point, lmp, lmp_params, required_params) + assert len(results) == 1 + assert results[0]().output[0] == "mock_output" + +def test_evaluation_run_with_invalid_input(mock_evaluation): + data_point = {"input": 123} # Invalid input type + lmp = MockLMP + required_params = True + + with pytest.raises(ValueError, match="Invalid input type: "): + mock_evaluation._process_single(data_point, lmp, {}, required_params) + +def test_evaluation_run_with_missing_params(mock_evaluation): + data_point = {"input": {"param": "test_input"}} + lmp = MockLMP + lmp_params = {} # Missing required params + required_params = False + + results = mock_evaluation._process_single(data_point, lmp, lmp_params, required_params) + assert len(results) == 1 + assert results[0]().output[0] == "mock_output" + + +def test_evaluation_run_with_criterion(mock_evaluation): + # Test with a criterion + data_point = {"input": {"param": "test_input"}} + lmp = MockLMP + required_params = False + + results = mock_evaluation._process_single(data_point, lmp, {}, required_params) + assert len(results) == 1 + assert results[0]().criterion[0] == True + + + diff --git a/tests/test_migrations.py b/tests/test_migrations.py index ebb0cf190..5e0faf13f 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -1,3 +1,4 @@ +import json import os import tempfile from pathlib import Path @@ -8,7 +9,15 @@ from ell.stores.migrations import init_or_migrate_database, get_alembic_config from alembic import command from alembic.config import Config - +# Compare schemas after recursively sorting all lists +def sort_lists_recursively(d): + if isinstance(d, dict): + return {k: sort_lists_recursively(v) for k, v in d.items()} + elif isinstance(d, list): + return list(sorted((sort_lists_recursively(x) for x in d), key=lambda x: json.dumps(x, sort_keys=True))) + else: + return d + def get_table_metadata(engine, exclude_tables=None): """Helper to get table metadata in a consistent format""" inspector = inspect(engine) @@ -83,8 +92,11 @@ def test_empty_db_migration(temp_db_url): # Get schema created by SQLModel sqlmodel_metadata = get_table_metadata(engine2) - # Compare schemas - assert migrated_metadata == sqlmodel_metadata + + + sorted_migrated = sort_lists_recursively(migrated_metadata) + sorted_sqlmodel = sort_lists_recursively(sqlmodel_metadata) + assert sorted_migrated == sorted_sqlmodel def test_existing_tables_no_alembic(temp_db_url): """Test database with existing tables but no alembic version table""" @@ -103,7 +115,8 @@ def test_existing_tables_no_alembic(temp_db_url): with engine.connect() as conn: result = conn.execute(text("SELECT version_num FROM ell_alembic_version")) version = result.scalar() - assert version == "4524fb60d23e" # Initial migration version + # Get current head version from alembic config + assert version == "f6528d04bbbd" def test_multiple_migrations(temp_db_url): """Test running multiple migrations in sequence""" @@ -179,5 +192,7 @@ def test_pure_migration_matches_metadata(temp_db_url): # Get schema created by SQLModel sqlmodel_metadata = get_table_metadata(engine2) - - assert migration_metadata == sqlmodel_metadata + # Compare schemas after recursively sorting + migration_metadata_sorted = sort_lists_recursively(migration_metadata) + sqlmodel_metadata_sorted = sort_lists_recursively(sqlmodel_metadata) + assert migration_metadata_sorted == sqlmodel_metadata_sorted diff --git a/tests/test_results.py b/tests/test_results.py new file mode 100644 index 000000000..5796b6a21 --- /dev/null +++ b/tests/test_results.py @@ -0,0 +1,37 @@ +from ell.evaluation.results import _ResultDatapoint, EvaluationResults, Label +from ell.stores.models.evaluations import EvaluationLabelerType +import numpy as np + +def test_evaluation_results_from_rowar_results(): + # Test that from_rowar_results correctly converts rowar_results to EvaluationResults + rowar_results = [ + _ResultDatapoint( + output=("output1", "id1"), + labels=[ + Label(name="metric1", type=EvaluationLabelerType.METRIC, label=(0.95, "id1")), + Label(name="annotation1", type=EvaluationLabelerType.ANNOTATION, label=("anno1", "id1")), + Label(name="criterion", type=EvaluationLabelerType.CRITERION, label=(True, "id1")) + ] + ), + _ResultDatapoint( + output=("output2", "id2"), + labels=[ + Label(name="metric1", type=EvaluationLabelerType.METRIC, label=(0.85, "id2")), + Label(name="annotation1", type=EvaluationLabelerType.ANNOTATION, label=("anno2", "id2")), + Label(name="criterion", type=EvaluationLabelerType.CRITERION, label=(False, "id2")) + ] + ), + ] + results = EvaluationResults.from_rowar_results(rowar_results) + + assert results.outputs == ["output1", "output2"] + assert (results.metrics["metric1"] == np.array([0.95, 0.85])).all() + assert (results.annotations["annotation1"] == np.array(["anno1", "anno2"])).all() + assert (results.criterion == np.array([True, False])).all() + + # Check invocation_ids + assert results.invocation_ids is not None + assert results.invocation_ids.outputs == ["id1", "id2"] + assert (results.invocation_ids.metrics["metric1"] == np.array(["id1", "id2"])).all() + assert (results.invocation_ids.annotations["annotation1"] == np.array(["id1", "id2"])).all() + assert (results.invocation_ids.criterion == np.array(["id1", "id2"])).all() diff --git a/tests/test_sql_store.py b/tests/test_sql_store.py index dfdbcb2f9..3adecc490 100644 --- a/tests/test_sql_store.py +++ b/tests/test_sql_store.py @@ -4,7 +4,8 @@ from sqlalchemy import Engine, create_engine, func from ell.types.lmp import LMPType -from ell.stores.studio import utc_now +from ell.util.serialization import utc_now + @pytest.fixture def in_memory_db():