-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1607e81
commit ad2a88d
Showing
7 changed files
with
588 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,55 @@ | ||
from transformers import BertTokenizer, BertForSequenceClassification | ||
import torch | ||
import subprocess | ||
import os | ||
import google.auth | ||
from google.auth.transport.requests import Request | ||
from google.auth import default | ||
from google.auth import exceptions | ||
from google.oauth2.credentials import Credentials | ||
from google.auth.transport.requests import AuthorizedSession | ||
from google.auth import impersonated_credentials | ||
from google.auth.transport import requests | ||
|
||
from google.cloud import aiplatform | ||
|
||
# Function to authenticate and get the credentials | ||
def authenticate_with_google_cloud(): | ||
credentials, project = google.auth.default( | ||
scopes=["https://www.googleapis.com/auth/cloud-platform"] | ||
) | ||
if not credentials.valid: | ||
if credentials.expired and credentials.refresh_token: | ||
credentials.refresh(Request()) | ||
else: | ||
# If credentials are not valid and cannot be refreshed, open the browser for authentication | ||
flow = google.auth.oauth2client.OAuth2WebServerFlow( | ||
client_id=os.getenv("GOOGLE_CLIENT_ID"), | ||
client_secret=os.getenv("GOOGLE_CLIENT_SECRET"), | ||
scope="https://www.googleapis.com/auth/cloud-platform", | ||
redirect_uri="urn:ietf:wg:oauth:2.0:oob" | ||
) | ||
auth_uri = flow.step1_get_authorize_url() | ||
print("Please go to this URL: {}".format(auth_uri)) | ||
auth_code = input("Enter the authorization code: ") | ||
credentials = flow.step2_exchange(auth_code) | ||
|
||
return credentials, project | ||
|
||
# Function to export Vertex AI data to a .txt file | ||
def export_vertex_ai_data(): | ||
# Authenticate and get the credentials | ||
credentials, project_id = authenticate_with_google_cloud() | ||
|
||
# Initialize the AI Platform client | ||
aiplatform.init(project=project_id, credentials=credentials) | ||
|
||
# Define the file to save the data | ||
export_file = "vertex_ai_data.txt" | ||
|
||
with open(export_file, "w") as file: | ||
# Example: List all datasets in the project | ||
datasets = aiplatform.Dataset.list() | ||
file.write("Datasets:\n") | ||
for dataset in datasets: | ||
file.write(f"Name: {dataset.name}, Display Name: {dataset.display_name}\n") | ||
|
||
# Example: List all models in the project | ||
models = aiplatform.Model.list() | ||
file.write("\nModels:\n") | ||
for model in models: | ||
file.write(f"Name: {model.name}, Display Name: {model.display_name}\n") | ||
|
||
# Example: List all endpoints in the project | ||
endpoints = aiplatform.Endpoint.list() | ||
file.write("\nEndpoints:\n") | ||
for endpoint in endpoints: | ||
file.write(f"Name: {endpoint.name}, Display Name: {endpoint.display_name}\n") | ||
|
||
print(f"Data exported to {export_file}") | ||
|
||
if __name__ == "__main__": | ||
export_vertex_ai_data() | ||
MODEL_PATH = 'C:\\Users\\LENOVO\\Desktop\\CSC5382_SP24_FINALPROJECT\\scripts\\bert-election2024-twitter-stance-biden' | ||
SAVE_PATH = 'C:\\Users\\LENOVO\\Desktop\\saved_model' | ||
HANDLER_PATH = 'C:\\Users\\LENOVO\\Desktop\\CSC5382_SP24_FINALPROJECT\\scripts\\transformers_handler.py' # Update this path as needed | ||
|
||
# Load the model and tokenizer | ||
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) | ||
model = BertForSequenceClassification.from_pretrained(MODEL_PATH) | ||
|
||
# Save the tokenizer and model | ||
tokenizer.save_pretrained(SAVE_PATH) | ||
model.save_pretrained(SAVE_PATH) | ||
|
||
# If the model uses PyTorch, save the model as a .bin file | ||
model_file_path = f"{SAVE_PATH}/pytorch_model.bin" | ||
torch.save(model.state_dict(), model_file_path) | ||
|
||
# Determine the PyTorch version | ||
torch_version = torch.__version__ | ||
print(f"PyTorch version: {torch_version}") | ||
|
||
# Prepare the model for TorchServe | ||
# Install torch-model-archiver if not already installed | ||
try: | ||
import torch_model_archiver | ||
except ImportError: | ||
subprocess.run(["pip", "install", "torch-model-archiver"]) | ||
|
||
# Create model_store directory if it doesn't exist | ||
model_store_path = os.path.join(SAVE_PATH, "model_store") | ||
os.makedirs(model_store_path, exist_ok=True) | ||
|
||
# Archive the model | ||
archive_command = [ | ||
"torch-model-archiver", | ||
"--model-name", "bert-election2024", | ||
"--version", "1.0", | ||
"--serialized-file", model_file_path, | ||
"--handler", HANDLER_PATH, | ||
"--export-path", model_store_path, | ||
"--extra-files", f"{SAVE_PATH}/config.json,{SAVE_PATH}/vocab.txt", | ||
"--force" | ||
] | ||
subprocess.run(archive_command) | ||
|
||
# Verify model archive | ||
archive_file = os.path.join(model_store_path, "bert-election2024.mar") | ||
if os.path.exists(archive_file): | ||
print(f"Model archive created at: {archive_file}") | ||
else: | ||
print("Failed to create model archive.") |
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.