Philips656 commited on
Commit
960ee52
·
verified ·
1 Parent(s): abb43a8

Update shield.py

Browse files
Files changed (1) hide show
  1. shield.py +30 -54
shield.py CHANGED
@@ -3,9 +3,9 @@ from flask import Flask, request, Response
3
 
4
  app = Flask(__name__)
5
 
6
- # --- CONFIGURATION ---
7
  NEWAPI_INTERNAL = "http://127.0.0.1:3000"
8
- OPENAI_MOD_KEY = "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"
9
 
10
  TIDB_CONFIG = {
11
  "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
@@ -19,70 +19,46 @@ TIDB_CONFIG = {
19
  "ssl_verify_cert": True
20
  }
21
 
22
- def log_violation(user_id, prompt):
23
  try:
24
  conn = mysql.connector.connect(**TIDB_CONFIG)
25
  cursor = conn.cursor()
26
- cursor.execute("INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)",
27
- (str(user_id), str(prompt)[:1000]))
28
- conn.commit()
29
- print(f"✅ LOGGED VIOLATION: {user_id}")
30
  conn.close()
31
- except Exception as e:
32
- print(f"❌ DATABASE ERROR: {e}")
33
 
34
- # 1. SERVE UI
35
- @app.route('/')
36
- def home():
37
- with open('index.html', 'r') as f:
38
- return f.read()
39
-
40
- # 2. PROXY AUTH & MODEL LISTS DIRECTLY (Don't Block These)
41
- @app.route('/api/<path:subpath>', methods=['GET', 'POST', 'PUT', 'DELETE'])
42
- def proxy_api(subpath):
43
- return forward_request(f"{NEWAPI_INTERNAL}/api/{subpath}")
44
-
45
- @app.route('/v1/models', methods=['GET'])
46
- def proxy_models():
47
- return forward_request(f"{NEWAPI_INTERNAL}/v1/models")
48
-
49
- # 3. INTERCEPT & PROTECT CHAT
50
  @app.route('/v1/chat/completions', methods=['POST'])
51
  def protect_chat():
52
  data = request.json
53
- # Extract text from latest message
54
- msgs = data.get('messages', [])
55
- last_msg = msgs[-1].get('content', '') if msgs else ""
56
 
57
- # Check Safety (using your OpenAI Key for moderation)
58
  try:
59
- mod = requests.post("https://api.openai.com/v1/moderations",
60
- headers={"Authorization": f"Bearer {OPENAI_MOD_KEY}"},
61
- json={"input": last_msg}, timeout=3).json()
62
-
63
- if mod.get('results', [{}])[0].get('categories', {}).get('sexual/minors'):
64
- user_token = request.headers.get('Authorization', 'Anon')
65
- log_violation(user_token, last_msg)
66
- return {"error": {"message": "Policy Violation: Blocked by Shield.", "type": "safety_error"}}, 403
67
- except Exception as e:
68
- print(f"Moderation Warning: {e}")
69
 
70
- # If Safe -> Forward to NewAPI (using the USER'S token from the UI)
71
- return forward_request(f"{NEWAPI_INTERNAL}/v1/chat/completions")
72
 
73
- def forward_request(url):
74
- try:
75
- resp = requests.request(
76
- method=request.method,
77
- url=url,
78
- headers={k: v for k, v in request.headers if k.lower() != 'host'},
79
- data=request.get_data(),
80
- allow_redirects=False
81
- )
82
- return Response(resp.content, resp.status_code, resp.headers.items())
83
- except Exception as e:
84
- return {"error": str(e)}, 502
 
 
 
85
 
86
  if __name__ == '__main__':
87
- # Start Shield on 7860
88
  app.run(host='0.0.0.0', port=7860)
 
3
 
4
  app = Flask(__name__)
5
 
6
+ # --- SETTINGS ---
7
  NEWAPI_INTERNAL = "http://127.0.0.1:3000"
8
+ MODERATION_KEY = "sk-proj-LYW3iVcaE..." # Your OpenAI Key for the safety check
9
 
10
  TIDB_CONFIG = {
11
  "host": "gateway01.eu-central-1.prod.aws.tidbcloud.com",
 
19
  "ssl_verify_cert": True
20
  }
21
 
22
+ def log_violation(user, prompt):
23
  try:
24
  conn = mysql.connector.connect(**TIDB_CONFIG)
25
  cursor = conn.cursor()
26
+ cursor.execute("INSERT INTO safety_violations (user_id, prompt_content) VALUES (%s, %s)", (user, prompt[:1000]))
 
 
 
27
  conn.close()
28
+ except: pass
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @app.route('/v1/chat/completions', methods=['POST'])
31
  def protect_chat():
32
  data = request.json
33
+ content = " ".join([m.get('content', '') for m in data.get('messages', [])])
 
 
34
 
35
+ # Safety Check
36
  try:
37
+ res = requests.post("https://api.openai.com/v1/moderations",
38
+ headers={"Authorization": f"Bearer {MODERATION_KEY}"},
39
+ json={"input": content}, timeout=3).json()
40
+ if res.get('results', [{}])[0].get('categories', {}).get('sexual/minors'):
41
+ log_violation(request.headers.get('Authorization', 'Anon'), content)
42
+ return {"error": {"message": "Shield Block: Safety Violation"}}, 403
43
+ except: pass
 
 
 
44
 
45
+ return forward_to_newapi(request.path)
 
46
 
47
+ @app.route('/', defaults={'path': ''}, methods=['GET', 'POST', 'PUT', 'DELETE'])
48
+ @app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
49
+ def catch_all(path):
50
+ return forward_to_newapi(path)
51
+
52
+ def forward_to_newapi(path):
53
+ resp = requests.request(
54
+ method=request.method,
55
+ url=f"{NEWAPI_INTERNAL}/{path}",
56
+ headers={k: v for k, v in request.headers if k.lower() != 'host'},
57
+ data=request.get_data(),
58
+ params=request.args,
59
+ allow_redirects=False
60
+ )
61
+ return Response(resp.content, resp.status_code, resp.headers.items())
62
 
63
  if __name__ == '__main__':
 
64
  app.run(host='0.0.0.0', port=7860)