Skip to content
Snippets Groups Projects
Commit 24e7d4cf authored by Jonathan Klimt's avatar Jonathan Klimt :cowboy:
Browse files

Fixed formatting, lints and typing

parent 4fc5d280
No related branches found
No related tags found
No related merge requests found
from flask import Flask, render_template, jsonify, Response, request, redirect, url_for, flash from flask import (
Flask,
render_template,
jsonify,
request,
redirect,
url_for,
flash,
)
from flask_socketio import SocketIO from flask_socketio import SocketIO
from flask_login import LoginManager, login_user, login_required, logout_user, current_user from flask_login import ( # type: ignore
LoginManager,
login_user,
login_required,
logout_user,
current_user,
)
from flask_limiter import Limiter from flask_limiter import Limiter
from flask_limiter.util import get_remote_address from flask_limiter.util import get_remote_address
from flask.typing import ResponseReturnValue
from pyghmi.ipmi import command # type: ignore from pyghmi.ipmi import command # type: ignore
import json import json
import time import time
...@@ -11,25 +26,27 @@ import logging ...@@ -11,25 +26,27 @@ import logging
import socket import socket
import sys import sys
import argparse import argparse
from typing import Dict, List, Optional, Tuple, TypedDict, Union from typing import Dict, List, Optional, Tuple, TypedDict
from auth import LDAPAuth, User from auth import LDAPAuth, User, LDAPConfig
import os import os
from datetime import timedelta from datetime import timedelta
from typing import Any
# Configure logging to show all levels # Configure logging to show all levels
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG, level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
stream=sys.stdout stream=sys.stdout,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Silence Werkzeug logs # Silence Werkzeug logs
logging.getLogger('werkzeug').setLevel(logging.WARNING) logging.getLogger("werkzeug").setLevel(logging.WARNING)
# Set socket timeout # Set socket timeout
socket.setdefaulttimeout(5) # 5 seconds timeout socket.setdefaulttimeout(5) # 5 seconds timeout
class ServerConfig(TypedDict): class ServerConfig(TypedDict):
name: str name: str
ipmi_ip: str ipmi_ip: str
...@@ -37,6 +54,7 @@ class ServerConfig(TypedDict): ...@@ -37,6 +54,7 @@ class ServerConfig(TypedDict):
ipmi_pass: str ipmi_pass: str
locked: bool locked: bool
class ServerStatus(TypedDict): class ServerStatus(TypedDict):
name: str name: str
status: str status: str
...@@ -45,54 +63,55 @@ class ServerStatus(TypedDict): ...@@ -45,54 +63,55 @@ class ServerStatus(TypedDict):
power_consumption: Optional[float] power_consumption: Optional[float]
locked: bool locked: bool
class GroupConfig(TypedDict): class GroupConfig(TypedDict):
name: str name: str
servers: List[ServerConfig] servers: List[ServerConfig]
class LDAPConfig(TypedDict):
server: str
base_dn: str
user_dn: str
group_dn: str
admin_group: str
class Config(TypedDict): class Config(TypedDict):
ldap: LDAPConfig ldap: LDAPConfig
groups: List[GroupConfig] groups: List[GroupConfig]
class ServerData(TypedDict): class ServerData(TypedDict):
config: ServerConfig config: ServerConfig
connection: Optional[command.Command] connection: Optional[command.Command]
last_error: Optional[str] last_error: Optional[str]
class ServerGroup(TypedDict): class ServerGroup(TypedDict):
name: str name: str
servers: List[ServerConfig] servers: List[ServerConfig]
class GroupStatus(TypedDict): class GroupStatus(TypedDict):
name: str name: str
servers: List[ServerStatus] servers: List[ServerStatus]
app = Flask(__name__) app = Flask(__name__)
app.secret_key = os.urandom(24) app.secret_key = os.urandom(24)
app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(hours=1) # Session timeout after 1 hour app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(
hours=1
) # Session timeout after 1 hour
# app.config['SESSION_COOKIE_SECURE'] = True # Only send cookies over HTTPS # app.config['SESSION_COOKIE_SECURE'] = True # Only send cookies over HTTPS
app.config['SESSION_COOKIE_HTTPONLY'] = True # Prevent JavaScript access to cookies app.config["SESSION_COOKIE_HTTPONLY"] = True # Prevent JavaScript access to cookies
app.config['SESSION_COOKIE_SAMESITE'] = 'Strict' app.config["SESSION_COOKIE_SAMESITE"] = "Strict"
socketio = SocketIO(app, cors_allowed_origins="*") socketio = SocketIO(app, cors_allowed_origins="*")
# Initialize LoginManager # Initialize LoginManager
login_manager = LoginManager() login_manager = LoginManager()
login_manager.init_app(app) login_manager.init_app(app)
login_manager.login_view = 'login' login_manager.login_view = "login"
# Load configuration # Load configuration
with open('config.json', 'r') as f: with open("config.json", "r") as f:
config: Config = json.load(f) config: Config = json.load(f)
# Initialize LDAP authentication # Initialize LDAP authentication
ldap_auth = LDAPAuth(config['ldap']) ldap_auth = LDAPAuth(config["ldap"])
# Initialize rate limiter # Initialize rate limiter
limiter = Limiter( limiter = Limiter(
...@@ -100,13 +119,15 @@ limiter = Limiter( ...@@ -100,13 +119,15 @@ limiter = Limiter(
key_func=get_remote_address, key_func=get_remote_address,
default_limits=["200 per day", "50 per hour"], default_limits=["200 per day", "50 per hour"],
storage_uri="memory://", storage_uri="memory://",
strategy="fixed-window" strategy="fixed-window",
) )
@login_manager.user_loader @login_manager.user_loader
def load_user(user_id: str) -> Optional[User]: def load_user(user_id: str) -> Optional[User]:
return User(user_id) return User(user_id)
class ServerManager: class ServerManager:
def __init__(self, config: Config) -> None: def __init__(self, config: Config) -> None:
self.servers: Dict[str, ServerData] = {} self.servers: Dict[str, ServerData] = {}
...@@ -118,17 +139,19 @@ class ServerManager: ...@@ -118,17 +139,19 @@ class ServerManager:
def load_config(self, config: Config) -> None: def load_config(self, config: Config) -> None:
try: try:
self.groups = config['groups'] self.groups = config["groups"]
ipmi_ips = set() ipmi_ips = set()
for group in self.groups: for group in self.groups:
for server in group['servers']: for server in group["servers"]:
if server['ipmi_ip'] in ipmi_ips: if server["ipmi_ip"] in ipmi_ips:
raise ValueError(f"Duplicate IPMI IP address found in config: {server['ipmi_ip']}") raise ValueError(
ipmi_ips.add(server['ipmi_ip']) f"Duplicate IPMI IP address found in config: {server['ipmi_ip']}"
)
ipmi_ips.add(server["ipmi_ip"])
self.servers[f"{group['name']}/{server['name']}"] = { self.servers[f"{group['name']}/{server['name']}"] = {
'config': server, "config": server,
'connection': None, "connection": None,
'last_error': None "last_error": None,
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to load server configuration: {str(e)}") logger.error(f"Failed to load server configuration: {str(e)}")
...@@ -140,27 +163,36 @@ class ServerManager: ...@@ -140,27 +163,36 @@ class ServerManager:
for server_name, server_data in self.servers.items(): for server_name, server_data in self.servers.items():
try: try:
logger.debug(f"Initializing connection to {server_name}") logger.debug(f"Initializing connection to {server_name}")
server_data['connection'] = command.Command( server_data["connection"] = command.Command(
bmc=server_data['config']['ipmi_ip'], bmc=server_data["config"]["ipmi_ip"],
userid=server_data['config']['ipmi_user'], userid=server_data["config"]["ipmi_user"],
password=server_data['config']['ipmi_pass'], password=server_data["config"]["ipmi_pass"],
keepalive=True keepalive=True,
) )
logger.debug(f"Successfully initialized connection to {server_name}") logger.debug(f"Successfully initialized connection to {server_name}")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize connection to {server_name}: {str(e)}") logger.error(
server_data['last_error'] = str(e) f"Failed to initialize connection to {server_name}: {str(e)}"
)
def _create_status_response(self, group_name: str, server_name: str, status: str, error: Optional[str] = None, power_consumption: Optional[float] = None) -> ServerStatus: server_data["last_error"] = str(e)
def _create_status_response(
self,
group_name: str,
server_name: str,
status: str,
error: Optional[str] = None,
power_consumption: Optional[float] = None,
) -> ServerStatus:
"""Create a standardized status response for a server.""" """Create a standardized status response for a server."""
server: ServerData = self.servers[f"{group_name}/{server_name}"] server: ServerData = self.servers[f"{group_name}/{server_name}"]
return { return {
'name': server_name, "name": server_name,
'status': status, "status": status,
'error': error, "error": error,
'ipmi_ip': server['config']['ipmi_ip'], "ipmi_ip": server["config"]["ipmi_ip"],
'power_consumption': power_consumption, "power_consumption": power_consumption,
'locked': server['config'].get('locked', False) "locked": server["config"].get("locked", False),
} }
def _update_server_status(self, group_name: str, server_name: str) -> None: def _update_server_status(self, group_name: str, server_name: str) -> None:
...@@ -169,50 +201,53 @@ class ServerManager: ...@@ -169,50 +201,53 @@ class ServerManager:
return return
server: ServerData = self.servers[cache_key] server: ServerData = self.servers[cache_key]
if not server['connection']: if not server["connection"]:
self.status_cache[cache_key] = self._create_status_response( self.status_cache[cache_key] = self._create_status_response(
group_name, group_name,
server_name, server_name,
'unknown', "unknown",
server['last_error'] or 'No connection established' server["last_error"] or "No connection established",
) )
return return
try: try:
status = server['connection'].get_power() status = server["connection"].get_power()
server['last_error'] = None server["last_error"] = None
# Get power consumption # Get power consumption
power_consumption: Optional[float] = None power_consumption: Optional[float] = None
try: try:
power_consumption = server['connection'].get_system_power_watts() power_consumption = server["connection"].get_system_power_watts()
logger.debug(f"Power consumption for {server_name}: {power_consumption}W") logger.debug(
f"Power consumption for {server_name}: {power_consumption}W"
)
except Exception as e: except Exception as e:
logger.debug(f"Could not get power consumption for {server_name}: {str(e)}") logger.debug(
f"Could not get power consumption for {server_name}: {str(e)}"
)
self.status_cache[cache_key] = self._create_status_response( self.status_cache[cache_key] = self._create_status_response(
group_name, group_name,
server_name, server_name,
'on' if status['powerstate'] == 'on' else 'off', "on" if status["powerstate"] == "on" else "off",
power_consumption=power_consumption power_consumption=power_consumption,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get power status for {server_name}: {str(e)}") logger.error(f"Failed to get power status for {server_name}: {str(e)}")
server['last_error'] = str(e) server["last_error"] = str(e)
self.status_cache[cache_key] = self._create_status_response( self.status_cache[cache_key] = self._create_status_response(
group_name, group_name, server_name, "unknown", str(e)
server_name,
'unknown',
str(e)
) )
def _status_update_loop(self) -> None: def _status_update_loop(self) -> None:
while True: while True:
try: try:
for group in self.groups: for group in self.groups:
for server in group['servers']: for server in group["servers"]:
self._update_server_status(group['name'], server['name']) self._update_server_status(group["name"], server["name"])
socketio.emit('server_status_update', {'groups': self.get_all_statuses()}) socketio.emit(
"server_status_update", {"groups": self.get_all_statuses()}
)
except Exception as e: except Exception as e:
logger.error(f"Error in status update thread: {str(e)}") logger.error(f"Error in status update thread: {str(e)}")
time.sleep(5) time.sleep(5)
...@@ -221,32 +256,40 @@ class ServerManager: ...@@ -221,32 +256,40 @@ class ServerManager:
thread = threading.Thread(target=self._status_update_loop, daemon=True) thread = threading.Thread(target=self._status_update_loop, daemon=True)
thread.start() thread.start()
def toggle_power(self, group_name: str, server_name: str) -> Tuple[Optional[bool], Optional[str]]: def toggle_power(
self, group_name: str, server_name: str
) -> Tuple[Optional[bool], Optional[str]]:
cache_key = f"{group_name}/{server_name}" cache_key = f"{group_name}/{server_name}"
if cache_key not in self.servers: if cache_key not in self.servers:
logger.error(f"Power toggle failed: Server {server_name} not found in group {group_name}") logger.error(
f"Power toggle failed: Server {server_name} not found in group {group_name}"
)
return None, "Server not found" return None, "Server not found"
server = self.servers[cache_key] server = self.servers[cache_key]
if server['config'].get('locked', False): if server["config"].get("locked", False):
logger.info(f"Power toggle blocked: Server {server_name} is locked") logger.info(f"Power toggle blocked: Server {server_name} is locked")
return None, "Server is locked and cannot be controlled via this tool" return None, "Server is locked and cannot be controlled via this tool"
if not server['connection']: if not server["connection"]:
logger.error(f"Power toggle failed: No connection established for server {server_name}") logger.error(
f"Power toggle failed: No connection established for server {server_name}"
)
return None, "No connection established" return None, "No connection established"
try: try:
current_status = server['connection'].get_power() current_status = server["connection"].get_power()
if current_status['powerstate'] == 'on': if current_status["powerstate"] == "on":
logger.info(f"Initiating power off for server {server_name}") logger.info(f"Initiating power off for server {server_name}")
server['connection'].set_power('off') server["connection"].set_power("off")
else: else:
logger.info(f"Initiating power on for server {server_name}") logger.info(f"Initiating power on for server {server_name}")
server['connection'].set_power('on') server["connection"].set_power("on")
return True, None return True, None
except Exception as e: except Exception as e:
logger.error(f"Failed to initiate power toggle for server {server_name}: {str(e)}") logger.error(
f"Failed to initiate power toggle for server {server_name}: {str(e)}"
)
return None, str(e) return None, str(e)
def get_all_statuses(self) -> List[GroupStatus]: def get_all_statuses(self) -> List[GroupStatus]:
...@@ -255,78 +298,88 @@ class ServerManager: ...@@ -255,78 +298,88 @@ class ServerManager:
for group in self.groups: for group in self.groups:
group_servers: List[ServerStatus] = [] group_servers: List[ServerStatus] = []
for server in group['servers']: for server in group["servers"]:
cache_key = f"{group['name']}/{server['name']}" cache_key = f"{group['name']}/{server['name']}"
if cache_key in self.status_cache: if cache_key in self.status_cache:
group_servers.append(self.status_cache[cache_key]) group_servers.append(self.status_cache[cache_key])
group_statuses.append({ group_statuses.append({"name": group["name"], "servers": group_servers})
'name': group['name'],
'servers': group_servers
})
return group_statuses return group_statuses
config_path = "config.json" config_path = "config.json"
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='IPMI Server Management') parser = argparse.ArgumentParser(description="IPMI Server Management")
parser.add_argument('-c', '--config', default='config.json', help='Path to the configuration file') parser.add_argument(
"-c", "--config", default="config.json", help="Path to the configuration file"
)
args = parser.parse_args() args = parser.parse_args()
config_path = args.config config_path = args.config
with open(config_path, 'r') as f: with open(config_path, "r") as f:
config = json.load(f) config = json.load(f)
server_manager = ServerManager(config) server_manager = ServerManager(config)
@app.route('/login', methods=['GET', 'POST'])
@app.route("/login", methods=["GET", "POST"])
@limiter.limit("5 per minute") # Rate limit login attempts @limiter.limit("5 per minute") # Rate limit login attempts
def login(): def login() -> ResponseReturnValue:
if request.method == 'POST': if request.method == "POST":
username = request.form.get('username') username = request.form.get("username")
password = request.form.get('password') password = request.form.get("password")
if not username or not password: if not username or not password:
flash('Please enter both username and password') flash("Please enter both username and password")
return redirect(url_for('login')) return redirect(url_for("login"))
user = ldap_auth.authenticate(username, password) user = ldap_auth.authenticate(username, password)
if user: if user:
login_user(user) login_user(user)
return redirect(url_for('index')) return redirect(url_for("index"))
else: else:
flash('Invalid username or password') flash("Invalid username or password")
return render_template("login.html")
return render_template('login.html')
@app.route('/logout') @app.route("/logout")
@login_required @login_required
def logout(): def logout() -> ResponseReturnValue:
logout_user() logout_user()
return redirect(url_for('login')) return redirect(url_for("login"))
@app.route('/')
@app.route("/")
def index() -> str: def index() -> str:
return render_template('index.html') return render_template("index.html")
@app.route('/api/servers', methods=['GET']) @app.route("/api/servers", methods=["GET"])
@limiter.limit("90 per minute") # Rate limit server status checks @limiter.limit("90 per minute") # Rate limit server status checks
def get_servers() -> Response: def get_servers() -> Any:
return jsonify({'groups': server_manager.get_all_statuses()}) return jsonify({"groups": server_manager.get_all_statuses()})
@app.route('/api/servers/<group_name>/<server_name>/power', methods=['POST'])
@app.route("/api/servers/<group_name>/<server_name>/power", methods=["POST"])
@login_required @login_required
@limiter.limit("4 per minute", key_func=lambda: f"{request.remote_addr}:{request.view_args['group_name']}:{request.view_args['server_name']}") @limiter.limit(
def power_action(group_name: str, server_name: str) -> Union[Response, Tuple[Response, int]]: "4 per minute",
key_func=lambda: f"{request.remote_addr}:{request.view_args['group_name']}:{request.view_args['server_name']}",
)
def power_action(group_name: str, server_name: str) -> Any:
success, error = server_manager.toggle_power(group_name, server_name) success, error = server_manager.toggle_power(group_name, server_name)
if not success: if not success:
return jsonify({'error': error}), 500 return jsonify({"error": error}), 500
return jsonify({'status': 'success'}) return jsonify({"status": "success"})
@app.route("/api/auth/status", methods=["GET"])
def auth_status() -> Any:
return jsonify({"authenticated": current_user.is_authenticated})
@app.route('/api/auth/status', methods=['GET'])
def auth_status() -> Response:
return jsonify({'authenticated': current_user.is_authenticated})
if __name__ == '__main__': if __name__ == "__main__":
socketio.run(app, debug=True) socketio.run(app, debug=True)
import ldap import ldap # type: ignore
from flask_login import UserMixin from flask_login import UserMixin # type: ignore
from typing import Optional, Dict, Any from typing import Optional, TypedDict
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LDAPConfig(TypedDict):
server: str
base_dn: str
user_dn: str
user_password: str
admin_group: str
class User(UserMixin): class User(UserMixin):
def __init__(self, username: str): def __init__(self, username: str):
self.id = username self.id = username
class LDAPAuth: class LDAPAuth:
def __init__(self, config: Dict[str, Any]): def __init__(self, config: 'LDAPConfig'):
self.server = config['server'] self.server = config["server"]
self.base_dn = config['base_dn'] self.base_dn = config["base_dn"]
self.user_dn = config['user_dn'] self.user_dn = config["user_dn"]
self.admin_group = config['admin_group'] self.admin_group = config["admin_group"]
self.user_password = config['user_password'] self.user_password = config["user_password"]
self._service_conn = self._establish_service_connection() self._service_conn = self._establish_service_connection()
def _establish_service_connection(self) -> ldap.ldapobject.LDAPObject: def _establish_service_connection(self) -> ldap.ldapobject.LDAPObject:
...@@ -35,10 +44,14 @@ class LDAPAuth: ...@@ -35,10 +44,14 @@ class LDAPAuth:
try: try:
# Search for the user using service connection # Search for the user using service connection
search_filter = f"(uid={username})" search_filter = f"(uid={username})"
result = self._service_conn.search_s(self.base_dn, ldap.SCOPE_SUBTREE, search_filter, ['dn', 'memberOf']) result = self._service_conn.search_s(
self.base_dn, ldap.SCOPE_SUBTREE, search_filter, ["dn", "memberOf"]
)
if not result: if not result:
logger.warning(f"Unsuccessful login attempt from user {username}: not found in LDAP") logger.warning(
f"Unsuccessful login attempt from user {username}: not found in LDAP"
)
return None return None
# Get user's DN and try to bind with user's credentials # Get user's DN and try to bind with user's credentials
...@@ -55,20 +68,25 @@ class LDAPAuth: ...@@ -55,20 +68,25 @@ class LDAPAuth:
return None return None
# Check if user is authorized # Check if user is authorized
if 'memberOf' in result[0][1]: if "memberOf" in result[0][1]:
for group_dn in result[0][1]['memberOf']: for group_dn in result[0][1]["memberOf"]:
if group_dn.decode('utf-8').split(',')[0].split('=')[1] == self.admin_group: if (
group_dn.decode("utf-8").split(",")[0].split("=")[1]
== self.admin_group
):
logger.info(f"Successful login of user {username}") logger.info(f"Successful login of user {username}")
return User(username) return User(username)
logger.warning(f"User {username} is not a member of the authorized group {self.admin_group}") logger.warning(
f"User {username} is not a member of the authorized group {self.admin_group}"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"LDAP authentication error: {str(e)}") logger.error(f"LDAP authentication error: {str(e)}")
return None return None
def __del__(self): def __del__(self) -> None:
"""Clean up service connection when object is destroyed.""" """Clean up service connection when object is destroyed."""
if self._service_conn: if self._service_conn:
self._service_conn.unbind_s() self._service_conn.unbind_s()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment