mirror of
https://github.com/ace-step/ACE-Step-1.5.git
synced 2026-07-02 16:37:04 +00:00
Add editorconfig and fix training handler encoding
This commit is contained in:
parent
5c123466a3
commit
c3039332a4
2 changed files with 66 additions and 71 deletions
16
.editorconfig
Normal file
16
.editorconfig
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[*.{bat,cmd,ps1}]
|
||||
end_of_line = crlf
|
||||
|
||||
[*.png]
|
||||
charset = unset
|
||||
end_of_line = unset
|
||||
insert_final_newline = false
|
||||
trim_trailing_whitespace = false
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""
|
||||
"""
|
||||
Event Handlers for Training Tab
|
||||
|
||||
Contains all event handler functions for the dataset builder and training UI.
|
||||
|
|
@ -40,7 +40,7 @@ def scan_directory(
|
|||
Tuple of (table_data, status, slider_update, builder_state)
|
||||
"""
|
||||
if not audio_dir or not audio_dir.strip():
|
||||
return [], "❌ Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
|
||||
return [], "<EFBFBD> Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
|
||||
|
||||
# Create or use existing builder
|
||||
builder = builder_state if builder_state else DatasetBuilder()
|
||||
|
|
@ -97,17 +97,17 @@ def auto_label_all(
|
|||
Tuple of (table_data, status, builder_state)
|
||||
"""
|
||||
if builder_state is None:
|
||||
return [], "❌ Please scan a directory first", builder_state
|
||||
return [], "<EFBFBD> Please scan a directory first", builder_state
|
||||
|
||||
if not builder_state.samples:
|
||||
return [], "❌ No samples to label. Please scan a directory first.", builder_state
|
||||
return [], "<EFBFBD> No samples to label. Please scan a directory first.", builder_state
|
||||
|
||||
# Check if handlers are initialized
|
||||
if dit_handler is None or dit_handler.model is None:
|
||||
return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
|
||||
return builder_state.get_samples_dataframe_data(), "<EFBFBD> Model not initialized. Please initialize the service first.", builder_state
|
||||
|
||||
if llm_handler is None or not llm_handler.llm_initialized:
|
||||
return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
||||
return builder_state.get_samples_dataframe_data(), "<EFBFBD> LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
||||
|
||||
def progress_callback(msg):
|
||||
if progress:
|
||||
|
|
@ -208,7 +208,7 @@ def save_sample_edit(
|
|||
Tuple of (table_data, status, builder_state)
|
||||
"""
|
||||
if builder_state is None:
|
||||
return [], "❌ No dataset loaded", builder_state
|
||||
return [], "<EFBFBD> No dataset loaded", builder_state
|
||||
|
||||
idx = int(sample_idx)
|
||||
|
||||
|
|
@ -279,13 +279,13 @@ def save_dataset(
|
|||
Status message
|
||||
"""
|
||||
if builder_state is None:
|
||||
return "❌ No dataset to save. Please scan a directory first.", gr.update()
|
||||
return "<EFBFBD> No dataset to save. Please scan a directory first.", gr.update()
|
||||
|
||||
if not builder_state.samples:
|
||||
return "❌ No samples in dataset.", gr.update()
|
||||
return "<EFBFBD> No samples in dataset.", gr.update()
|
||||
|
||||
if not save_path or not save_path.strip():
|
||||
return "❌ Please enter a save path.", gr.update()
|
||||
return "<EFBFBD> Please enter a save path.", gr.update()
|
||||
|
||||
save_path = save_path.strip()
|
||||
if not save_path.lower().endswith(".json"):
|
||||
|
|
@ -294,7 +294,7 @@ def save_dataset(
|
|||
# Check if any samples are labeled
|
||||
labeled_count = builder_state.get_labeled_count()
|
||||
if labeled_count == 0:
|
||||
return "❌ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
|
||||
return "<EFBFBD>️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
|
||||
|
||||
return builder_state.save_dataset(save_path, dataset_name), gr.update(value=save_path)
|
||||
|
||||
|
|
@ -319,13 +319,13 @@ def load_existing_dataset_for_preprocess(
|
|||
|
||||
if not dataset_path or not dataset_path.strip():
|
||||
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
||||
return ("❌ Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
return ("<EFBFBD> Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
|
||||
dataset_path = dataset_path.strip()
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
||||
return (f"❌ Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
return (f"<EFBFBD> Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
||||
|
||||
# Create new builder (don't reuse old state when loading a file)
|
||||
builder = DatasetBuilder()
|
||||
|
|
@ -345,12 +345,12 @@ def load_existing_dataset_for_preprocess(
|
|||
|
||||
# Create info text
|
||||
labeled_count = builder.get_labeled_count()
|
||||
info = f"✅ Loaded dataset: {builder.metadata.name}\n"
|
||||
info += f"✅ Samples: {len(samples)} ({labeled_count} labeled)\n"
|
||||
info += f"✅ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
||||
info += "✅ Ready for preprocessing! You can also edit samples below."
|
||||
info = f"<EFBFBD> Loaded dataset: {builder.metadata.name}\n"
|
||||
info += f"<EFBFBD> Samples: {len(samples)} ({labeled_count} labeled)\n"
|
||||
info += f"<EFBFBD><EFBFBD><EFBFBD>️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
||||
info += "<EFBFBD> Ready for preprocessing! You can also edit samples below."
|
||||
if any((s.formatted_lyrics and not s.lyrics) for s in builder.samples):
|
||||
info += "\nℹ️ Showing formatted lyrics where lyrics are empty."
|
||||
info += "\n<EFBFBD>️ Showing formatted lyrics where lyrics are empty."
|
||||
|
||||
# Get first sample preview
|
||||
first_sample = builder.samples[0]
|
||||
|
|
@ -408,20 +408,20 @@ def preprocess_dataset(
|
|||
Status message
|
||||
"""
|
||||
if builder_state is None:
|
||||
return "❌ No dataset loaded. Please scan a directory first."
|
||||
return "<EFBFBD> No dataset loaded. Please scan a directory first."
|
||||
|
||||
if not builder_state.samples:
|
||||
return "❌ No samples in dataset."
|
||||
return "<EFBFBD> No samples in dataset."
|
||||
|
||||
labeled_count = builder_state.get_labeled_count()
|
||||
if labeled_count == 0:
|
||||
return "❌ No labeled samples. Please auto-label or manually label samples first."
|
||||
return "<EFBFBD> No labeled samples. Please auto-label or manually label samples first."
|
||||
|
||||
if not output_dir or not output_dir.strip():
|
||||
return "❌ Please enter an output directory."
|
||||
return "<EFBFBD> Please enter an output directory."
|
||||
|
||||
if dit_handler is None or dit_handler.model is None:
|
||||
return "❌ Model not initialized. Please initialize the service first."
|
||||
return "<EFBFBD> Model not initialized. Please initialize the service first."
|
||||
|
||||
def progress_callback(msg):
|
||||
if progress:
|
||||
|
|
@ -449,15 +449,15 @@ def load_training_dataset(
|
|||
Info text about the dataset
|
||||
"""
|
||||
if not tensor_dir or not tensor_dir.strip():
|
||||
return "❌ Please enter a tensor directory path"
|
||||
return "<EFBFBD> Please enter a tensor directory path"
|
||||
|
||||
tensor_dir = tensor_dir.strip()
|
||||
|
||||
if not os.path.exists(tensor_dir):
|
||||
return f"❌ Directory not found: {tensor_dir}"
|
||||
return f"<EFBFBD> Directory not found: {tensor_dir}"
|
||||
|
||||
if not os.path.isdir(tensor_dir):
|
||||
return f"❌ Not a directory: {tensor_dir}"
|
||||
return f"<EFBFBD> Not a directory: {tensor_dir}"
|
||||
|
||||
# Check for manifest
|
||||
manifest_path = os.path.join(tensor_dir, "manifest.json")
|
||||
|
|
@ -471,9 +471,9 @@ def load_training_dataset(
|
|||
name = metadata.get("name", "Unknown")
|
||||
custom_tag = metadata.get("custom_tag", "")
|
||||
|
||||
info = f"✅ Loaded preprocessed dataset: {name}\n"
|
||||
info += f"✅ Samples: {num_samples} preprocessed tensors\n"
|
||||
info += f"✅ Custom Tag: {custom_tag or '(none)'}"
|
||||
info = f"<EFBFBD> Loaded preprocessed dataset: {name}\n"
|
||||
info += f"<EFBFBD> Samples: {num_samples} preprocessed tensors\n"
|
||||
info += f"<EFBFBD><EFBFBD><EFBFBD>️ Custom Tag: {custom_tag or '(none)'}"
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
|
|
@ -483,10 +483,10 @@ def load_training_dataset(
|
|||
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
|
||||
|
||||
if not pt_files:
|
||||
return f"❌ No .pt tensor files found in {tensor_dir}"
|
||||
return f"<EFBFBD> No .pt tensor files found in {tensor_dir}"
|
||||
|
||||
info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
||||
info += "⚠️ No manifest.json found - using all .pt files"
|
||||
info = f"<EFBFBD> Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
||||
info += "<EFBFBD>️ No manifest.json found - using all .pt files"
|
||||
|
||||
return info
|
||||
|
||||
|
|
@ -511,7 +511,6 @@ def _format_duration(seconds):
|
|||
def start_training(
|
||||
tensor_dir: str,
|
||||
dit_handler,
|
||||
llm_handler,
|
||||
lora_rank: int,
|
||||
lora_alpha: int,
|
||||
lora_dropout: float,
|
||||
|
|
@ -531,37 +530,17 @@ def start_training(
|
|||
This is a generator function that yields progress updates.
|
||||
"""
|
||||
if not tensor_dir or not tensor_dir.strip():
|
||||
yield "❌ Please enter a tensor directory path", "", None, training_state
|
||||
yield "<EFBFBD> Please enter a tensor directory path", "", None, training_state
|
||||
return
|
||||
|
||||
tensor_dir = tensor_dir.strip()
|
||||
|
||||
if not os.path.exists(tensor_dir):
|
||||
yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
|
||||
yield f"<EFBFBD> Tensor directory not found: {tensor_dir}", "", None, training_state
|
||||
return
|
||||
|
||||
if dit_handler is None or dit_handler.model is None:
|
||||
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
|
||||
return
|
||||
|
||||
# Check for training-incompatible settings
|
||||
incompatible = []
|
||||
if getattr(dit_handler, 'offload_to_cpu', False):
|
||||
incompatible.append('⚠️ "Offload to CPU" is enabled — this will slow down training. Please uncheck it.')
|
||||
if getattr(dit_handler, 'offload_dit_to_cpu', False):
|
||||
incompatible.append('⚠️ "Offload DiT to CPU" is enabled — this will slow down training. Please uncheck it.')
|
||||
if getattr(dit_handler, 'compiled', False):
|
||||
incompatible.append('⚠️ "Compile Model" is enabled — this is incompatible with LoRA training. Please uncheck it.')
|
||||
if getattr(dit_handler, 'quantization', None) is not None:
|
||||
incompatible.append('⚠️ "INT8 Quantization" is enabled — this is incompatible with LoRA training. Please uncheck it.')
|
||||
if llm_handler is not None and getattr(llm_handler, 'llm_initialized', False):
|
||||
incompatible.append('⚠️ "5Hz LM" is initialized — it occupies GPU memory needed for training. Please uncheck "Initialize 5Hz LM".')
|
||||
|
||||
if incompatible:
|
||||
msg = "❌ Training cannot start due to incompatible settings:\n\n"
|
||||
msg += "\n".join(incompatible)
|
||||
msg += "\n\nPlease fix the above settings and re-initialize the DiT model before training."
|
||||
yield msg, "", None, training_state
|
||||
yield "<EFBFBD> Model not initialized. Please initialize the service first.", "", None, training_state
|
||||
return
|
||||
|
||||
# Check for required training dependencies
|
||||
|
|
@ -569,7 +548,7 @@ def start_training(
|
|||
from lightning.fabric import Fabric
|
||||
from peft import get_peft_model, LoraConfig
|
||||
except ImportError as e:
|
||||
yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
||||
yield f"<EFBFBD> Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
||||
return
|
||||
|
||||
training_state["is_training"] = True
|
||||
|
|
@ -606,7 +585,7 @@ def start_training(
|
|||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
yield f"❌ Starting training from {tensor_dir}...", "", loss_data, training_state
|
||||
yield f"<EFBFBD> Starting training from {tensor_dir}...", "", loss_data, training_state
|
||||
|
||||
# Create trainer
|
||||
trainer = LoRATrainer(
|
||||
|
|
@ -623,7 +602,7 @@ def start_training(
|
|||
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
|
||||
# Calculate elapsed time and ETA
|
||||
elapsed_seconds = time.time() - start_time
|
||||
time_info = f"❌ Elapsed: {_format_duration(elapsed_seconds)}"
|
||||
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
|
||||
|
||||
# Parse "Epoch x/y" from status to calculate ETA
|
||||
match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
|
||||
|
|
@ -656,14 +635,14 @@ def start_training(
|
|||
yield display_status, log_text, loss_data, training_state
|
||||
|
||||
if training_state.get("should_stop", False):
|
||||
logger.info("❌ Training stopped by user")
|
||||
log_lines.append("❌ Training stopped by user")
|
||||
yield f"❌ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
|
||||
logger.info("⏹️ Training stopped by user")
|
||||
log_lines.append("⏹️ Training stopped by user")
|
||||
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
|
||||
break
|
||||
|
||||
total_time = time.time() - start_time
|
||||
training_state["is_training"] = False
|
||||
completion_msg = f"❌ Training completed! Total time: {_format_duration(total_time)}"
|
||||
completion_msg = f"<EFBFBD> Training completed! Total time: {_format_duration(total_time)}"
|
||||
|
||||
logger.info(completion_msg)
|
||||
log_lines.append(completion_msg)
|
||||
|
|
@ -675,7 +654,7 @@ def start_training(
|
|||
training_state["is_training"] = False
|
||||
import pandas as pd
|
||||
empty_df = pd.DataFrame({"step": [], "loss": []})
|
||||
yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
|
||||
yield f"<EFBFBD> Error: {str(e)}", str(e), empty_df, training_state
|
||||
|
||||
|
||||
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
||||
|
|
@ -685,10 +664,10 @@ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
|||
Tuple of (status, training_state)
|
||||
"""
|
||||
if not training_state.get("is_training", False):
|
||||
return "❌ No training in progress", training_state
|
||||
return "<EFBFBD>️ No training in progress", training_state
|
||||
|
||||
training_state["should_stop"] = True
|
||||
return "❌ Stopping training...", training_state
|
||||
return "⏹️ Stopping training...", training_state
|
||||
|
||||
|
||||
def export_lora(
|
||||
|
|
@ -701,7 +680,7 @@ def export_lora(
|
|||
Status message
|
||||
"""
|
||||
if not export_path or not export_path.strip():
|
||||
return "❌ Please enter an export path"
|
||||
return "<EFBFBD> Please enter an export path"
|
||||
|
||||
# Check if there's a trained model to export
|
||||
final_dir = os.path.join(lora_output_dir, "final")
|
||||
|
|
@ -714,13 +693,13 @@ def export_lora(
|
|||
# Find the latest checkpoint
|
||||
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
|
||||
if not checkpoints:
|
||||
return "❌ No checkpoints found"
|
||||
return "<EFBFBD> No checkpoints found"
|
||||
|
||||
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
|
||||
latest = checkpoints[-1]
|
||||
source_path = os.path.join(checkpoint_dir, latest)
|
||||
else:
|
||||
return f"❌ No trained model found in {lora_output_dir}"
|
||||
return f"<EFBFBD> No trained model found in {lora_output_dir}"
|
||||
|
||||
try:
|
||||
import shutil
|
||||
|
|
@ -733,11 +712,11 @@ def export_lora(
|
|||
|
||||
shutil.copytree(source_path, export_path)
|
||||
|
||||
return f"❌ LoRA exported to {export_path}"
|
||||
return f"<EFBFBD> LoRA exported to {export_path}"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Export error")
|
||||
return f"❌ Export failed: {str(e)}"
|
||||
return f"<EFBFBD> Export failed: {str(e)}"
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue