forked from Confuzu/CivitAI-Model-grabber
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcivitAI_Model_downloader.py
286 lines (227 loc) · 9.86 KB
/
civitAI_Model_downloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import ujson as json
import requests
import urllib.parse
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import logging
import time
import argparse
# Configure logging to suppress unnecessary messages
logging.basicConfig(filename='civitai_Model_downloader.log', encoding='utf-8', level=logging.ERROR)
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Download model files and images from Civitai API.")
parser.add_argument("username", type=str, help="Enter username you want to download from.")
parser.add_argument("--retry_delay", type=int, default=10, help="Retry delay in seconds.")
parser.add_argument("--max_tries", type=int, default=3, help="Maximum number of retries.")
parser.add_argument("--max_threads", type=int, default=5, help="Maximum number of concurrent threads.Too many produces API Failure.")
args = parser.parse_args()
# Command-line arguments
username = args.username
retry_delay = args.retry_delay
max_tries = args.max_tries
max_threads = args.max_threads
# Format the URL with username, types, and nsfw parameter
base_url = "https://civitai.com/api/v1/models"
params = {
"username": username,
}
url = f"{base_url}?{urllib.parse.urlencode(params)}"
# Set the headers
headers = {
"Content-Type": "application/json"
}
# Create a session object for making multiple requests
session = requests.Session()
def sanitize_name(name):
# Function to replace spaces with underscores and remove slashes
return name.replace(' ', '_').replace('/', '-')
# Function to download a file or image from the provided URL
def download_file_or_image(url, output_path):
progress_bar = None
try:
response = session.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True, leave=False)
with open(output_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
progress_bar.update(len(chunk))
file.write(chunk)
progress_bar.close()
except (requests.RequestException, TimeoutError, ConnectionResetError) as e:
if progress_bar:
progress_bar.close()
print(f"Error downloading: {url}")
time.sleep(retry_delay)
download_file_or_image(url, output_path)
# Create a lock for thread-safe file writes
file_lock = threading.Lock()
# New function to download the related image and model files for each model version
def download_model_files(item_name, model_version, item):
files = model_version.get('files', [])
images = model_version.get('images', [])
downloaded = False
model_id = item['id']
model_url = f"https://civitai.com/models/{model_id}"
item_name_sanitized = sanitize_name(item_name)
item_dir = os.path.join(output_dir, item_name_sanitized)
os.makedirs(item_dir, exist_ok=True)
existing_files_count = 0
for file in files:
file_name = file.get('name', '')
file_name_sanitized = sanitize_name(file_name)
file_path = os.path.join(item_dir, file_name_sanitized)
if os.path.exists(file_path):
existing_files_count += 1
if existing_files_count == len(files):
downloaded = True
model_images = {} # Dictionary to store image filenames associated with the model
for file in files:
file_name = file.get('name', '') # Use empty string as default if 'name' key is missing
file_url = file.get('downloadUrl', '') # Use empty string as default if 'downloadUrl' key is missing
# Skip download if the file already exists
file_name_sanitized = sanitize_name(file_name)
file_path = os.path.join(item_dir, file_name_sanitized)
if os.path.exists(file_path):
continue
# Skip if 'name' or 'downloadUrl' keys are missing
if not file_name or not file_url:
print(f"Invalid file entry: {file}")
continue
# Download the file
try:
download_file_or_image(file_url, file_path)
downloaded = True
except (requests.RequestException, TimeoutError):
print(f"Error downloading file: {file_url}")
# Update the details file
details_file = os.path.join(item_dir, "details.txt")
with open(details_file, "a") as f:
f.write(f"Model URL: {model_url}\n")
f.write(f"File Name: {file_name}\n")
f.write(f"File URL: {file_url}\n")
for image in images:
image_id = image.get('id', '') # Use empty string as default if 'id' key is missing
image_url = image.get('url', '') # Use empty string as default if 'url' key is missing
# Skip download if the image already exists
image_filename_raw = f"{item_name}_{image_id}_for_{file_name}.jpeg"
image_filename_sanitized = sanitize_name(image_filename_raw)
image_path = os.path.join(item_dir, image_filename_sanitized)
if os.path.exists(image_path):
continue
# Skip if 'id' or 'url' keys are missing
if not image_id or not image_url:
print(f"Invalid image entry: {image}")
continue
# Download the image
try:
download_file_or_image(image_url, image_path)
downloaded = True
except (requests.RequestException, TimeoutError):
print(f"Error downloading image: {image_url}")
# Update the details file
details_file = os.path.join(item_dir, "details.txt")
with open(details_file, "a") as f:
f.write(f"Image ID: {image_id}\n")
f.write(f"Image URL: {image_url}\n")
# Store the image filename in the model_images dictionary
if item_name not in model_images:
model_images[item_name] = []
model_images[item_name].append(image_filename_raw)
return item_name, downloaded, model_images
# Create a directory for the username
output_dir = f"{username}_downloads"
os.makedirs(output_dir, exist_ok=True)
# Make the REST API call using the session object
retry_count = 0
max_retries = 3
retry_delay = 10
while retry_count < max_retries:
try:
response = session.get(url, headers=headers)
response.raise_for_status()
data = response.json()
break
except (requests.RequestException, TimeoutError, json.JSONDecodeError) as e:
print(f"Error making API request or decoding JSON response: {e}")
retry_count += 1
if retry_count < max_retries:
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
print("Maximum retries exceeded. Exiting.")
exit()
# Extract data from the JSON response
items = data['items']
metadata = data['metadata']
# Create a thread pool
max_threads = 5 # Adjust the maximum number of concurrent threads as per your needs
executor = ThreadPoolExecutor(max_workers=max_threads)
# Define a list to store the download futures
download_futures = []
# Define a set to store the downloaded item names
downloaded_item_names = set()
def handle_pagination(metadata):
current_page = metadata['currentPage']
total_pages = metadata['totalPages']
total_items_count = 0
max_recursion_depth = 50
for _ in range(max_recursion_depth):
next_page = f"{base_url}?{urllib.parse.urlencode(params)}&page={current_page}"
try:
response = session.get(next_page, headers=headers)
response.raise_for_status()
data = response.json()
except (requests.RequestException, TimeoutError, json.JSONDecodeError) as e:
print(f"Error making API request or decoding JSON response: {e}")
return
items = data['items']
total_items_count += len(items)
# Call download_model_files() for all of the items
for item in items:
item_name = item['name']
model_versions = item['modelVersions']
for version in model_versions:
future = executor.submit(download_model_files, item_name, version, item)
download_futures.append(future)
current_page = data['metadata']['currentPage'] + 1
# Break the recursion if we have reached the maximum depth
if current_page > total_pages:
break
return total_items_count
# Iterate through the items and model versions, submitting download tasks to the thread pool
for item in items:
item_name = item['name']
model_versions = item['modelVersions']
downloaded_item_names.add(item_name)
for version in model_versions:
future = executor.submit(download_model_files, item_name, version, item)
download_futures.append(future)
# Check for pagination and handle subsequent pages
if metadata['totalPages'] > 1:
handle_pagination(metadata)
# Wait for all downloads to complete
download_results = []
for future in tqdm(download_futures, desc="Downloading Files", unit="file", leave=False):
result = future.result()
download_results.append(result)
# Call handle_pagination function passing the metadata
total_items = handle_pagination(metadata)
# Shut down the thread pool
executor.shutdown()
# Compare the downloaded items with all item names to find missing items
all_item_names_sanitized = {sanitize_name(item['name']) for item in items}
downloaded_item_names_sanitized = {sanitize_name(item_name) for item_name, downloaded, _ in download_results if downloaded}
missing_items = all_item_names_sanitized - downloaded_item_names_sanitized
# Print summary
print(" Download completed successfully.")
print(f"Total items: {total_items}")
print(f"Downloaded items: {len(downloaded_item_names)}")
print(f"Missing items: {len(missing_items)}")
if missing_items:
print("Missing item names:")
for item_name in missing_items:
print(item_name)