mstyslavity commited on
Commit
47cc53e
·
verified ·
1 Parent(s): 5d72926

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -29
app.py CHANGED
@@ -3,13 +3,19 @@ import pathlib
3
  import random
4
  import string
5
  import tempfile
6
- import time
7
- from concurrent.futures import ThreadPoolExecutor
8
  from typing import Iterable, List
9
 
 
10
  import gradio as gr
11
  import huggingface_hub
12
  import torch
 
 
 
 
 
 
 
13
 
14
  # HF Spaces: needed for ZeroGPU. Safe on CPU Spaces too.
15
  try:
@@ -178,28 +184,189 @@ def _merge_impl(yaml_config: str, hf_token: str, repo_name: str, *, force_cuda:
178
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
179
 
180
 
 
 
 
 
 
 
 
 
181
 
 
 
 
 
 
 
 
 
182
 
183
- def merge_cpu(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
184
- """CPU path (default)."""
185
- yield from _merge_impl(yaml_config, hf_token, repo_name, force_cuda=False)
186
 
187
- if spaces is not None:
188
- # ZeroGPU requires at least one @spaces.GPU function to exist at startup.
189
- @spaces.GPU(duration=60 * 2) # up to 60 min per call; ZeroGPU minutes may cap earlier
190
- def merge_gpu(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
191
- """GPU path for ZeroGPU hardware."""
192
- yield from _merge_impl(yaml_config, hf_token, repo_name, force_cuda=True)
193
- else:
194
- def merge_gpu(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
195
- yield from merge_cpu(yaml_config, hf_token, repo_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Prefer GPU entrypoint on ZeroGPU; harmless fallback on CPU Spaces.
198
- MERGE_FN = merge_gpu if spaces is not None else merge_cpu
 
 
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  with gr.Blocks() as demo:
202
  gr.Markdown(MARKDOWN_DESCRIPTION)
 
203
 
204
  with gr.Row():
205
  filename = gr.Textbox(visible=False, label="filename")
@@ -217,6 +384,11 @@ with gr.Blocks() as demo:
217
  label="Repo name",
218
  placeholder="Optional. Will create a random name if empty.",
219
  )
 
 
 
 
 
220
  button = gr.Button("Merge", variant="primary")
221
  logs = LogsView(label="Terminal output")
222
  gr.Examples(
@@ -229,19 +401,7 @@ with gr.Blocks() as demo:
229
  )
230
  gr.Markdown(MARKDOWN_ARTICLE)
231
 
232
- button.click(fn=MERGE_FN, inputs=[config, token, repo_name], outputs=[logs])
233
-
234
- # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
235
- def _garbage_collect_every_hour():
236
- while True:
237
- try:
238
- garbage_collect_empty_models(token=COMMUNITY_HF_TOKEN)
239
- except Exception as e:
240
- print("Error running garbage collection", e)
241
- time.sleep(3600)
242
-
243
 
244
- pool = ThreadPoolExecutor()
245
- pool.submit(_garbage_collect_every_hour)
246
 
247
  demo.queue(default_concurrency_limit=1).launch()
 
3
  import random
4
  import string
5
  import tempfile
 
 
6
  from typing import Iterable, List
7
 
8
+ import spaces
9
  import gradio as gr
10
  import huggingface_hub
11
  import torch
12
+ import yaml
13
+ from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
14
+ from mergekit.config import MergeConfiguration
15
+
16
+ from clean_community_org import garbage_collect_empty_models
17
+ from apscheduler.schedulers.background import BackgroundScheduler
18
+ from datetime import timezone
19
 
20
  # HF Spaces: needed for ZeroGPU. Safe on CPU Spaces too.
21
  try:
 
184
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
185
 
186
 
187
+ def run_merge_cpu(runner: LogsViewRunner, cli: str, merged_path: str, tmpdirname: str):
188
+ # Set tmp HF_HOME to avoid filling up disk Space
189
+ tmp_env = os.environ.copy()
190
+ tmp_env["HF_HOME"] = f"{tmpdirname}/.cache"
191
+ full_cli = cli + f" --lora-merge-cache {tmpdirname}/.lora_cache --transformers-cache {tmpdirname}/.cache"
192
+ yield from runner.run_command(full_cli.split(), cwd=merged_path, env=tmp_env)
193
+ yield ("done", runner.exit_code)
194
+
195
 
196
+ @spaces.GPU(duration=60 * 2)
197
+ def run_merge_gpu(runner: LogsViewRunner, cli: str, merged_path: str, tmpdirname: str):
198
+ yield from run_merge_cpu(
199
+ runner,
200
+ cli + " --cuda --low-cpu-memory --read-to-gpu",
201
+ merged_path,
202
+ tmpdirname,
203
+ )
204
 
 
 
 
205
 
206
+ def run_merge(
207
+ runner: LogsViewRunner,
208
+ cli: str,
209
+ merged_path: str,
210
+ tmpdirname: str,
211
+ use_gpu: bool,
212
+ ):
213
+ if use_gpu:
214
+ yield from run_merge_gpu(runner, cli, merged_path, tmpdirname)
215
+ else:
216
+ yield from run_merge_cpu(runner, cli, merged_path, tmpdirname)
217
+
218
+
219
+ def prefetch_models(
220
+ runner: LogsViewRunner,
221
+ merge_config: MergeConfiguration,
222
+ hf_home: str,
223
+ lora_merge_cache: str,
224
+ ):
225
+ for model in merge_config.referenced_models():
226
+ yield runner.log(f"Downloading {model}...")
227
+ model = model.merged(cache_dir=lora_merge_cache, trust_remote_code=False)
228
+ local_path = model.local_path(cache_dir=hf_home)
229
+ yield runner.log(f"\tDownloaded {model} to {local_path}")
230
+
231
+
232
+ def merge(yaml_config: str, hf_token: str, repo_name: str, private: bool) -> Iterable[List[Log]]:
233
+ runner = LogsViewRunner()
234
+
235
+ if not yaml_config:
236
+ yield runner.log("Empty yaml, enter your config or pick an example below", level="ERROR")
237
+ return
238
+ try:
239
+ merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
240
+ except Exception as e:
241
+ yield runner.log(f"Invalid yaml {e}", level="ERROR")
242
+ return
243
+
244
+ is_community_model = False
245
+ if not hf_token:
246
+ if private:
247
+ yield runner.log(
248
+ "Cannot upload model as private without a token. Please provide a HF token.",
249
+ level="ERROR",
250
+ )
251
+ return
252
+ if "/" in repo_name and not repo_name.startswith("mergekit-community/"):
253
+ yield runner.log(
254
+ f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
255
+ level="ERROR",
256
+ )
257
+ return
258
+ yield runner.log(
259
+ "No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization."
260
+ )
261
+ is_community_model = True
262
+ if not COMMUNITY_HF_TOKEN:
263
+ raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
264
+ hf_token = COMMUNITY_HF_TOKEN
265
 
266
+ api = huggingface_hub.HfApi(token=hf_token)
267
+ has_gpu = torch.cuda.is_available()
268
+ cli = "mergekit-yaml config.yaml merge --copy-tokenizer --allow-crimes -v" + (
269
+ " --out-shard-size 1B --lazy-unpickle" if (not has_gpu) else ""
270
+ )
271
 
272
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
273
+ tmpdir = pathlib.Path(tmpdirname)
274
+ merged_path = tmpdir / "merged"
275
+ merged_path.mkdir(parents=True, exist_ok=True)
276
+ config_path = merged_path / "config.yaml"
277
+ config_path.write_text(yaml_config)
278
+ yield runner.log(f"Merge configuration saved in {config_path}")
279
+
280
+ if not repo_name:
281
+ yield runner.log("No repo name provided. Generating a random one.")
282
+ repo_name = f"mergekit-{merge_config.merge_method}"
283
+ # Make repo_name "unique" (no need to be extra careful on uniqueness)
284
+ repo_name += "-" + "".join(random.choices(string.ascii_lowercase, k=7))
285
+ repo_name = repo_name.replace("/", "-").strip("-")
286
+
287
+ if is_community_model and not repo_name.startswith("mergekit-community/"):
288
+ repo_name = f"mergekit-community/{repo_name}"
289
+
290
+ try:
291
+ yield runner.log(f"Creating repo {repo_name}")
292
+ repo_url = api.create_repo(repo_name, exist_ok=True, private=private, repo_type="model")
293
+ yield runner.log(f"Repo created: {repo_url}")
294
+ except Exception as e:
295
+ yield runner.log(f"Error creating repo {e}", level="ERROR")
296
+ return
297
+
298
+ # Prefetch models to avoid downloading them with scarce GPU time
299
+ yield runner.log("Prefetching models...")
300
+ yield from prefetch_models(
301
+ runner,
302
+ merge_config,
303
+ hf_home=f"{tmpdirname}/.cache",
304
+ lora_merge_cache=f"{tmpdirname}/.lora_cache",
305
+ )
306
+ yield runner.log("Models prefetched. Starting merge.")
307
+ exit_code = None
308
+ try:
309
+ for ev in run_merge(
310
+ runner,
311
+ cli,
312
+ merged_path,
313
+ tmpdirname,
314
+ use_gpu=has_gpu,
315
+ ):
316
+ if isinstance(ev, tuple) and ev[0] == "done":
317
+ exit_code = ev[1]
318
+ continue
319
+ yield ev
320
+ except Exception as e:
321
+ yield runner.log(f"Error running merge {e}", level="ERROR")
322
+ yield runner.log("Merge failed. Deleting repo as no model is uploaded.", level="ERROR")
323
+ api.delete_repo(repo_url.repo_id)
324
+ return
325
+
326
+ if exit_code != 0:
327
+ yield runner.log(f"Exit code: {exit_code}")
328
+ yield runner.log("Merge failed. Deleting repo as no model is uploaded.", level="ERROR")
329
+ api.delete_repo(repo_url.repo_id)
330
+ return
331
+
332
+ yield runner.log("Model merged successfully. Uploading to HF.")
333
+ yield from runner.run_python(
334
+ api.upload_folder,
335
+ repo_id=repo_url.repo_id,
336
+ folder_path=merged_path / "merge",
337
+ )
338
+ yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
339
+
340
+
341
+ merge.zerogpu = True
342
+ run_merge.zerogpu = True
343
+
344
+ def _restart_space():
345
+ huggingface_hub.HfApi().restart_space(
346
+ repo_id="arcee-ai/mergekit-gui", token=COMMUNITY_HF_TOKEN, factory_reboot=False
347
+ )
348
+
349
+
350
+ # Run garbage collection every hour to keep the community org clean.
351
+ # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
352
+ def _garbage_remover():
353
+ try:
354
+ garbage_collect_empty_models(token=COMMUNITY_HF_TOKEN)
355
+ except Exception as e:
356
+ print("Error running garbage collection", e)
357
+
358
+
359
+ scheduler = BackgroundScheduler()
360
+ restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=21600)
361
+ garbage_remover_job = scheduler.add_job(_garbage_remover, "interval", seconds=3600)
362
+ scheduler.start()
363
+ next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
364
+
365
+ NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC)"
366
 
367
  with gr.Blocks() as demo:
368
  gr.Markdown(MARKDOWN_DESCRIPTION)
369
+ gr.Markdown(NEXT_RESTART)
370
 
371
  with gr.Row():
372
  filename = gr.Textbox(visible=False, label="filename")
 
384
  label="Repo name",
385
  placeholder="Optional. Will create a random name if empty.",
386
  )
387
+ private = gr.Checkbox(
388
+ label="Private",
389
+ value=False,
390
+ info="Upload the model as private. If not checked, will be public. Must provide a token.",
391
+ )
392
  button = gr.Button("Merge", variant="primary")
393
  logs = LogsView(label="Terminal output")
394
  gr.Examples(
 
401
  )
402
  gr.Markdown(MARKDOWN_ARTICLE)
403
 
404
+ button.click(fn=merge, inputs=[config, token, repo_name, private], outputs=[logs])
 
 
 
 
 
 
 
 
 
 
405
 
 
 
406
 
407
  demo.queue(default_concurrency_limit=1).launch()