-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/adaamko/POTATO into main
- Loading branch information
Showing
5 changed files
with
307 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
import json | ||
import os | ||
import numpy as np | ||
import logging | ||
from typing import List, Dict | ||
from argparse import ArgumentParser, ArgumentError | ||
|
||
from sklearn.model_selection import train_test_split | ||
from xpotato.dataset.explainable_dataset import ExplainableDataset | ||
from xpotato.models.trainer import GraphTrainer | ||
from xpotato.dataset.utils import save_dataframe | ||
|
||
|
||
def read_json(file_path: str) -> Dict[str, List[Dict[str, List[str]]]]: | ||
data_by_target = {} | ||
with open(file_path) as dataset: | ||
data = json.load(dataset) | ||
for post in data.values(): | ||
sentence = " ".join(post["post_tokens"]) | ||
targets = {} | ||
target = [] | ||
labels = {} | ||
for annotation in post["annotators"]: | ||
if annotation["label"] not in labels: | ||
labels[annotation["label"]] = 1 | ||
else: | ||
labels[annotation["label"]] += 1 | ||
for target_i in annotation["target"]: | ||
if target_i not in targets: | ||
targets[target_i] = 1 | ||
else: | ||
targets[target_i] += 1 | ||
if len(labels) != len(post["annotators"]): | ||
label = max(labels.items(), key=lambda x: x[1])[0] | ||
if label == "normal": | ||
target = ["None"] | ||
else: | ||
target = [t[0] for t in targets.items() if t[1] > 1] | ||
rationale = [] | ||
if len(post["rationales"]) > 0: | ||
rats = [ | ||
n | ||
for n in post["rationales"] | ||
if len(n) == len(post["post_tokens"]) | ||
] | ||
rationale = np.round(np.mean(rats, axis=0), decimals=0).tolist() | ||
if len(target) == 1: | ||
if target[0] not in data_by_target: | ||
data_by_target[target[0]] = [] | ||
data_by_target[target[0]].append( | ||
{ | ||
"tokens": post["post_tokens"], | ||
"sentence": sentence, | ||
"rationale": rationale, | ||
} | ||
) | ||
return data_by_target | ||
|
||
|
||
def process(data_path: str, groups: List[str], target: str, just_none: bool): | ||
running_groups = ["none", target] if just_none else groups | ||
sentences = [] | ||
for group in running_groups: | ||
group_path = os.path.join(data_path, f"{group}.json") | ||
if os.path.isfile(group_path): | ||
with open(group_path, "r") as group_json: | ||
group_list = json.load(group_json) | ||
sentences += [ | ||
( | ||
example["sentence"], | ||
"None" if group != target else target.capitalize(), | ||
[ | ||
tok | ||
for (rat, tok) in zip( | ||
example["rationale"], example["tokens"] | ||
) | ||
if rat == 1 | ||
] | ||
if group == target | ||
else [], | ||
) | ||
for example in group_list | ||
] | ||
else: | ||
logging.warning(f"Skipping {group}, because {group_path} does not exist.") | ||
|
||
potato_dataset = ExplainableDataset( | ||
sentences, label_vocab={"None": 0, f"{target.capitalize()}": 1}, lang="en" | ||
) | ||
potato_dataset.set_graphs(potato_dataset.parse_graphs(graph_format="ud")) | ||
df = potato_dataset.to_dataframe() | ||
trainer = GraphTrainer(df) | ||
features = trainer.prepare_and_train() | ||
train, val = train_test_split(df, test_size=0.2, random_state=1234) | ||
save_dataframe(train, os.path.join(data_path, "train.tsv")) | ||
save_dataframe(val, os.path.join(data_path, "val.tsv")) | ||
|
||
with open("features.json", "w+") as f: | ||
json.dump(features, f) | ||
|
||
|
||
if __name__ == "__main__": | ||
target_groups = [ | ||
"african", | ||
"arab", | ||
"asian", | ||
"caucasian", | ||
"christian", | ||
"disability", | ||
"economic", | ||
"hindu", | ||
"hispanic", | ||
"homosexual", | ||
"indian", | ||
"islam", | ||
"jewish", | ||
"men", | ||
"other", | ||
"refugee", | ||
"women", | ||
] | ||
argparser = ArgumentParser() | ||
argparser.add_argument( | ||
"--data_path", "-d", help="Path to the json dataset.", required=True | ||
) | ||
argparser.add_argument( | ||
"--mode", | ||
"-m", | ||
help="Mode to start the program. Modes:" | ||
"\n\t- distinct: " | ||
"cut the dataset.json into distinct categorical json files" | ||
"\n\t- process: " | ||
"load the chosen category as the target and every other one as non-target" | ||
"\n\t- both: " | ||
"run the distinct and the process after eachother", | ||
default="both", | ||
choices=["distinct", "process", "both"], | ||
) | ||
argparser.add_argument( | ||
"--target", | ||
"-t", | ||
help="The target group to set as our category.", | ||
choices=target_groups, | ||
) | ||
argparser.add_argument( | ||
"--just_none", | ||
"-n", | ||
action="store_true", | ||
help="Use only the normal texts as counter.", | ||
) | ||
args = argparser.parse_args() | ||
|
||
if args.mode != "distinct" and args.target is None: | ||
raise ArgumentError( | ||
"Target is not given! If you want to produce a POTATO dataset " | ||
"(by running this code in process or both mode), you should specify the target." | ||
) | ||
|
||
if args.mode != "process": | ||
dataset = ( | ||
args.data_path | ||
if os.path.isfile(args.data_path) | ||
else os.path.join(args.data_path, "dataset.json") | ||
) | ||
if not os.path.isfile(dataset): | ||
raise ArgumentError( | ||
"The specified data path is not a file and does not contain a dataset.json file. " | ||
"If your file has a different name, please specify." | ||
) | ||
dir_path = os.path.dirname(dataset) | ||
dt_by_target = read_json(dataset) | ||
for name, list_of_dicts in dt_by_target.items(): | ||
with open(os.path.join(dir_path, f"{name.lower()}.json"), "w") as json_file: | ||
json.dump(list_of_dicts, json_file, indent=4) | ||
|
||
if args.mode == "both": | ||
process( | ||
data_path=dir_path, | ||
groups=target_groups, | ||
target=args.target, | ||
just_none=args.just_none, | ||
) | ||
|
||
else: | ||
dir_path = ( | ||
os.path.dirname(args.data_path) | ||
if os.path.isfile(args.data_path) | ||
else args.data_path | ||
) | ||
process( | ||
data_path=dir_path, | ||
groups=target_groups, | ||
target=args.target, | ||
just_none=args.just_none, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Dict, List, Tuple | ||
|
||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
from xpotato.dataset.dataset import Dataset | ||
from xpotato.dataset.explainable_sample import ExplainableSample | ||
from xpotato.graph_extractor.graph import PotatoGraph | ||
|
||
|
||
class ExplainableDataset(Dataset): | ||
def __init__( | ||
self, | ||
examples: List[Tuple[str, str]] = None, | ||
label_vocab: Dict[str, int] = {}, | ||
lang="en", | ||
path=None, | ||
binary=False, | ||
cache_dir=None, | ||
cache_fn=None, | ||
) -> None: | ||
super().__init__( | ||
examples=examples, | ||
label_vocab=label_vocab, | ||
lang=lang, | ||
path=path, | ||
binary=binary, | ||
cache_dir=cache_dir, | ||
cache_fn=cache_fn, | ||
) | ||
|
||
def read_dataset( | ||
self, | ||
examples: List[Tuple[str, str]] = None, | ||
path: str = None, | ||
binary: bool = False, | ||
) -> List[ExplainableSample]: | ||
if examples: | ||
return [ExplainableSample(example) for example in examples] | ||
elif path: | ||
if binary: | ||
df = pd.read_pickle(path) | ||
graphs_str = self.prune_graphs(df.graph.tolist()) | ||
df.drop(columns=["graph"], inplace=True) | ||
df["graph"] = graphs_str | ||
else: | ||
df = pd.read_csv(path, sep="\t") | ||
samples = [ | ||
ExplainableSample( | ||
(example["text"], example["label"], example["rationale"]), | ||
potato_graph=PotatoGraph(graph_str=example["graph"]), | ||
label_id=example["label_id"], | ||
) | ||
for _, example in tqdm(df.iterrows()) | ||
] | ||
self.graphs = [sample.potato_graph.graph for sample in samples] | ||
return samples | ||
else: | ||
raise ValueError("No examples or path provided") | ||
|
||
def to_dataframe(self, as_penman: bool = False) -> pd.DataFrame: | ||
df = pd.DataFrame( | ||
{ | ||
"text": [sample.text for sample in self._dataset], | ||
"label": [sample.label for sample in self._dataset], | ||
"label_id": [ | ||
self.label_vocab[sample.label] if sample.label else None | ||
for sample in self._dataset | ||
], | ||
"rationale": [sample.rationale for sample in self._dataset], | ||
"graph": [ | ||
str(sample.potato_graph).replace("\n", " ") | ||
if as_penman | ||
else sample.potato_graph.graph | ||
for sample in self._dataset | ||
], | ||
} | ||
) | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Tuple, Dict | ||
import networkx as nx | ||
from xpotato.dataset.sample import Sample | ||
|
||
from xpotato.graph_extractor.graph import PotatoGraph | ||
|
||
|
||
class ExplainableSample(Sample): | ||
def __init__( | ||
self, | ||
example: Tuple[str, str], | ||
potato_graph: PotatoGraph = None, | ||
label_id: int = None, | ||
) -> None: | ||
super().__init__(example=example, potato_graph=potato_graph, label_id=label_id) | ||
self.rationale = example[2] | ||
|
||
def _postprocess(self, graph: PotatoGraph) -> PotatoGraph: | ||
rationale_bool = [] | ||
if len(self.rationale) != 0: | ||
for node, attr in graph.graph.nodes(data=True): | ||
if attr["name"] in self.rationale: | ||
rationale_bool.append(True) | ||
else: | ||
rationale_bool.append(False) | ||
nx.set_node_attributes(graph.graph, rationale_bool, "rationale") | ||
return graph | ||
|
||
def set_graph(self, graph: PotatoGraph) -> None: | ||
self.potato_graph = self._postprocess(graph) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters