liumaolin commited on
Commit
dff81cd
·
1 Parent(s): 49308e9

feat(inference): add `_parse_list_file` to handle default values for `ref_audio_path` and `ref_text`

Browse files

- Parse `slicer_opt.list` for default `ref_audio_path` and `ref_text` values if not provided in the config.

training_pipeline/stages/inference.py CHANGED
@@ -165,10 +165,38 @@ class InferenceStage(BaseStage):
165
  sovits_paths = self.config.sovits_paths()
166
  return len(gpt_paths) > 0 and len(sovits_paths) > 0
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def run(self) -> Generator[Dict[str, Any], None, None]:
169
  self._status = StageStatus.RUNNING
170
  cfg = self.config
171
 
 
 
 
 
 
 
172
  # 确保输出目录存在
173
  os.makedirs(cfg.output_dir, exist_ok=True)
174
 
 
165
  sovits_paths = self.config.sovits_paths()
166
  return len(gpt_paths) > 0 and len(sovits_paths) > 0
167
 
168
+ def _parse_list_file(self) -> tuple[str, str]:
169
+ """从 asr_opt/slicer_opt.list 解析第一行获取 ref_audio_path 和 ref_text
170
+
171
+ Returns:
172
+ (ref_audio_path, ref_text) 元组,解析失败返回空字符串
173
+ """
174
+ list_path = os.path.join(self.config.exp_dir, 'asr_opt', 'slicer_opt.list')
175
+ if not os.path.exists(list_path):
176
+ return "", ""
177
+
178
+ with open(list_path, 'r', encoding='utf-8') as f:
179
+ first_line = f.readline().strip()
180
+
181
+ if not first_line:
182
+ return "", ""
183
+
184
+ # 格式: {音频路径}|{文件夹名}|{语言}|{识别文本}
185
+ parts = first_line.split('|')
186
+ if len(parts) >= 4:
187
+ return parts[0], parts[3]
188
+ return "", ""
189
+
190
  def run(self) -> Generator[Dict[str, Any], None, None]:
191
  self._status = StageStatus.RUNNING
192
  cfg = self.config
193
 
194
+ # 如果 ref_text 或 ref_audio_path 为空,从 .list 文件解析默认值
195
+ if not cfg.ref_text or not cfg.ref_audio_path:
196
+ parsed_audio, parsed_text = self._parse_list_file()
197
+ cfg.ref_audio_path = parsed_audio
198
+ cfg.ref_text = parsed_text
199
+
200
  # 确保输出目录存在
201
  os.makedirs(cfg.output_dir, exist_ok=True)
202