| | import gradio as gr |
| | import os |
| | import time |
| | import re |
| | import json |
| | from concurrent.futures import ThreadPoolExecutor, as_completed |
| | from openai import OpenAI |
| | import random |
| | from concurrent.futures import TimeoutError as FuturesTimeoutError |
| | from openai import APIStatusError, APITimeoutError, APIConnectionError |
| | import traceback |
| | from dotenv import load_dotenv |
| | from prompts import ( |
| | USER_PROMPT, |
| | WRAPPER_PROMPT, |
| | CALL_1_SYSTEM_PROMPT, |
| | CALL_2_SYSTEM_PROMPT, |
| | CALL_3_SYSTEM_PROMPT, |
| | ) |
| | import difflib |
| | import csv |
| | from threading import Lock |
| | import threading |
| |
|
| | load_dotenv() |
| |
|
| | BASE_URL = "https://api.upstage.ai/v1" |
| | API_KEY = os.getenv("OPENAI_API_KEY") |
| |
|
| | client = OpenAI(api_key=API_KEY, base_url=BASE_URL) |
| |
|
| |
|
| | import re |
| | import re |
| | import re |
| |
|
| | def postprocess_pronoun(text: str) -> str: |
| | """ |
| | '์ด์ฌ๋ช
๋ํ'๊ฐ ํฌํจ๋ ๋ชจ๋ ๋จ์ด๋ฅผ '์ด์ฌ๋ช
๋ํต๋ น'์ผ๋ก ๊ต์ฒดํ๋ฉฐ, |
| | ๋ค๋ฐ๋ฅด๋ ์กฐ์ฌ๊ฐ ์์ ๊ฒฝ์ฐ ํจ๊ป ์์ ํฉ๋๋ค. |
| | """ |
| | |
| | |
| | correction_map = { |
| | '๋': '์', '๊ฐ': '์ด', '๋ฅผ': '์', '์': '๊ณผ', '๋ก': '์ผ๋ก', |
| | '์ฌ': '์ด์ฌ', '๋ผ': '์ด๋ผ', '๋': '์ด๋', |
| | '๋ค': '์ด๋ค', '์๋ค': '์ด์๋ค', '๋ผ๋ฉด': '์ด๋ผ๋ฉด', '๋ผ์': '์ด๋ผ์' |
| | } |
| | |
| | |
| | all_target_particles = list(correction_map.keys()) + ['๋ก๋ถํฐ', '๋ง', '๋', '๊ป์'] |
| | |
| | particle_pattern = "|".join(re.escape(p) for p in all_target_particles) |
| | |
| | |
| | regex = re.compile(f"(์ด์ฌ๋ช
\s*๋ํ)({particle_pattern})?") |
| |
|
| | def replace_func(match): |
| | particle = match.group(2) |
| | new_phrase = "์ด์ฌ๋ช
๋ํต๋ น" |
| | |
| | if particle: |
| | new_phrase += correction_map.get(particle, particle) |
| | |
| | return new_phrase |
| |
|
| | return regex.sub(replace_func, text) |
| |
|
| |
|
| |
|
| | def extract_json_from_text(text): |
| | """ |
| | ํ
์คํธ์์ JSON ๋ถ๋ถ์ ์ถ์ถํฉ๋๋ค. |
| | ์ฌ๋ฌ ํจํด์ ์๋ํ์ฌ JSON์ ์ฐพ์ต๋๋ค. |
| | |
| | Args: |
| | text: JSON์ด ํฌํจ๋ ํ
์คํธ |
| | |
| | Returns: |
| | dict: ํ์ฑ๋ JSON ๊ฐ์ฒด ๋๋ None |
| | """ |
| | if not text or not text.strip(): |
| | return None |
| | |
| | |
| | text = text.strip() |
| | |
| | |
| | json_code_block_pattern = r'```json\s*(.*?)\s*```' |
| | match = re.search(json_code_block_pattern, text, re.DOTALL) |
| | if match: |
| | try: |
| | extracted = match.group(1).strip() |
| | if extracted: |
| | return json.loads(extracted) |
| | except json.JSONDecodeError: |
| | pass |
| | |
| | |
| | code_block_pattern = r'```\s*(.*?)\s*```' |
| | match = re.search(code_block_pattern, text, re.DOTALL) |
| | if match: |
| | try: |
| | extracted = match.group(1).strip() |
| | if extracted: |
| | return json.loads(extracted) |
| | except json.JSONDecodeError: |
| | pass |
| | |
| | |
| | json_object_pattern = r'\{.*\}' |
| | match = re.search(json_object_pattern, text, re.DOTALL) |
| | if match: |
| | try: |
| | extracted = match.group(0).strip() |
| | if extracted: |
| | return json.loads(extracted) |
| | except json.JSONDecodeError: |
| | pass |
| | |
| | |
| | json_array_pattern = r'\[.*\]' |
| | match = re.search(json_array_pattern, text, re.DOTALL) |
| | if match: |
| | try: |
| | extracted = match.group(0).strip() |
| | if extracted: |
| | return json.loads(extracted) |
| | except json.JSONDecodeError: |
| | pass |
| | |
| | |
| | try: |
| | if text.startswith('{') or text.startswith('['): |
| | return json.loads(text) |
| | except json.JSONDecodeError: |
| | pass |
| | |
| | return None |
| |
|
| |
|
| |
|
| | |
| | def load_vocabulary(): |
| | vocabulary = {} |
| | with open("Vocabulary.csv", "r", encoding="utf-8-sig") as f: |
| | reader = csv.DictReader(f) |
| | for row in reader: |
| | |
| | if len(vocabulary) == 0: |
| | print("CSV columns:", list(row.keys())) |
| | vocabulary[row["original"]] = row["corrected"] |
| | return vocabulary |
| |
|
| |
|
| | VOCABULARY = load_vocabulary() |
| |
|
| | |
| | counter_lock = Lock() |
| | processed_count = 0 |
| | total_bulks = 0 |
| |
|
| |
|
| | def apply_vocabulary_correction(text): |
| | for original, corrected in VOCABULARY.items(): |
| | text = text.replace(original, corrected) |
| | return text |
| |
|
| |
|
| | def create_bulk_paragraphs(text, max_chars=500): |
| | """ |
| | ํ
์คํธ๋ฅผ 500์ ๊ธฐ์ค์ผ๋ก ๋ฒํฌ ๋จ์๋ก ๋ถํ ํฉ๋๋ค. |
| | |
| | Args: |
| | text: ์
๋ ฅ ํ
์คํธ |
| | max_chars: ์ต๋ ๋ฌธ์ ์ (๊ธฐ๋ณธ๊ฐ: 500) |
| | |
| | Returns: |
| | List[str]: ๋ฒํฌ ๋จ์๋ก ๋ถํ ๋ ํ
์คํธ ๋ฆฌ์คํธ |
| | """ |
| | paragraphs = [p.strip() for p in text.split("\n") if p.strip()] |
| |
|
| | if not paragraphs: |
| | return [] |
| |
|
| | bulks = [] |
| | current_bulk = [] |
| | current_length = 0 |
| |
|
| | for para in paragraphs: |
| | para_length = len(para) |
| |
|
| | |
| | if para_length > max_chars: |
| | |
| | if current_bulk: |
| | bulks.append("\n".join(current_bulk)) |
| | current_bulk = [] |
| | current_length = 0 |
| |
|
| | |
| | bulks.append(para) |
| | else: |
| | |
| | if ( |
| | current_length + para_length + len(current_bulk) > max_chars |
| | and current_bulk |
| | ): |
| | |
| | bulks.append("\n".join(current_bulk)) |
| | current_bulk = [para] |
| | current_length = para_length |
| | else: |
| | |
| | current_bulk.append(para) |
| | current_length += para_length |
| |
|
| | |
| | if current_bulk: |
| | bulks.append("\n".join(current_bulk)) |
| |
|
| | return bulks |
| |
|
| |
|
| |
|
| | def process_bulk(bulk_text, bulk_index, max_retries=3, article_info=""): |
| | """ |
| | ํ๋์ ๋ฒํฌ๋ฅผ ํ์ดํ๋ผ์ธ์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค. |
| | API ์๋ฌ ๋ฐ์ ์, ๋ง์ง๋ง์ผ๋ก ์ฑ๊ณตํ ๋จ๊ณ์ ๊ฒฐ๊ณผ๋ฌผ์ ๋ฐํํฉ๋๋ค. |
| | """ |
| | global processed_count |
| | thread_id = threading.get_ident() |
| | start = time.time() |
| |
|
| | |
| | step0, proofread_result, step1, step1_explanation, step2, step2_explanation, step3, step4, step5 = (None,) * 9 |
| | |
| | |
| | last_successful_output = bulk_text |
| |
|
| | for attempt in range(max_retries): |
| | try: |
| | |
| | step0 = apply_vocabulary_correction(bulk_text) |
| | last_successful_output = step0 |
| | |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling proofread...") |
| | proofread_result = call_proofread(step0) |
| |
|
| |
|
| | step0 = step0.replace("\n", "<paragraph_separator>") |
| | proofread_result = proofread_result.replace("\n", "<paragraph_separator>") |
| |
|
| | |
| | |
| | system_step1 = WRAPPER_PROMPT.format(system_prompt=CALL_1_SYSTEM_PROMPT) |
| | user_step1 = USER_PROMPT.format(original=step0, proofread=proofread_result) |
| | |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step1...") |
| | step1_json = call_solar_pro2(system_step1, user_step1) |
| | try: |
| | parsed_json = json.loads(step1_json) |
| | step1 = parsed_json.get('output', step0) |
| | step1_explanation = parsed_json.get('explanation', '') |
| | last_successful_output = step1 |
| | except json.JSONDecodeError: |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step1 JSON ํ์ฑ ์คํจ. ์ถ์ถ ์๋...") |
| | extracted_json = extract_json_from_text(step1_json) |
| | if extracted_json and 'output' in extracted_json: |
| | step1 = extracted_json['output'] |
| | step1_explanation = extracted_json.get('explanation', '') |
| | last_successful_output = step1 |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์ฑ๊ณต") |
| | else: |
| | step1 = step0 |
| | step1_explanation = "" |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์คํจ") |
| | |
| | |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step2...") |
| | step2_json = call_solar_pro2(CALL_2_SYSTEM_PROMPT, step1) |
| | try: |
| | parsed_json = json.loads(step2_json) |
| | step2 = parsed_json.get('output', step1) |
| | step2_explanation = parsed_json.get('explanation', '') |
| | last_successful_output = step2 |
| | except json.JSONDecodeError: |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step2 JSON ํ์ฑ ์คํจ. ์ถ์ถ ์๋...") |
| | extracted_json = extract_json_from_text(step2_json) |
| | if extracted_json and 'output' in extracted_json: |
| | step2 = extracted_json['output'] |
| | step2_explanation = extracted_json.get('explanation', '') |
| | last_successful_output = step2 |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์ฑ๊ณต") |
| | else: |
| | step2 = step1 |
| | step2_explanation = "" |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์คํจ") |
| |
|
| | |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Calling step3...") |
| | step3_json = call_solar_pro2(CALL_3_SYSTEM_PROMPT, step2) |
| | try: |
| | parsed_json = json.loads(step3_json) |
| | step3 = parsed_json.get('output', step2) |
| | last_successful_output = step3 |
| | except json.JSONDecodeError: |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - Step3 JSON ํ์ฑ ์คํจ. ์ถ์ถ ์๋...") |
| | extracted_json = extract_json_from_text(step3_json) |
| | if extracted_json and 'output' in extracted_json: |
| | step3 = extracted_json['output'] |
| | last_successful_output = step3 |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์ฑ๊ณต") |
| | else: |
| | step3 = step2 |
| | print(f"{article_info}[Thread-{thread_id}] Bulk {bulk_index+1} Attempt {attempt+1} - JSON ์ถ์ถ ์คํจ") |
| | |
| | |
| | step4 = apply_vocabulary_correction(step3) |
| | |
| | |
| | step5 = postprocess_pronoun(step4) |
| | last_successful_output = step5 |
| |
|
| |
|
| | step5 = step5.replace("<paragraph_separator>", "\n") |
| | last_successful_output = last_successful_output.replace("<paragraph_separator>", "\n") |
| |
|
| | elapsed = time.time() - start |
| |
|
| | with counter_lock: |
| | processed_count += 1 |
| |
|
| | |
| | return { |
| | "bulk_index": bulk_index, |
| | "original": bulk_text, |
| | "final": last_successful_output, |
| | "processing_time": elapsed, |
| | "character_count": len(bulk_text), |
| | "attempts": attempt + 1, |
| | } |
| |
|
| | except Exception as e: |
| | if attempt < max_retries - 1: |
| | print( |
| | f"{article_info}[Thread-{thread_id}] ๋ฒํฌ {bulk_index+1} ์๋ {attempt+1} ์คํจ, ์ฌ์๋: {type(e).__name__}" |
| | ) |
| | time.sleep(1 * (attempt + 1)) |
| | continue |
| | else: |
| | |
| | print(f"๐ฅ๐ฅ๐ฅ {article_info}[Thread-{thread_id}] ๋ฒํฌ {bulk_index+1} ์ต์ข
์คํจ! ๋ง์ง๋ง ์ฑ๊ณต ๊ฒฐ๊ณผ๋ฌผ์ ์ฌ์ฉํฉ๋๋ค. ๐ฅ๐ฅ๐ฅ") |
| | traceback.print_exc() |
| |
|
| | return { |
| | "bulk_index": bulk_index, |
| | "original": bulk_text, |
| | "final": last_successful_output, |
| | "processing_time": time.time() - start, |
| | "character_count": len(bulk_text), |
| | "error": traceback.format_exc(), |
| | "attempts": max_retries, |
| | } |
| | |
| | |
| | return {"bulk_index": bulk_index, "final": bulk_text, "error": "unknown_flow_error"} |
| |
|
| |
|
| |
|
| | def call_solar_pro2(system, user, temperature=0.0, model_name="solar-pro2"): |
| | response = client.chat.completions.create( |
| | model=model_name, |
| | messages=[ |
| | {"role": "system", "content": system}, |
| | {"role": "user", "content": user}, |
| | ], |
| | stream=False, |
| | temperature=temperature, |
| | ) |
| | return response.choices[0].message.content |
| |
|
| |
|
| | def call_proofread(paragraph): |
| | prompt = "์
๋ ฅ๋ ๋ฌธ์์ ๋ํ ๊ต์ด ๊ฒฐ๊ณผ๋ฅผ ์์ฑํด ์ฃผ์ธ์." |
| | response = client.chat.completions.create( |
| | model="ft:solar-news-correction-dev", |
| | messages=[ |
| | {"role": "system", "content": prompt}, |
| | {"role": "user", "content": paragraph}, |
| | ], |
| | stream=False, |
| | temperature=0.0, |
| | ) |
| | return response.choices[0].message.content |
| |
|
| |
|
| | def highlight_diff(original, corrected): |
| | matcher = difflib.SequenceMatcher(None, original, corrected) |
| | result_html = [] |
| | for tag, i1, i2, j1, j2 in matcher.get_opcodes(): |
| | if tag == "equal": |
| | result_html.append(f"<span>{original[i1:i2]}</span>") |
| | elif tag == "replace": |
| | result_html.append( |
| | f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>' |
| | ) |
| | result_html.append( |
| | f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>' |
| | ) |
| | elif tag == "delete": |
| | result_html.append( |
| | f'<span style="background:#ffecec;text-decoration:line-through;">{original[i1:i2]}</span>' |
| | ) |
| | elif tag == "insert": |
| | result_html.append( |
| | f'<span style="background:#e6ffec;">{corrected[j1:j2]}</span>' |
| | ) |
| | return "".join(result_html) |
| |
|
| |
|
| | def process_text_parallel(input_text, max_workers=10): |
| | """ํ
์คํธ๋ฅผ ๋ฒํฌ ๋จ์๋ก ๋ณ๋ ฌ ์ฒ๋ฆฌํฉ๋๋ค.""" |
| | global processed_count, total_bulks |
| |
|
| | |
| | bulks = create_bulk_paragraphs(input_text) |
| | total_bulks = len(bulks) |
| | processed_count = 0 |
| |
|
| | if not bulks: |
| | return [] |
| |
|
| | results = [] |
| |
|
| | |
| | with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| | |
| | future_to_bulk = { |
| | executor.submit(process_bulk, bulk, i): i for i, bulk in enumerate(bulks) |
| | } |
| |
|
| | |
| | for future in as_completed(future_to_bulk): |
| | try: |
| | result = future.result() |
| | results.append(result) |
| | except Exception as e: |
| | bulk_index = future_to_bulk[future] |
| | print(f"๋ฒํฌ {bulk_index+1} ์ฒ๋ฆฌ ์ค ์์ธ ๋ฐ์: {e}") |
| | results.append( |
| | { |
| | "bulk_index": bulk_index, |
| | "original": bulks[bulk_index], |
| | "final": bulks[bulk_index], |
| | "processing_time": 0, |
| | "character_count": len(bulks[bulk_index]), |
| | "error": str(e), |
| | } |
| | ) |
| |
|
| | |
| | results.sort(key=lambda x: x["bulk_index"]) |
| |
|
| | return results |
| |
|
| |
|
| |
|
| | def demo_fn(input_text): |
| | |
| | bulk_results = process_text_parallel(input_text, max_workers=10) |
| |
|
| | if not bulk_results: |
| | return input_text, input_text |
| |
|
| | |
| | final_texts = [r["final"] for r in bulk_results] |
| | final_result = "\n".join(final_texts) |
| |
|
| | |
| | highlighted = highlight_diff(input_text, final_result) |
| |
|
| | return final_result, highlighted |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# ๊ต์ด ๋ชจ๋ธ ๋ฐ๋ชจ") |
| | input_text = gr.Textbox( |
| | label="์๋ฌธ ์
๋ ฅ", lines=10, placeholder="๋ฌธ๋จ ๋จ์๋ก ์
๋ ฅํด ์ฃผ์ธ์." |
| | ) |
| | btn = gr.Button("๊ต์ดํ๊ธฐ") |
| | output_corrected = gr.Textbox(label="๊ต์ด ๊ฒฐ๊ณผ", lines=10) |
| | output_highlight = gr.HTML(label="์์ ๋ ๋ถ๋ถ ๊ฐ์กฐ") |
| |
|
| | btn.click( |
| | fn=demo_fn, inputs=input_text, outputs=[output_corrected, output_highlight] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|