Skip to content

Commit

Permalink
update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
semio committed May 18, 2024
1 parent 650bb90 commit 7352218
Showing 1 changed file with 59 additions and 35 deletions.
94 changes: 59 additions & 35 deletions automation-api/yival_experiments/notebooks/result_data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,25 @@
# ## prepare data

# +
# results from experiment 1:
result_1 = pd.read_csv('./data/Gapminder AI evaluations - Results202404.csv')

# results from experiment 2:
result_2 = pd.read_csv('./data/Gapminder AI evaluations - Latest Results.csv')

result = pd.concat([result_1, result_2], ignore_index=True)
# results to be analyzed
# manually download from AI eval spreadsheet.
result = pd.read_csv('./data/Gapminder AI evaluations - Master Output.csv')
# -

# load ai eval spreadsheet
ai_eval_sheet = read_ai_eval_spreadsheet()

result

# cleanup
# cleanup
result.columns = result.columns.map(lambda x: x.lower().replace(' ', '_'))

result



# + magic_args="--save result_to_analyze " language="sql"
# select
# select
# *,
# CASE
# WHEN ((Result = 'correct')) THEN (3)
Expand All @@ -57,7 +53,7 @@
# from result where model_configuration_id not like 'mc026'

# + magic_args="--with result_to_analyze --save result_chn_prompt_renamed" language="sql"
# select
# select
# * exclude (prompt_variation_id),
# replace(prompt_variation_id, '_zh', '') as prompt_variation_id
# from result_to_analyze
Expand Down Expand Up @@ -123,7 +119,7 @@

# + magic_args="--save q_and_t" language="sql"
# -- only keep question id and topic list.
# select
# select
# question_id,
# first(human_wrong_percentage) as human_wrong_percentage,
# first(topic_list) as topic_list,
Expand Down Expand Up @@ -232,7 +228,7 @@
# other topics

# + magic_args="--save res_with_other_topics" language="sql"
# select
# select
# r.*,
# unnest(q.other_topics) as topic
# from result_to_analyze r left join q_and_t q on r.question_id = q.question_id
Expand Down Expand Up @@ -383,11 +379,11 @@
# q.sdg_topic,
# q.other_topics,
# q.human_wrong_percentage,
# case
# case
# when q.sdg_topic is null then other_topics
# else list_append(q.other_topics, q.sdg_topic)
# end as all_topics
#
#
# from
# model_question_stat_all r
# left join q_and_t q on
Expand Down Expand Up @@ -496,7 +492,7 @@
# mean (indecisive_rate) as indecisive_rate,
# median (variance) as variance
# from
# (select
# (select
# * exclude (all_topics, sdg_topic, other_topics),
# unnest(all_topics) as topic
# from topic_prompt_family_stat)
Expand All @@ -522,19 +518,19 @@
# select * from model_topic_stat;

# + magic_args="model_topic_human_diff <<" language="sql"
# select
# select
# question_id,
# model_configuration_id,
# (100 - correct_rate) as ai_wrong_percentage,
# human_wrong_percentage,
# ai_wrong_percentage - human_wrong_percentage as diff,
# sdg_topic,
# other_topics
# from model_topic_stat
# from model_topic_stat
# where diff > 0
# order by
# "sdg_topic",
# cast(other_topics as varchar),
# order by
# "sdg_topic",
# cast(other_topics as varchar),
# "model_configuration_id"
# -

Expand All @@ -547,19 +543,19 @@


# + magic_args="model_topic_monkey_diff <<" language="sql"
# select
# select
# question_id,
# model_configuration_id,
# (100 - correct_rate) as ai_wrong_percentage,
# 100 * (2/3) as monkey_wrong_percentage,
# ai_wrong_percentage - monkey_wrong_percentage as diff,
# sdg_topic,
# other_topics
# from model_topic_stat
# from model_topic_stat
# where diff > 0
# order by
# "sdg_topic",
# cast(other_topics as varchar),
# order by
# "sdg_topic",
# cast(other_topics as varchar),
# "model_configuration_id"
# -

Expand All @@ -575,13 +571,13 @@
# summary stats for human and monkey vs ai

# + magic_args="summary_human_ai <<" language="sql"
# select
# select
# question_id,
# count(*) as num_of_models,
# mean(diff) as average_diff,
# from
# from
# model_topic_human_diff_df
# group by
# group by
# question_id
# ORDER BY
# num_of_models desc,
Expand All @@ -595,13 +591,13 @@


# + magic_args="summary_monkey_ai <<" language="sql"
# select
# select
# question_id,
# count(*) as num_of_models,
# mean(diff) as average_diff,
# from
# from
# model_topic_monkey_diff_df
# group by
# group by
# question_id
# ORDER BY
# num_of_models desc,
Expand All @@ -628,6 +624,38 @@



# # for climate study questions

climate_questions = ["5", "59", "85", "86", "1524", "1672", "1691", "1706", "1717", "1730", "1731", "1737", "1738", "1741", "1761"]

# + magic_args="--save result_climate_questions" language="sql"
# select
# *
# from result_to_analyze
# where list_contains({{climate_questions}}, question_id) AND model_configuration_id != 'mc028';
# -

# climate_raw_result = %sql select * from result_climate_questions

climate_raw_result.DataFrame().to_csv('./data/outputs/climate_raw.csv', index=False)

# + magic_args="--save correct_by_prompt climate_res << " language="sql"
# select
# model_configuration_id,
# prompt_variation_id,
# count(*),
#
# from result_climate_questions
# where result = 'correct'
# group by model_configuration_id, prompt_variation_id
# -

climate_res.DataFrame().to_csv("./data/outputs/climate_study.csv")







# # Check raw outputs
Expand Down Expand Up @@ -669,7 +697,3 @@
err2.DataFrame()

err2.DataFrame().shape




0 comments on commit 7352218

Please sign in to comment.