# control_server.py
from flask import Flask, request, jsonify
from flask_cors import CORS
import paho.mqtt.publish as publish
import time
import threading
import paho.mqtt.client as mqtt
import json

app = Flask(__name__)
CORS(app)

# Simulated internal status state
status = {
    "mqtt_connected": True,
    "esp_online": True,
    "last_heartbeat": time.time(),
    "last_zones": 0,

# From esp32/<node>/zones retained JSON
    "zones": {str(i): False for i in range(1, 9)},
    "active_zones": [],
    "auto_mode": None,
    "flush_mode": None,
    "hold_mode": None,
}

@app.route('/api/command', methods=['POST'])
def command():
    data = request.get_json()
    cmd = data.get('cmd')
    if not cmd:
        return jsonify({'error': 'Missing command'}), 400

    publish.single("pi/control/TROF", payload=cmd, hostname="localhost")
    print(f"📡 Sent MQTT command: {cmd}")
    return jsonify({'status': 'ok', 'sent': cmd})

@app.route('/api/status', methods=['GET'])
@app.route('/api/status', methods=['GET'])
def get_status():
    now = time.time()
    hb_delta = now - status["last_heartbeat"]

    # consider offline if no heartbeat in 45s (your hb interval is 30s)
    esp_online = hb_delta < 45
    status["esp_online"] = esp_online

    heartbeat_str = f"{int(hb_delta)}s ago" if hb_delta < 600 else "—"

    return jsonify({
        "mqtt_connected": status["mqtt_connected"],
        "esp_online": esp_online,
        "heartbeat": heartbeat_str,

        # NEW: truth for UI
        "zones": status["zones"],
        "active_zones": status["active_zones"],
        "auto_mode": status["auto_mode"],
        "flush_mode": status["flush_mode"],
        "hold_mode": status["hold_mode"],
    })

def on_connect(client, userdata, flags, rc):
    status["mqtt_connected"] = (rc == 0)
    client.subscribe("esp32/TROF/hb")
    client.subscribe("esp32/TROF/zones")
    client.subscribe("esp32/TROF/state")  # optional: online/offline LWT

def on_disconnect(client, userdata, rc):
    status["mqtt_connected"] = False

def on_message(client, userdata, msg):
    topic = msg.topic
    payload = msg.payload.decode(errors="ignore")

    if topic == "esp32/TROF/hb":
        status["last_heartbeat"] = time.time()
        status["esp_online"] = True
        return

    if topic == "esp32/TROF/state":
        # LWT online/offline string
        if payload.strip() == "online":
            status["esp_online"] = True
        elif payload.strip() == "offline":
            status["esp_online"] = False
        return

    if topic == "esp32/TROF/zones":
        try:
            data = json.loads(payload)
        except Exception as e:
            print(f"⚠️ bad zones JSON: {e} payload={payload}")
            return

        z = data.get("z")
        if isinstance(z, list) and len(z) >= 8:
            zones = {str(i+1): bool(z[i]) for i in range(8)}
            status["zones"] = zones
            status["active_zones"] = [int(k) for k, v in zones.items() if v]

        # ESP publishes manual/flush/hold as 0/1
        manual = data.get("manual")
        flush  = data.get("flush")
        hold   = data.get("hold")

        if manual is not None:
            status["auto_mode"] = (int(manual) == 0)
        if flush is not None:
            status["flush_mode"] = (int(flush) == 1)
        if hold is not None:
            status["hold_mode"] = (int(hold) == 1)

        status["last_zones"] = time.time()

def start_mqtt_listener():
    client = mqtt.Client()
    client.on_connect = on_connect
    client.on_disconnect = on_disconnect
    client.on_message = on_message
    client.connect("localhost", 1883, 60)
    client.loop_forever()

# Start MQTT listener in background thread
threading.Thread(target=start_mqtt_listener, daemon=True).start()

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)
