From 1e03600d7a016c5f675433829e5f67f64f9d2364 Mon Sep 17 00:00:00 2001 From: Reaper0x1 Date: Thu, 24 Oct 2024 16:59:03 +0200 Subject: [PATCH] Initial commit --- LICENSE | 201 +++ README.md | 2 + classifiers.py | 23 + custom_dataset.py | 15 + models.py | 121 ++ preprocess_data.py | 108 ++ requirements.txt | 63 + train.py | 157 ++ train_evaluate_models.ipynb | 3391 +++++++++++++++++++++++++++++++++++ utils.py | 93 + 10 files changed, 4174 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 classifiers.py create mode 100644 custom_dataset.py create mode 100644 models.py create mode 100644 preprocess_data.py create mode 100644 requirements.txt create mode 100644 train.py create mode 100644 train_evaluate_models.ipynb create mode 100644 utils.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3ec6467 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# VanillaGAN +Experiment with Vanilla GAN. diff --git a/classifiers.py b/classifiers.py new file mode 100644 index 0000000..47be3dd --- /dev/null +++ b/classifiers.py @@ -0,0 +1,23 @@ +import torch.nn as nn +from torch_geometric.nn import GCN, global_mean_pool + + +class GCNModel(nn.Module): + def __init__( + self, in_channels, hidden_channels, out_channels, num_layers=2, dropout=0.6 + ): + super().__init__() + self.gnc_model = GCN( + in_channels, hidden_channels, num_layers, dropout=dropout, jk="cat" + ) + self.fc = nn.Linear(hidden_channels, out_channels) + self.softmax = nn.Softmax(dim=1) + + def forward(self, x, edge_index, edge_weight=None, batch=None): + out = self.gnc_model(x, edge_index, edge_weight=edge_weight, batch=batch) + # print("GCN output shape: ", out.shape) + out = global_mean_pool(out, batch) # [batch_size, hidden_channels] + # print("global mean pooling output shape: ", out.shape) + out = self.softmax(self.fc(out)) + # print("softmax output shape: ", out.shape) + return out diff --git a/custom_dataset.py b/custom_dataset.py new file mode 100644 index 0000000..d902f77 --- /dev/null +++ b/custom_dataset.py @@ -0,0 +1,15 @@ +from torch.utils.data import Dataset + + +class CustomSequenceDataset(Dataset): + def __init__(self, data, labels): + self.data = data + self.labels = labels + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data[idx] + label = self.labels[idx] + return sample, label diff --git a/models.py b/models.py new file mode 100644 index 0000000..21f5285 --- /dev/null +++ b/models.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Custom weights initialization for nn.Linear layers +def init_weights(m): + if isinstance(m, nn.Linear): + # Initialize the weights with Xavier (Glorot) uniform initialization + torch.nn.init.xavier_uniform_(m.weight) + # Set biases to a small constant value, e.g., 0.01 + if m.bias is not None: + m.bias.data.fill_(0.01) + + +class Generator(nn.Module): + def __init__( + self, + latent_dim, + seq_len, + hidden_dim, + output_dim, + embed_dim, + dropout, + conditional_info=False, + num_classes=2, + ): + super().__init__() + self.output_dim = output_dim + self.seq_len = seq_len + self.latent_dim = latent_dim + self.input_dim = self.latent_dim + self.conditional_info = conditional_info + if self.conditional_info: + if num_classes is None: + raise ValueError( + "num_classes must be provided if conditional_info is True" + ) + + self.input_dim = latent_dim + embed_dim + self.embedding = nn.Embedding(num_classes, embed_dim) + + self.model = nn.Sequential( + nn.Linear(self.input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, seq_len * output_dim), + nn.ReLU(), # added ReLU here to avoid invalid logits to the gumbel softmax + ) + + def forward(self, z, labels=None, tau=0.5, hard=False): + if self.conditional_info: + if labels is None: + raise ValueError("labels must be provided if conditional_info is True") + self.embedding.to(z.device) + z = torch.cat((z, self.embedding(labels)), dim=1) + + out = self.model(z) + # reshape output to [batch_size, seq_len, output_dim] + # i.e., [batch_size, seq_len, vocab_size] + gen_data = F.gumbel_softmax( + out.view(out.shape[0], self.seq_len, self.output_dim), + tau=tau, + hard=hard, + ) + return gen_data + + +class Discriminator(nn.Module): + def __init__( + self, + input_dim, + seq_len, + hidden_dim, + embed_dim, + dropout, + conditional_info=False, + num_classes=None, + ): + super().__init__() + + self.conditional_info = conditional_info + self.input_dim = input_dim * seq_len + if self.conditional_info: + if num_classes is None: + raise ValueError( + "num_classes must be provided if conditional_info is True" + ) + self.embedding = nn.Embedding(num_classes, embed_dim) + self.input_dim = self.input_dim + embed_dim + + self.model = nn.Sequential( + nn.Linear(self.input_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, 1), + nn.Sigmoid(), + ) + + def forward(self, x, labels=None): + # reshape x from [batch_size, seq_len, vocab_size] to [batch_size, seq_len * vocab_size] + x = x.reshape(x.shape[0], -1) + if self.conditional_info: + if labels is None: + raise ValueError("labels must be provided if conditional_info is True") + self.embedding.to(x.device) + # concatenate labels with input + c = self.embedding(labels) + # print(x.shape, "|", c.shape) + x = torch.cat((x, c), dim=-1) + out = self.model(x) + return out diff --git a/preprocess_data.py b/preprocess_data.py new file mode 100644 index 0000000..d58e8b3 --- /dev/null +++ b/preprocess_data.py @@ -0,0 +1,108 @@ +import os +from typing import List + +import networkx as nx +import torch +from torch_geometric.utils.convert import from_networkx + + +def read_subfolder(path: str, label): + sequences = [] + labels = [] + for filename in os.listdir(path): + if filename.endswith(".txt"): + with open(os.path.join(path, filename), "r") as f: + lines = f.readlines() + # print(lines) + + # there is only one line + line = list(map(int, lines[0].strip().split())) + sequences.append(line) + labels.append(label) + + return sequences, labels + + +def read_adfa_data(path: str): + """ + sub_dir: data folder full path (e.g. "/../.../ADFA/Training_Data_Master") + Read all files in the data folder and return a list of sequences. + """ + sequences = [] + labels = [] + + folder_name = "" + # check labels when reading attack data to be able to assign label to 1 + label = 0 # indicates benign + if "Attack_Data_Master" in path: + folder_name = "Attack_Data_Master" + label = 1 + for sub_folder in list(sorted(os.listdir(path))): + sub_folder_path = os.path.join(path, sub_folder) + if os.path.isdir(sub_folder_path): + # print("processing folder: ", sub_folder) + sub_folder_sequences, sub_folder_labels = read_subfolder( + sub_folder_path, label=label + ) + # print(f"len of sequences: {sub_folder} = {len(sub_folder_sequences)}") + + sequences.extend(sub_folder_sequences) + labels.extend(sub_folder_labels) + + # return a list of sequences, and labels for the attack data + print(f"Read {len(sequences)} sequences from {folder_name}") + return sequences, labels + + # return a list of sequences, and labels for the benign data + + sequences, labels = read_subfolder(path, label=label) + folder_name = path.split("/")[-1] + print(f"Read {len(sequences)} sequences from {folder_name}") + return sequences, labels + + +def sequence_to_graph(L: List, graph_label=None, vocab_size=None): + """ + Convert a sequence of (integers) to a graph. + Currently, we are using already encoded set of integers that represent system calls. If raw data is used, it will be necessary to encode the data first using a dictionary. + """ + # create a graph + G = nx.DiGraph() + for i in range(len(L) - 1): + edge = (L[i], L[i + 1]) + # if edge is not in the graph + if not G.has_edge(*edge): + G.add_edge(*edge, weight=1) + + # if edge is in the graph, just update the weight + else: + u, v = edge + G[u][v]["weight"] += 1 + + # add node attributes + node_attr = [] + # convert networkx graph to pyg graph data + + nodes = torch.tensor(list(G.nodes), dtype=torch.long) + node_attr = [nodes] + + # convert graph to pytorch geometric data + data = from_networkx(G) + data.x = node_attr + if graph_label is not None: + data.y = graph_label + + # validate the data + data.validate(raise_on_error=True) + + return G, data + + +def fetch_graph_data(sequences, labels, vocab_size=342): + graphs = [] + for i in range(len(sequences)): + nx_graph_G, pyg_graph_data = sequence_to_graph( + sequences[i], graph_label=labels[i], vocab_size=vocab_size + ) + graphs.append(pyg_graph_data) + return graphs diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7caf300 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,63 @@ +aiohttp==3.9.5 +aiosignal==1.3.1 +appnope==0.1.4 +asttokens==2.4.1 +attrs==23.2.0 +certifi==2024.2.2 +charset-normalizer==3.3.2 +comm==0.2.2 +contourpy==1.2.1 +cycler==0.12.1 +debugpy==1.8.1 +decorator==5.1.1 +executing==2.0.1 +filelock==3.14.0 +fonttools==4.53.0 +frozenlist==1.4.1 +fsspec==2024.5.0 +idna==3.7 +ipykernel==6.29.4 +ipython==8.25.0 +jedi==0.19.1 +Jinja2==3.1.4 +joblib==1.4.2 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +kiwisolver==1.4.5 +MarkupSafe==2.1.5 +matplotlib==3.9.0 +matplotlib-inline==0.1.7 +mpmath==1.3.0 +multidict==6.0.5 +nest-asyncio==1.6.0 +networkx==3.3 +numpy==1.26.4 +packaging==24.0 +parso==0.8.4 +pexpect==4.9.0 +pillow==10.3.0 +platformdirs==4.2.2 +prompt_toolkit==3.0.45 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +Pygments==2.18.0 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +pyzmq==26.0.3 +requests==2.32.3 +scikit-learn==1.5.0 +scipy==1.13.1 +six==1.16.0 +stack-data==0.6.3 +sympy==1.12.1 +threadpoolctl==3.5.0 +torch==2.3.0 +torch_geometric==2.5.3 +tornado==6.4 +tqdm==4.66.4 +traitlets==5.14.3 +typing_extensions==4.12.1 +urllib3==2.2.1 +wcwidth==0.2.13 +yarl==1.9.4 diff --git a/train.py b/train.py new file mode 100644 index 0000000..fd229fc --- /dev/null +++ b/train.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +from sklearn.metrics import f1_score +from torch.nn import functional as F +from torch.optim import Adam +from tqdm import tqdm + + +def train_gcn_model(model, train_loader, vocab_size=342, epochs=20, device="cpu"): + optimizer = Adam(model.parameters(), lr=0.001, weight_decay=5e-4, betas=(0.5, 0.99)) + criterion = nn.CrossEntropyLoss() + + train_losses = [] + # val_losses = [] + train_f1_scores = [] + + for epoch in range(epochs): + loss_ = 0.0 + f1_score_ = 0.0 + for data in tqdm(train_loader): + # print(data) + x = F.one_hot(data.x[0], num_classes=vocab_size).float() + x = x.to(device) + edge_index = data.edge_index.to(device) + edge_weight = data.weight.float().to(device) + # y = torch.LongTensor(data.y).to(device) + y = torch.LongTensor(data.y).to(device) + batch = data.batch.to(device) + + optimizer.zero_grad() + + pred = model(x, edge_index, edge_weight=edge_weight, batch=batch) + + # print("pred", pred.shape) + loss = criterion(pred, y) + + # print("loss", loss.item()) + loss.backward() + optimizer.step() + loss_ += loss.item() + + # Convert the tensors to numpy arrays + outputs = torch.argmax(torch.clone(pred).detach().cpu(), dim=1).numpy() + targets = torch.clone(y).detach().cpu().numpy() + # Compute the F1 score + f1 = f1_score(targets, outputs, zero_division=0.0) + f1_score_ += f1 + + # Compute the average loss and F1 score + n_batches = len(train_loader) + loss_ = loss_ / n_batches + f1_score_ = f1_score_ / n_batches + train_losses.append(loss_) + train_f1_scores.append(f1_score_) + + print( + f"Epoch {epoch + 1} / {epochs}, Loss: {loss_:.4f}, F1 Score: { f1_score_:.4f}" + ) + + return train_losses, train_f1_scores + + +def train_gan_model( + generator, + discriminator, + gen_optimizer, + disc_optimizer, + train_loader, + epochs, + vocab_size, + device="cpu", + tau=0.1, + hard=False, +): + generator_losses = [] + discriminator_losses = [] + for epoch in range(1, epochs + 1): + print(f"Epoch {epoch} / {epochs}") + + gloss = 0.0 + dloss = 0.0 + for data, labels in tqdm(train_loader): + # print(labels) + + # encode data to one-hot encoding + data = F.one_hot(data.long(), vocab_size).float().to(device) + + labels = labels.long().to(device) + + # Train Discriminator + disc_optimizer.zero_grad() + # generate random noise + latent_dim = generator.latent_dim + z = torch.randn(data.shape[0], latent_dim).to(device) + # generate fake data + gen_data = generator(z, labels) + + # print("gen data shape", gen_data.shape) + + # perform discrete categorical sampling + # gen_data = F.gumbel_softmax(gen_data, tau=temperature, hard=True) + + # feed real data to discriminator + disc_real = discriminator(data, labels) + + # feed fake data to discriminator + disc_fake = discriminator(gen_data.detach(), labels) + + # compute discriminator loss + disc_loss = -torch.mean(torch.log(disc_real) + torch.log(1 - disc_fake)) + + dloss += disc_loss.item() + + # backward pass + disc_loss.backward() + + # update discriminator weights + disc_optimizer.step() + + # Train Generator + gen_optimizer.zero_grad() + z = torch.randn(data.shape[0], latent_dim).to(device) + + # generate fake data + gen_data = generator(z, labels, tau=tau, hard=hard) + + # perform discrete categorical sampling + # print("gen data shape", gen_data.shape) + # gen_data = F.gumbel_softmax(gen_data, tau=tau, hard=False, dim=-1) + + # feed fake data to discriminator + disc_fake = discriminator(gen_data, labels) + + # compute generator loss: minmax GAN's generator loss + gen_loss = torch.mean(torch.log(1 - disc_fake)) + + # for non-saturated minmax GAN's generator loss :==> generator: maximize log(D(G(z))) + # gen_loss = -torch.mean(torch.log(disc_fake)) + + gloss += gen_loss.item() + + # backward pass + gen_loss.backward() + + # update generator weights + gen_optimizer.step() + + n_batches = len(train_loader) + dloss = torch.round(torch.tensor(dloss), decimals=4) + gloss = torch.round(torch.tensor(gloss), decimals=4) + + discriminator_losses.append(dloss) + generator_losses.append(gloss) + + print(f"D Loss: {disc_loss.item():.4f}, G Loss: {gen_loss.item():.4f}") + + return generator_losses, discriminator_losses diff --git a/train_evaluate_models.ipynb b/train_evaluate_models.ipynb new file mode 100644 index 0000000..090d102 --- /dev/null +++ b/train_evaluate_models.ipynb @@ -0,0 +1,3391 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import networkx as nx\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import numpy as np\n", + "from collections import Counter\n", + "from torch_geometric.loader import DataLoader\n", + "import torch.nn as nn\n", + "from torch.optim import Adam\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from pathlib import Path\n", + "import os\n", + "\n", + "# custom modules\n", + "from models import Generator, Discriminator\n", + "from custom_dataset import CustomSequenceDataset\n", + "from preprocess_data import read_adfa_data, sequence_to_graph, fetch_graph_data\n", + "from utils import get_device, plot_loss_curve, evaluate_gcn_model\n", + "from classifiers import GCNModel\n", + "from train import train_gan_model\n", + "from train import train_gcn_model\n", + "from models import init_weights" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# for reproducibility\n", + "import random\n", + "\n", + "SEED = 42\n", + "torch.manual_seed(SEED)\n", + "np.random.seed(SEED)\n", + "random.seed(SEED)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# read, preprocess, and fetch the ADFA datasets\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "data_folder = \"ADFA\" # make sure \"ADFA\" folder in the parent directory of this project's folder [ie., your codes]\n", + "current_directory = Path(os.getcwd())\n", + "parent_path = current_directory.parent.absolute()\n", + "# print(current_directory.parent.absolute())\n", + "\n", + "full_data_folder_path = os.path.join(parent_path, data_folder)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Read 833 sequences from Training_Data_Master\n", + "Read 4372 sequences from Validation_Data_Master\n", + "Read 746 sequences from Attack_Data_Master\n" + ] + }, + { + "data": { + "text/plain": [ + "(4165, 4165, 1786, 1786)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adfa_sub_folders = [\n", + " \"Training_Data_Master\",\n", + " \"Validation_Data_Master\",\n", + " \"Attack_Data_Master\",\n", + "]\n", + "\n", + "benign_training_data_path = os.path.join(full_data_folder_path, adfa_sub_folders[0])\n", + "benign_validation_data_path = os.path.join(full_data_folder_path, adfa_sub_folders[1])\n", + "\n", + "attack_data_path = os.path.join(full_data_folder_path, adfa_sub_folders[2])\n", + "\n", + "\n", + "# read the sub folders\n", + "benign_train_sequences, benign_train_labels = read_adfa_data(benign_training_data_path)\n", + "\n", + "benign_val_sequences, benign_val_labels = read_adfa_data(benign_validation_data_path)\n", + "\n", + "attack_sequences, attack_labels = read_adfa_data(attack_data_path)\n", + "\n", + "\n", + "data = benign_train_sequences + benign_val_sequences + attack_sequences\n", + "labels = benign_train_labels + benign_val_labels + attack_labels\n", + "\n", + "# perform 70 % training and 30% testing set split\n", + "\n", + "train_data, test_data, train_labels, test_labels = train_test_split(\n", + " data, labels, test_size=0.3, random_state=42, shuffle=True, stratify=labels\n", + ")\n", + "\n", + "len(train_data), len(train_labels), len(test_data), len(test_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "746" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(attack_sequences)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Counter({0: 3643, 1: 522})" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Data is imbalanced: 1666 samples of benign & 1051 samples of attack (malignant) will be used for training the GCN model\n", + "Counter(train_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Counter({0: 1562, 1: 224})" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Counter(test_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a custom dataset\n", + "batch_size = 64\n", + "# convert train_data sequences and val_data sequences into graphs\n", + "\n", + "train_graph_dataset = fetch_graph_data(train_data, train_labels)\n", + "val_graph_dataset = fetch_graph_data(test_data, test_labels)\n", + "\n", + "\n", + "# Create a data loader\n", + "graph_train_loader = DataLoader(\n", + " train_graph_dataset, batch_size=batch_size, shuffle=True\n", + ")\n", + "graph_val_loader = DataLoader(val_graph_dataset, batch_size=batch_size, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Data(edge_index=[2, 63], weight=[63], num_nodes=17, x=[1], y=0)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# you can see the graph data: it has edge_index, weight, num_nodes, y (label of the graph, i.e., malign or benign), and x (node features). Here, x is a list of integers representing the node features. They need to be converted to a one-hot vector of size vocab_size or vector of embedded representation.\n", + "train_graph_dataset[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 13.31it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 / 35, Loss: 0.5450, F1 Score: 0.0148\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.11it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2 / 35, Loss: 0.4404, F1 Score: 0.0000\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.03it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3 / 35, Loss: 0.4313, F1 Score: 0.0000\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4 / 35, Loss: 0.4190, F1 Score: 0.1648\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5 / 35, Loss: 0.4049, F1 Score: 0.5277\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.58it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 6 / 35, Loss: 0.4005, F1 Score: 0.5737\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.51it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 7 / 35, Loss: 0.3953, F1 Score: 0.6032\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.71it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8 / 35, Loss: 0.3955, F1 Score: 0.6018\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 / 35, Loss: 0.3929, F1 Score: 0.6133\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.48it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 10 / 35, Loss: 0.3925, F1 Score: 0.6130\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.57it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 11 / 35, Loss: 0.3924, F1 Score: 0.6147\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 12 / 35, Loss: 0.3903, F1 Score: 0.6165\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.72it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 13 / 35, Loss: 0.3923, F1 Score: 0.6083\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 14 / 35, Loss: 0.3935, F1 Score: 0.6034\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.63it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 15 / 35, Loss: 0.3915, F1 Score: 0.6224\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.42it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 16 / 35, Loss: 0.3909, F1 Score: 0.6224\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 17 / 35, Loss: 0.3868, F1 Score: 0.6434\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 18 / 35, Loss: 0.3925, F1 Score: 0.6044\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.08it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 19 / 35, Loss: 0.3896, F1 Score: 0.6274\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.70it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 20 / 35, Loss: 0.3879, F1 Score: 0.6457\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.38it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 21 / 35, Loss: 0.3893, F1 Score: 0.6227\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.85it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 22 / 35, Loss: 0.3884, F1 Score: 0.6188\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.07it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 23 / 35, Loss: 0.3875, F1 Score: 0.6431\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.07it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 24 / 35, Loss: 0.3837, F1 Score: 0.6437\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.31it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 25 / 35, Loss: 0.3848, F1 Score: 0.6553\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 26 / 35, Loss: 0.3850, F1 Score: 0.6557\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 13.67it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 27 / 35, Loss: 0.3836, F1 Score: 0.6476\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 10.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 28 / 35, Loss: 0.3837, F1 Score: 0.6672\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 10.07it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 29 / 35, Loss: 0.3827, F1 Score: 0.6492\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 13.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 30 / 35, Loss: 0.3842, F1 Score: 0.6567\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 13.21it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 31 / 35, Loss: 0.3836, F1 Score: 0.6370\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 32 / 35, Loss: 0.3863, F1 Score: 0.6209\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 14.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 33 / 35, Loss: 0.3829, F1 Score: 0.6680\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 15.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 34 / 35, Loss: 0.3855, F1 Score: 0.6579\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 17/17 [00:01<00:00, 13.84it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 35 / 35, Loss: 0.3835, F1 Score: 0.6545\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# device = get_device()\n", + "vocab_size = 342\n", + "hidden_dim = 128 # also the number of hidden nodes in the GCN. Can be set to 128, 256 or 512. If you have more GPU, you can increase it. However, you may face overfitting. Increase dropout if you face overfitting. Like dropout = 0.6, 0.7, 0.8\n", + "num_classes = 2\n", + "dropout = 0.4\n", + "output_dim = num_classes\n", + "# number of GCN layers, depending on your GPU, set it to any integer from 2 to 5.\n", + "num_layers = 3\n", + "EPOCHS_GCN = 35\n", + "\n", + "\n", + "device = torch.device(\"cpu\")\n", + "\n", + "\n", + "gcn_model = GCNModel(\n", + " in_channels=vocab_size,\n", + " hidden_channels=hidden_dim,\n", + " out_channels=num_classes,\n", + " num_layers=num_layers,\n", + " dropout=dropout,\n", + ").to(device)\n", + "\n", + "train_losses, train_f1_scores = train_gcn_model(\n", + " gcn_model,\n", + " graph_train_loader,\n", + " vocab_size=vocab_size,\n", + " epochs=EPOCHS_GCN,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = np.arange(EPOCHS_GCN, dtype=int)\n", + "plt.plot(x, train_losses, color=\"red\", label=\"GCN Training Loss\")\n", + "plt.plot(x, train_f1_scores, color=\"blue\", label=\"GCN Training F1 Score\")\n", + "plt.title(\"Training Loss and F1 Score\")\n", + "plt.xlabel(\"Epochs\")\n", + "plt.ylabel(\"Loss\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GCN performance:\n", + "Accuracy: 0.9429\n", + "Precision: 0.9130\n", + "Recall: 0.8106\n", + "F1 score: 0.8519\n", + "MCC: 0.7163\n" + ] + } + ], + "source": [ + "evaluate_gcn_model(gcn_model, graph_val_loader, vocab_size=vocab_size, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([4165, 4094]), torch.Size([1786, 4494]))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# PREPARE DATASET FOR GAN MODEL: pad the sequences to make them of same length, then trim them to a fixed sequence length to avoid too much padding.\n", + "# convert train_data sequences and val_data sequences into graphs\n", + "SEQUENCE_LENGTH = 100\n", + "train_data_padded = pad_sequence(\n", + " [torch.tensor(sequence, dtype=torch.long) for sequence in train_data],\n", + " batch_first=True,\n", + ")\n", + "val_data_padded = pad_sequence(\n", + " [torch.tensor(sequence, dtype=torch.long) for sequence in test_data],\n", + " batch_first=True,\n", + ")\n", + "\n", + "train_data_padded.shape, val_data_padded.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "train_data_padded[:, SEQUENCE_LENGTH].shape, val_data_padded[:SEQUENCE_LENGTH].shape\n", + "\n", + "# our dataset will be of shape (n_samples, SEQUENCE_LENGTH).\n", + "train_dataset = CustomSequenceDataset(\n", + " train_data_padded[:, :SEQUENCE_LENGTH], train_labels\n", + ")\n", + "val_dataset = CustomSequenceDataset(val_data_padded[:, :SEQUENCE_LENGTH], test_labels)\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# hyperparameters of the generator and discriminator of the GAN model\n", + "latent_dim = 100\n", + "hidden_dim = 128\n", + "vocab_size = 342 # we have 342 unique system calls in the ADFA dataset\n", + "seq_len = SEQUENCE_LENGTH # sequence length of the ADFA dataset with less than 120, will be padded. If sequence is longer than 120, it will be truncated.\n", + "output_dim = vocab_size\n", + "embed_dim = 10\n", + "dropout = 0.5\n", + "batch_size = 32\n", + "n_samples = 1000\n", + "num_classes = 2\n", + "lr = 5e-5 # 2e-6\n", + "epochs = 150\n", + "# criterion = nn.BCELoss()\n", + "criterion = nn.CrossEntropyLoss()\n", + "device = get_device()\n", + "temperature = 0.2 # 0.1\n", + "\n", + "gen = Generator(\n", + " latent_dim,\n", + " seq_len,\n", + " hidden_dim,\n", + " output_dim,\n", + " embed_dim,\n", + " dropout,\n", + " conditional_info=True,\n", + " num_classes=num_classes,\n", + ").to(device)\n", + "# gen = Gen(latent_dim, hidden_dim, seq_len, output_dim, embed_dim, dropout).to(device)\n", + "\n", + "disc = Discriminator(\n", + " vocab_size,\n", + " seq_len,\n", + " hidden_dim,\n", + " embed_dim,\n", + " dropout,\n", + " conditional_info=True,\n", + " num_classes=num_classes,\n", + ").to(device)\n", + "\n", + "\n", + "# Apply the weights initialization to your GAN model, uncomment the two lines of code below:\n", + "\n", + "# gen.apply(init_weights)\n", + "# disc.apply(init_weights)\n", + "\n", + "gen_optimizer = Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))\n", + "\n", + "disc_optimizer = Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "device = get_device()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# input_data = torch.randint(0, vocab_size, (n_samples, seq_len))\n", + "# input_labels = torch.randint(0, num_classes, (n_samples,))\n", + "# train_dataset = CustomSequenceDataset(input_data, input_labels)\n", + "# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:03<00:00, 17.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3881, G Loss: -0.6905\n", + "Epoch 2 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:03<00:00, 17.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3966, G Loss: -0.6982\n", + "Epoch 3 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:03<00:00, 17.68it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3920, G Loss: -0.6938\n", + "Epoch 4 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:03<00:00, 16.67it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3851, G Loss: -0.6909\n", + "Epoch 5 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:03<00:00, 16.61it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3786, G Loss: -0.6975\n", + "Epoch 6 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:03<00:00, 17.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3862, G Loss: -0.6959\n", + "Epoch 7 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.80it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3948, G Loss: -0.6969\n", + "Epoch 8 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 16.05it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3930, G Loss: -0.6909\n", + "Epoch 9 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 16.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3790, G Loss: -0.6950\n", + "Epoch 10 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:03<00:00, 16.74it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3793, G Loss: -0.6889\n", + "Epoch 11 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.19it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3917, G Loss: -0.7008\n", + "Epoch 12 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.49it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.4028, G Loss: -0.6943\n", + "Epoch 13 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.92it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3859, G Loss: -0.6920\n", + "Epoch 14 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3813, G Loss: -0.6742\n", + "Epoch 15 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 16.20it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3886, G Loss: -0.6905\n", + "Epoch 16 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.72it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3937, G Loss: -0.6853\n", + "Epoch 17 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3872, G Loss: -0.6889\n", + "Epoch 18 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3827, G Loss: -0.6926\n", + "Epoch 19 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.84it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3806, G Loss: -0.6857\n", + "Epoch 20 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.79it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3865, G Loss: -0.6854\n", + "Epoch 21 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.05it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3915, G Loss: -0.6963\n", + "Epoch 22 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.07it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3851, G Loss: -0.6872\n", + "Epoch 23 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.58it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3921, G Loss: -0.6904\n", + "Epoch 24 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.31it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3857, G Loss: -0.6936\n", + "Epoch 25 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3878, G Loss: -0.6937\n", + "Epoch 26 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.99it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3928, G Loss: -0.6915\n", + "Epoch 27 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3869, G Loss: -0.6937\n", + "Epoch 28 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.98it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3930, G Loss: -0.6900\n", + "Epoch 29 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.02it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3950, G Loss: -0.6898\n", + "Epoch 30 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3910, G Loss: -0.6974\n", + "Epoch 31 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3861, G Loss: -0.7018\n", + "Epoch 32 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3871, G Loss: -0.6907\n", + "Epoch 33 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.79it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3811, G Loss: -0.6986\n", + "Epoch 34 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.99it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3876, G Loss: -0.6984\n", + "Epoch 35 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.98it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3971, G Loss: -0.7016\n", + "Epoch 36 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.12it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3907, G Loss: -0.6984\n", + "Epoch 37 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.05it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3854, G Loss: -0.6901\n", + "Epoch 38 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.69it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3917, G Loss: -0.6932\n", + "Epoch 39 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.70it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3929, G Loss: -0.6890\n", + "Epoch 40 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.05it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3838, G Loss: -0.6947\n", + "Epoch 41 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.91it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3846, G Loss: -0.7004\n", + "Epoch 42 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.97it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3854, G Loss: -0.6949\n", + "Epoch 43 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.05it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3832, G Loss: -0.7010\n", + "Epoch 44 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 15.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3922, G Loss: -0.6927\n", + "Epoch 45 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3928, G Loss: -0.6950\n", + "Epoch 46 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.93it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3767, G Loss: -0.6930\n", + "Epoch 47 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.21it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3910, G Loss: -0.6999\n", + "Epoch 48 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.51it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3956, G Loss: -0.6971\n", + "Epoch 49 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.01it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3760, G Loss: -0.6827\n", + "Epoch 50 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3915, G Loss: -0.6922\n", + "Epoch 51 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:08<00:00, 7.94it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3851, G Loss: -0.6912\n", + "Epoch 52 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.97it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3828, G Loss: -0.6919\n", + "Epoch 53 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.99it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3851, G Loss: -0.6938\n", + "Epoch 54 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.74it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3770, G Loss: -0.6997\n", + "Epoch 55 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.63it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3866, G Loss: -0.6991\n", + "Epoch 56 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.41it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3796, G Loss: -0.6858\n", + "Epoch 57 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.74it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3809, G Loss: -0.7107\n", + "Epoch 58 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3741, G Loss: -0.6974\n", + "Epoch 59 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.75it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3949, G Loss: -0.6967\n", + "Epoch 60 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.60it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3866, G Loss: -0.6893\n", + "Epoch 61 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.63it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3917, G Loss: -0.6919\n", + "Epoch 62 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.60it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3807, G Loss: -0.6846\n", + "Epoch 63 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.68it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3808, G Loss: -0.6956\n", + "Epoch 64 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.57it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3842, G Loss: -0.6885\n", + "Epoch 65 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.63it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3948, G Loss: -0.6959\n", + "Epoch 66 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.84it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3981, G Loss: -0.6900\n", + "Epoch 67 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3776, G Loss: -0.6922\n", + "Epoch 68 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3830, G Loss: -0.6933\n", + "Epoch 69 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3926, G Loss: -0.6985\n", + "Epoch 70 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3726, G Loss: -0.6915\n", + "Epoch 71 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.74it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3892, G Loss: -0.6962\n", + "Epoch 72 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.72it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3836, G Loss: -0.6913\n", + "Epoch 73 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.42it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3767, G Loss: -0.6861\n", + "Epoch 74 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.81it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3947, G Loss: -0.6899\n", + "Epoch 75 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3792, G Loss: -0.6888\n", + "Epoch 76 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.61it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3880, G Loss: -0.6869\n", + "Epoch 77 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.57it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3836, G Loss: -0.6960\n", + "Epoch 78 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3907, G Loss: -0.6935\n", + "Epoch 79 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3971, G Loss: -0.6990\n", + "Epoch 80 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3867, G Loss: -0.6989\n", + "Epoch 81 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.77it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3725, G Loss: -0.6966\n", + "Epoch 82 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3846, G Loss: -0.6953\n", + "Epoch 83 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.68it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3873, G Loss: -0.6934\n", + "Epoch 84 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.67it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3739, G Loss: -0.6930\n", + "Epoch 85 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.41it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3834, G Loss: -0.6918\n", + "Epoch 86 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.24it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3883, G Loss: -0.6991\n", + "Epoch 87 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.75it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3848, G Loss: -0.6867\n", + "Epoch 88 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.11it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3756, G Loss: -0.6936\n", + "Epoch 89 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.91it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3780, G Loss: -0.6895\n", + "Epoch 90 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.28it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3869, G Loss: -0.6938\n", + "Epoch 91 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.40it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3833, G Loss: -0.6918\n", + "Epoch 92 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 13.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3877, G Loss: -0.6868\n", + "Epoch 93 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.11it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3872, G Loss: -0.6940\n", + "Epoch 94 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.63it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3807, G Loss: -0.6942\n", + "Epoch 95 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.12it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3816, G Loss: -0.6941\n", + "Epoch 96 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3835, G Loss: -0.6933\n", + "Epoch 97 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.04it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3814, G Loss: -0.6963\n", + "Epoch 98 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3800, G Loss: -0.6879\n", + "Epoch 99 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.74it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3845, G Loss: -0.6889\n", + "Epoch 100 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.93it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3778, G Loss: -0.7008\n", + "Epoch 101 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.24it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3830, G Loss: -0.6888\n", + "Epoch 102 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.68it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3755, G Loss: -0.6938\n", + "Epoch 103 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3852, G Loss: -0.6948\n", + "Epoch 104 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3807, G Loss: -0.6966\n", + "Epoch 105 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.70it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3935, G Loss: -0.6921\n", + "Epoch 106 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.76it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.4002, G Loss: -0.6990\n", + "Epoch 107 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.92it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3825, G Loss: -0.6922\n", + "Epoch 108 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.47it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3974, G Loss: -0.6947\n", + "Epoch 109 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.04it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3838, G Loss: -0.6841\n", + "Epoch 110 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3806, G Loss: -0.6924\n", + "Epoch 111 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.98it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3895, G Loss: -0.6905\n", + "Epoch 112 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.48it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3871, G Loss: -0.6920\n", + "Epoch 113 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.77it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3833, G Loss: -0.6917\n", + "Epoch 114 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 13.00it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3867, G Loss: -0.6894\n", + "Epoch 115 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.75it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3931, G Loss: -0.6912\n", + "Epoch 116 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.80it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3885, G Loss: -0.6920\n", + "Epoch 117 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.76it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3897, G Loss: -0.6933\n", + "Epoch 118 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.70it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3856, G Loss: -0.6939\n", + "Epoch 119 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.45it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3902, G Loss: -0.6975\n", + "Epoch 120 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.64it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3781, G Loss: -0.6871\n", + "Epoch 121 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.61it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3979, G Loss: -0.6867\n", + "Epoch 122 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.72it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3786, G Loss: -0.6834\n", + "Epoch 123 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3884, G Loss: -0.6906\n", + "Epoch 124 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:06<00:00, 10.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3929, G Loss: -0.6936\n", + "Epoch 125 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 13.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3909, G Loss: -0.6930\n", + "Epoch 126 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3830, G Loss: -0.6906\n", + "Epoch 127 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.60it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3755, G Loss: -0.6879\n", + "Epoch 128 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.56it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3909, G Loss: -0.6956\n", + "Epoch 129 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.30it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3897, G Loss: -0.6947\n", + "Epoch 130 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.58it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3839, G Loss: -0.6909\n", + "Epoch 131 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.48it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3884, G Loss: -0.6919\n", + "Epoch 132 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.58it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3821, G Loss: -0.6938\n", + "Epoch 133 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 14.24it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.4008, G Loss: -0.6880\n", + "Epoch 134 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3897, G Loss: -0.6902\n", + "Epoch 135 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.92it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3850, G Loss: -0.6866\n", + "Epoch 136 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.26it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3873, G Loss: -0.6920\n", + "Epoch 137 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:06<00:00, 9.56it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3868, G Loss: -0.6988\n", + "Epoch 138 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.98it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3922, G Loss: -0.6919\n", + "Epoch 139 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.68it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3843, G Loss: -0.6935\n", + "Epoch 140 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.84it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3890, G Loss: -0.6911\n", + "Epoch 141 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.49it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3880, G Loss: -0.6893\n", + "Epoch 142 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.90it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3788, G Loss: -0.6932\n", + "Epoch 143 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.79it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3917, G Loss: -0.6935\n", + "Epoch 144 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3919, G Loss: -0.6877\n", + "Epoch 145 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.82it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3799, G Loss: -0.6903\n", + "Epoch 146 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:04<00:00, 13.84it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3833, G Loss: -0.6957\n", + "Epoch 147 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 13.15it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3920, G Loss: -0.6940\n", + "Epoch 148 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 11.67it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3836, G Loss: -0.6857\n", + "Epoch 149 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3925, G Loss: -0.6955\n", + "Epoch 150 / 150\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [00:05<00:00, 12.71it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D Loss: 1.3828, G Loss: -0.6896\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# note that the temperature is set to any value between 0 and 1.\n", + "g_losses, d_losses = train_gan_model(\n", + " gen,\n", + " disc,\n", + " gen_optimizer,\n", + " disc_optimizer,\n", + " train_loader,\n", + " epochs,\n", + " vocab_size,\n", + " device=device,\n", + " tau=temperature,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_loss_curve(d_losses, g_losses, epochs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## generate fake samples\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "fake_labels = torch.randint(0, num_classes, (n_samples,), dtype=torch.long)\n", + "z = torch.randn((n_samples, latent_dim))\n", + "\n", + "# print(\"random noise shape\", z.shape)\n", + "# set generator to eval mode\n", + "gen.eval()\n", + "\n", + "fake_data = gen(z.to(device), fake_labels.to(device)).detach().cpu()\n", + "\n", + "# print the generated data: if the data contains nan values, it means that there is gradient explosion, or other issues. To avoid such problems, change the termperature value to be between 0.1 and 0.5. Also, trying changing the learning rate to be between 5e-6 and 1e-4. Of course, this is just a suggestion. You might experiment with other values if you want, and report your findings.\n", + "# print(fake_data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "array([[319, 267, 98, 273, 158, 191, 26, 114, 35, 193, 280, 102, 150,\n", + " 268, 81, 250, 260, 64, 145, 88, 53, 256, 34, 275, 231, 11,\n", + " 63, 62, 232, 194, 144, 133, 255, 38, 255, 228, 153, 265, 130,\n", + " 311, 146, 168, 197, 226, 294, 314, 9, 100, 180, 28, 87, 273,\n", + " 0, 15, 162, 149, 128, 273, 95, 307, 156, 307, 241, 159, 165,\n", + " 129, 68, 65, 171, 98, 304, 4, 142, 325, 18, 18, 297, 5,\n", + " 249, 164, 279, 76, 235, 263, 325, 131, 0, 183, 197, 331, 300,\n", + " 18, 120, 65, 31, 270, 179, 318, 119, 173],\n", + " [111, 136, 83, 50, 130, 225, 241, 201, 328, 269, 94, 98, 59,\n", + " 275, 178, 140, 72, 245, 15, 301, 23, 93, 219, 30, 48, 105,\n", + " 162, 303, 49, 230, 148, 236, 337, 106, 317, 139, 99, 106, 294,\n", + " 30, 70, 264, 306, 9, 305, 291, 53, 19, 158, 309, 47, 43,\n", + " 171, 168, 298, 313, 279, 279, 36, 296, 8, 124, 25, 124, 192,\n", + " 5, 184, 107, 118, 122, 220, 289, 299, 284, 318, 64, 130, 258,\n", + " 294, 335, 294, 153, 26, 172, 236, 299, 341, 145, 93, 182, 115,\n", + " 253, 222, 336, 155, 22, 219, 179, 35, 214]])\n" + ] + } + ], + "source": [ + "# get the generated samples as tokens (this are a sequence tokens or words / characters, represented by integers .)\n", + "from pprint import pprint\n", + "\n", + "fake_samples = torch.argmax(fake_data, dim=-1).cpu().numpy()\n", + "\n", + "# now, let's see two generated sequences\n", + "pprint(fake_samples[:2])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Now, we can generate from generated tokens (i.e., sequence)\n", + "\n", + "` To do so, we can use the sequence_to_graph() function given in the preprocess_data.py`\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# convert the generated samples into a networkx graph and also pytorch_geometric graph data object\n", + "networkx_graph, pytorch_geometric_data = sequence_to_graph(fake_samples[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# you can see the networkx graph we generated\n", + "nx.draw(networkx_graph, with_labels=True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# The Remaining Steps:\n", + "\n", + "```\n", + "1 - generate many fake samples (sequences and their corresponding labels),\n", + "\n", + "2 - add these fake sequences to train_data sequences, and add their fake_labels (benign or malign) to the trian_labels\n", + "\n", + "3 - train a new GCNModel and compare performance before and after adding fake samples.\n", + "\n", + "\n", + "```\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "graphganvenv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..a089e32 --- /dev/null +++ b/utils.py @@ -0,0 +1,93 @@ +from typing import List + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import ( + accuracy_score, + f1_score, + matthews_corrcoef, + precision_score, + recall_score, +) + + +def get_device(): + """Get the preferred device: + check if there is a GPU available:NVIDIA GPU or MPS (Apple Silcon GPU) if available). + """ + + # check if there is nvidia or cuda gpu + if torch.cuda.is_available(): + return torch.device("cuda") + + # check if there is an apple silicon gpu + if torch.backends.mps.is_available(): + return torch.device("mps") + + # otherwise use the cpu + return torch.device("cpu") + + +def plot_loss_curve(d_losses: List[float], g_losses: List[float], EPOCHS: int) -> None: + """Plot loss curve of critic and generator. + + Args: + d_losses (List[float]): List of Discriminator losses. + g_losses (List[float]): List of generator losses. + EPOCHS (int): Total number of epochs. + """ + # normalize losses to get a nice graph + gen_losses = np.asarray(g_losses) + critic_losses = np.asarray(d_losses) + + gen_losses = (gen_losses - gen_losses.min()) / (gen_losses.max() - gen_losses.min()) + critic_losses = np.asarray(d_losses) + + critic_losses = (critic_losses - critic_losses.min()) / ( + critic_losses.max() - critic_losses.min() + ) + plt.figure(figsize=(20, 10)) + x = range(1, EPOCHS + 1) + plt.plot(x, critic_losses, color="red", label="Discriminator Loss") + plt.plot(x, gen_losses, color="blue", label="Generator Loss") + plt.xlabel("Epoch") + plt.ylabel("Loss Curve") + plt.legend() + plt.savefig("loss_curve.png") + plt.show() + + +def evaluate_gcn_model(gcn_model, graph_val_loader, vocab_size, device): + gcn_model.eval() + # val_losses = [] + # val_f1_scores = [] + + y_true = [] + y_pred = [] + + with torch.no_grad(): + for data in graph_val_loader: + x = data.x[0].to(device) + x = F.one_hot(x, num_classes=vocab_size).float() + edge_index = data.edge_index.to(device) + edge_weight = data.weight.float().to(device) + batch = data.batch.to(device) + + out = gcn_model(x, edge_index, edge_weight, batch) + # loss = F.cross_entropy(out, data.y) + # val_losses.append(loss.item()) + + preds = out.argmax(dim=1).cpu().numpy() + y_true.extend(data.y.cpu().numpy()) + y_pred.extend(preds) + + print("GCN performance:") + print(f"Accuracy: {accuracy_score(y_true, y_pred):.4f}") + print(f"Precision: {precision_score(y_true, y_pred, average='macro'):.4f}") + print(f"Recall: {recall_score(y_true, y_pred, average='macro'):.4f}") + print( + f"F1 score: {f1_score(y_true, y_pred, average='macro', zero_division=0.0):.4f}" + ) + print(f"MCC: {matthews_corrcoef(y_true, y_pred):.4f}")