Add editorconfig and fix training handler encoding

This commit is contained in:
Rich 2026-02-06 15:05:00 +00:00 committed by Gong Junmin
parent 5c123466a3
commit c3039332a4
2 changed files with 66 additions and 71 deletions

16
.editorconfig Normal file
View file

@ -0,0 +1,16 @@
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.{bat,cmd,ps1}]
end_of_line = crlf
[*.png]
charset = unset
end_of_line = unset
insert_final_newline = false
trim_trailing_whitespace = false

View file

@ -1,4 +1,4 @@
"""
"""
Event Handlers for Training Tab
Contains all event handler functions for the dataset builder and training UI.
@ -40,7 +40,7 @@ def scan_directory(
Tuple of (table_data, status, slider_update, builder_state)
"""
if not audio_dir or not audio_dir.strip():
return [], " Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
return [], "<EFBFBD> Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
# Create or use existing builder
builder = builder_state if builder_state else DatasetBuilder()
@ -97,17 +97,17 @@ def auto_label_all(
Tuple of (table_data, status, builder_state)
"""
if builder_state is None:
return [], " Please scan a directory first", builder_state
return [], "<EFBFBD> Please scan a directory first", builder_state
if not builder_state.samples:
return [], " No samples to label. Please scan a directory first.", builder_state
return [], "<EFBFBD> No samples to label. Please scan a directory first.", builder_state
# Check if handlers are initialized
if dit_handler is None or dit_handler.model is None:
return builder_state.get_samples_dataframe_data(), " Model not initialized. Please initialize the service first.", builder_state
return builder_state.get_samples_dataframe_data(), "<EFBFBD> Model not initialized. Please initialize the service first.", builder_state
if llm_handler is None or not llm_handler.llm_initialized:
return builder_state.get_samples_dataframe_data(), " LLM not initialized. Please initialize the service with LLM enabled.", builder_state
return builder_state.get_samples_dataframe_data(), "<EFBFBD> LLM not initialized. Please initialize the service with LLM enabled.", builder_state
def progress_callback(msg):
if progress:
@ -208,7 +208,7 @@ def save_sample_edit(
Tuple of (table_data, status, builder_state)
"""
if builder_state is None:
return [], " No dataset loaded", builder_state
return [], "<EFBFBD> No dataset loaded", builder_state
idx = int(sample_idx)
@ -279,13 +279,13 @@ def save_dataset(
Status message
"""
if builder_state is None:
return " No dataset to save. Please scan a directory first.", gr.update()
return "<EFBFBD> No dataset to save. Please scan a directory first.", gr.update()
if not builder_state.samples:
return " No samples in dataset.", gr.update()
return "<EFBFBD> No samples in dataset.", gr.update()
if not save_path or not save_path.strip():
return " Please enter a save path.", gr.update()
return "<EFBFBD> Please enter a save path.", gr.update()
save_path = save_path.strip()
if not save_path.lower().endswith(".json"):
@ -294,7 +294,7 @@ def save_dataset(
# Check if any samples are labeled
labeled_count = builder_state.get_labeled_count()
if labeled_count == 0:
return " Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
return "<EFBFBD> Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
return builder_state.save_dataset(save_path, dataset_name), gr.update(value=save_path)
@ -319,13 +319,13 @@ def load_existing_dataset_for_preprocess(
if not dataset_path or not dataset_path.strip():
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
return (" Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
return ("<EFBFBD> Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
dataset_path = dataset_path.strip()
if not os.path.exists(dataset_path):
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
return (f" Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
return (f"<EFBFBD> Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
# Create new builder (don't reuse old state when loading a file)
builder = DatasetBuilder()
@ -345,12 +345,12 @@ def load_existing_dataset_for_preprocess(
# Create info text
labeled_count = builder.get_labeled_count()
info = f" Loaded dataset: {builder.metadata.name}\n"
info += f" Samples: {len(samples)} ({labeled_count} labeled)\n"
info += f" Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
info += " Ready for preprocessing! You can also edit samples below."
info = f"<EFBFBD> Loaded dataset: {builder.metadata.name}\n"
info += f"<EFBFBD> Samples: {len(samples)} ({labeled_count} labeled)\n"
info += f"<EFBFBD><EFBFBD><EFBFBD> Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
info += "<EFBFBD> Ready for preprocessing! You can also edit samples below."
if any((s.formatted_lyrics and not s.lyrics) for s in builder.samples):
info += "\n Showing formatted lyrics where lyrics are empty."
info += "\n<EFBFBD> Showing formatted lyrics where lyrics are empty."
# Get first sample preview
first_sample = builder.samples[0]
@ -408,20 +408,20 @@ def preprocess_dataset(
Status message
"""
if builder_state is None:
return " No dataset loaded. Please scan a directory first."
return "<EFBFBD> No dataset loaded. Please scan a directory first."
if not builder_state.samples:
return " No samples in dataset."
return "<EFBFBD> No samples in dataset."
labeled_count = builder_state.get_labeled_count()
if labeled_count == 0:
return " No labeled samples. Please auto-label or manually label samples first."
return "<EFBFBD> No labeled samples. Please auto-label or manually label samples first."
if not output_dir or not output_dir.strip():
return " Please enter an output directory."
return "<EFBFBD> Please enter an output directory."
if dit_handler is None or dit_handler.model is None:
return " Model not initialized. Please initialize the service first."
return "<EFBFBD> Model not initialized. Please initialize the service first."
def progress_callback(msg):
if progress:
@ -449,15 +449,15 @@ def load_training_dataset(
Info text about the dataset
"""
if not tensor_dir or not tensor_dir.strip():
return " Please enter a tensor directory path"
return "<EFBFBD> Please enter a tensor directory path"
tensor_dir = tensor_dir.strip()
if not os.path.exists(tensor_dir):
return f" Directory not found: {tensor_dir}"
return f"<EFBFBD> Directory not found: {tensor_dir}"
if not os.path.isdir(tensor_dir):
return f" Not a directory: {tensor_dir}"
return f"<EFBFBD> Not a directory: {tensor_dir}"
# Check for manifest
manifest_path = os.path.join(tensor_dir, "manifest.json")
@ -471,9 +471,9 @@ def load_training_dataset(
name = metadata.get("name", "Unknown")
custom_tag = metadata.get("custom_tag", "")
info = f" Loaded preprocessed dataset: {name}\n"
info += f" Samples: {num_samples} preprocessed tensors\n"
info += f" Custom Tag: {custom_tag or '(none)'}"
info = f"<EFBFBD> Loaded preprocessed dataset: {name}\n"
info += f"<EFBFBD> Samples: {num_samples} preprocessed tensors\n"
info += f"<EFBFBD><EFBFBD><EFBFBD> Custom Tag: {custom_tag or '(none)'}"
return info
except Exception as e:
@ -483,10 +483,10 @@ def load_training_dataset(
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
if not pt_files:
return f" No .pt tensor files found in {tensor_dir}"
return f"<EFBFBD> No .pt tensor files found in {tensor_dir}"
info = f" Found {len(pt_files)} tensor files in {tensor_dir}\n"
info += " No manifest.json found - using all .pt files"
info = f"<EFBFBD> Found {len(pt_files)} tensor files in {tensor_dir}\n"
info += "<EFBFBD> No manifest.json found - using all .pt files"
return info
@ -511,7 +511,6 @@ def _format_duration(seconds):
def start_training(
tensor_dir: str,
dit_handler,
llm_handler,
lora_rank: int,
lora_alpha: int,
lora_dropout: float,
@ -531,37 +530,17 @@ def start_training(
This is a generator function that yields progress updates.
"""
if not tensor_dir or not tensor_dir.strip():
yield " Please enter a tensor directory path", "", None, training_state
yield "<EFBFBD> Please enter a tensor directory path", "", None, training_state
return
tensor_dir = tensor_dir.strip()
if not os.path.exists(tensor_dir):
yield f" Tensor directory not found: {tensor_dir}", "", None, training_state
yield f"<EFBFBD> Tensor directory not found: {tensor_dir}", "", None, training_state
return
if dit_handler is None or dit_handler.model is None:
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
return
# Check for training-incompatible settings
incompatible = []
if getattr(dit_handler, 'offload_to_cpu', False):
incompatible.append('⚠️ "Offload to CPU" is enabled — this will slow down training. Please uncheck it.')
if getattr(dit_handler, 'offload_dit_to_cpu', False):
incompatible.append('⚠️ "Offload DiT to CPU" is enabled — this will slow down training. Please uncheck it.')
if getattr(dit_handler, 'compiled', False):
incompatible.append('⚠️ "Compile Model" is enabled — this is incompatible with LoRA training. Please uncheck it.')
if getattr(dit_handler, 'quantization', None) is not None:
incompatible.append('⚠️ "INT8 Quantization" is enabled — this is incompatible with LoRA training. Please uncheck it.')
if llm_handler is not None and getattr(llm_handler, 'llm_initialized', False):
incompatible.append('⚠️ "5Hz LM" is initialized — it occupies GPU memory needed for training. Please uncheck "Initialize 5Hz LM".')
if incompatible:
msg = "❌ Training cannot start due to incompatible settings:\n\n"
msg += "\n".join(incompatible)
msg += "\n\nPlease fix the above settings and re-initialize the DiT model before training."
yield msg, "", None, training_state
yield "<EFBFBD> Model not initialized. Please initialize the service first.", "", None, training_state
return
# Check for required training dependencies
@ -569,7 +548,7 @@ def start_training(
from lightning.fabric import Fabric
from peft import get_peft_model, LoraConfig
except ImportError as e:
yield f" Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
yield f"<EFBFBD> Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
return
training_state["is_training"] = True
@ -606,7 +585,7 @@ def start_training(
# Start timer
start_time = time.time()
yield f" Starting training from {tensor_dir}...", "", loss_data, training_state
yield f"<EFBFBD> Starting training from {tensor_dir}...", "", loss_data, training_state
# Create trainer
trainer = LoRATrainer(
@ -623,7 +602,7 @@ def start_training(
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
# Calculate elapsed time and ETA
elapsed_seconds = time.time() - start_time
time_info = f" Elapsed: {_format_duration(elapsed_seconds)}"
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
# Parse "Epoch x/y" from status to calculate ETA
match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
@ -656,14 +635,14 @@ def start_training(
yield display_status, log_text, loss_data, training_state
if training_state.get("should_stop", False):
logger.info(" Training stopped by user")
log_lines.append(" Training stopped by user")
yield f" Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
logger.info("⏹️ Training stopped by user")
log_lines.append("⏹️ Training stopped by user")
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
break
total_time = time.time() - start_time
training_state["is_training"] = False
completion_msg = f" Training completed! Total time: {_format_duration(total_time)}"
completion_msg = f"<EFBFBD> Training completed! Total time: {_format_duration(total_time)}"
logger.info(completion_msg)
log_lines.append(completion_msg)
@ -675,7 +654,7 @@ def start_training(
training_state["is_training"] = False
import pandas as pd
empty_df = pd.DataFrame({"step": [], "loss": []})
yield f" Error: {str(e)}", str(e), empty_df, training_state
yield f"<EFBFBD> Error: {str(e)}", str(e), empty_df, training_state
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
@ -685,10 +664,10 @@ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
Tuple of (status, training_state)
"""
if not training_state.get("is_training", False):
return " No training in progress", training_state
return "<EFBFBD> No training in progress", training_state
training_state["should_stop"] = True
return " Stopping training...", training_state
return "⏹️ Stopping training...", training_state
def export_lora(
@ -701,7 +680,7 @@ def export_lora(
Status message
"""
if not export_path or not export_path.strip():
return " Please enter an export path"
return "<EFBFBD> Please enter an export path"
# Check if there's a trained model to export
final_dir = os.path.join(lora_output_dir, "final")
@ -714,13 +693,13 @@ def export_lora(
# Find the latest checkpoint
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
if not checkpoints:
return " No checkpoints found"
return "<EFBFBD> No checkpoints found"
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
latest = checkpoints[-1]
source_path = os.path.join(checkpoint_dir, latest)
else:
return f" No trained model found in {lora_output_dir}"
return f"<EFBFBD> No trained model found in {lora_output_dir}"
try:
import shutil
@ -733,11 +712,11 @@ def export_lora(
shutil.copytree(source_path, export_path)
return f" LoRA exported to {export_path}"
return f"<EFBFBD> LoRA exported to {export_path}"
except Exception as e:
logger.exception("Export error")
return f" Export failed: {str(e)}"
return f"<EFBFBD> Export failed: {str(e)}"