From a90e01f9c3c0ad2b6215a9a2c855cf6897b3d265 Mon Sep 17 00:00:00 2001
From: "christoph.von.oy" <christoph.von.oy@rwth-aachen.de>
Date: Tue, 29 Oct 2024 11:42:20 +0100
Subject: [PATCH] Refactored component matching

---
 component/core.py | 189 ++++++++++++++++++++++++++++++++++------------
 1 file changed, 142 insertions(+), 47 deletions(-)

diff --git a/component/core.py b/component/core.py
index 4b56d47..96fb5fd 100644
--- a/component/core.py
+++ b/component/core.py
@@ -25,10 +25,12 @@ THE SOFTWARE.
 from Model_Library.utility import design_annuity, operational_annuity
 from Model_Library.optimization_model import VariableKind
 
+from dataclasses import dataclass
 from enum import Enum
 import json
 import pyomo.environ as pyo
 import warnings
+from typing import List, Union
 
 
 class ComponentKind(Enum):
@@ -53,6 +55,103 @@ class ComponentCommodity(Enum):
     SPACE_HEAT = 7
 
 
+class ComponentPart(Enum):
+    ALL = (1,)
+    DESIGN = (2,)
+    STATE = (3,)
+    NONE_STATE = 4
+
+
+class ComponentPartPattern:
+    def __init__(
+        self,
+        kind=ComponentKind.ALL,
+        type="all",
+        commodity=ComponentCommodity.ALL,
+        part=ComponentPart.ALL,
+    ):
+        if isinstance(kind, List) and ComponentKind.ALL in kind:
+            self.kind = ComponentKind.ALL
+        else:
+            self.kind = kind
+
+        if isinstance(type, List) and "all" in type:
+            self.type = "all"
+        else:
+            self.type = type
+
+        if isinstance(commodity, List) and ComponentCommodity.ALL in commodity:
+            self.commodity = ComponentCommodity.ALL
+        else:
+            self.commodity = commodity
+
+        if isinstance(part, List) and ComponentPart.ALL in part:
+            self.part = ComponentPart.ALL
+        else:
+            self.part = part
+
+    def match(
+        self,
+        kind=ComponentKind.ALL,
+        type="all",
+        commodity=ComponentCommodity.ALL,
+        part=ComponentPart.ALL,
+    ):
+        if kind != ComponentKind.ALL and self.kind != ComponentKind.ALL:
+            if isinstance(self.kind, List):
+                if kind not in self.kind:
+                    return False
+            else:
+                if kind != self.kind:
+                    return False
+
+        if type != "all" and self.type != "all":
+            if isinstance(self.type, List):
+                if type not in self.type:
+                    return False
+            else:
+                if type != self.type:
+                    return False
+
+        if (
+            commodity != ComponentCommodity.ALL
+            and self.commodity != ComponentCommodity.ALL
+        ):
+            if isinstance(self.commodity, List):
+                if commodity not in self.commodity:
+                    return False
+            else:
+                if commodity != self.commodity:
+                    return False
+
+        if part != ComponentPart.ALL and self.part != ComponentPart.ALL:
+            if isinstance(self.part, List):
+                if part not in self.part:
+                    return False
+            else:
+                if part != self.part:
+                    return False
+
+        return True
+
+    @staticmethod
+    def match_simple(
+        component_kind,
+        component_type,
+        component_commodities,
+        pattern_kind=ComponentKind.ALL,
+        pattern_type="all",
+        pattern_commodity=ComponentCommodity.ALL,
+    ):
+        match_kind = pattern_kind == ComponentKind.ALL or pattern_kind == component_kind
+        match_type = pattern_type == "all" or pattern_type == component_type
+        match_commodity = (
+            pattern_commodity == ComponentCommodity.ALL
+            or pattern_commodity in component_commodities
+        )
+        return match_kind and match_type and match_commodity
+
+
 class ComponentCapacity(Enum):
     NONE = 1
     OPTIONAL = 2
@@ -309,14 +408,16 @@ class BaseBusBar(AbstractComponent):
             capacity=ComponentCapacity.NONE,
         )
 
+        self.commodity = commodity
+
     def match(self, kind=ComponentKind.ALL, commodity=ComponentCommodity.ALL):
-        match_kind = kind == ComponentKind.ALL or kind == ComponentKind.BUSBAR
-        match_commodity = (
-            commodity == ComponentCommodity.ALL
-            or commodity == self.commodity
-            or (isinstance(commodity, list) and self.commodity in commodity)
+        return ComponentPartPattern.match_simple(
+            ComponentKind.BUSBAR,
+            self.__class__.__name__,
+            [self.commodity],
+            pattern_kind=kind,
+            pattern_commodity=commodity,
         )
-        return match_kind and match_commodity
 
     def operational_base_variable_names(self):
         return [
@@ -382,24 +483,18 @@ class BaseComponent(AbstractComponent):
         configuration["conversion_2"] = None
 
     def match(self, kind=ComponentKind.ALL, commodity=ComponentCommodity.ALL):
-        match_kind = kind == ComponentKind.ALL or kind == ComponentKind.BASE
-        match_commodity = (
-            commodity == ComponentCommodity.ALL
-            or commodity == self.input_commodity_1
-            or commodity == self.input_commodity_2
-            or commodity == self.output_commodity_1
-            or commodity == self.output_commodity_2
-            or (
-                isinstance(commodity, list)
-                and (
-                    self.input_commodity_1 in commodity
-                    or self.input_commodity_2 in commodity
-                    or self.output_commodity_1 in commodity
-                    or self.output_commodity_2 in commodity
-                )
-            )
+        return ComponentPartPattern.match_simple(
+            ComponentKind.BASE,
+            self.__class__.__name__,
+            [
+                self.input_commodity_1,
+                self.input_commodity_2,
+                self.output_commodity_1,
+                self.output_commodity_2,
+            ],
+            pattern_kind=kind,
+            pattern_commodity=commodity,
         )
-        return match_kind and match_commodity
 
     def operational_base_variable_names(self):
         return self.operational_variables
@@ -522,13 +617,13 @@ class BaseConsumption(AbstractComponent):
         self.consumption = configuration["consumption"]
 
     def match(self, kind=ComponentKind.ALL, commodity=ComponentCommodity.ALL):
-        match_kind = kind == ComponentKind.ALL or kind == ComponentKind.CONSUMPTION
-        match_commodity = (
-            commodity == ComponentCommodity.ALL
-            or commodity == self.commodity
-            or (isinstance(commodity, list) and self.commodity in commodity)
+        return ComponentPartPattern.match_simple(
+            ComponentKind.CONSUMPTION,
+            self.__class__.__name__,
+            [self.commodity],
+            pattern_kind=kind,
+            pattern_commodity=commodity,
         )
-        return match_kind and match_commodity
 
     def operational_base_variable_names(self):
         return [(self.name + ".input_1", VariableKind.INDEXED)]
@@ -564,13 +659,13 @@ class BaseGeneration(AbstractComponent):
         self.generation = configuration["generation"]
 
     def match(self, kind=ComponentKind.ALL, commodity=ComponentCommodity.ALL):
-        match_kind = kind == ComponentKind.ALL or kind == ComponentKind.GENERATION
-        match_commodity = (
-            commodity == ComponentCommodity.ALL
-            or commodity == self.commodity
-            or (isinstance(commodity, list) and self.commodity in commodity)
+        return ComponentPartPattern.match_simple(
+            ComponentKind.GENERATION,
+            self.__class__.__name__,
+            [self.commodity],
+            pattern_kind=kind,
+            pattern_commodity=commodity,
         )
-        return match_kind and match_commodity
 
     def operational_base_variable_names(self):
         return [(self.name + ".output_1", VariableKind.INDEXED)]
@@ -621,13 +716,13 @@ class BaseGrid(AbstractComponent):
             self.co2_emissions_ = 0
 
     def match(self, kind=ComponentKind.ALL, commodity=ComponentCommodity.ALL):
-        match_kind = kind == ComponentKind.ALL or kind == ComponentKind.GRID
-        match_commodity = (
-            commodity == ComponentCommodity.ALL
-            or commodity == self.commodity
-            or (isinstance(commodity, list) and self.commodity in commodity)
+        return ComponentPartPattern.match_simple(
+            ComponentKind.GRID,
+            self.__class__.__name__,
+            [self.commodity],
+            pattern_kind=kind,
+            pattern_commodity=commodity,
         )
-        return match_kind and match_commodity
 
     def operational_base_variable_names(self):
         return [
@@ -752,13 +847,13 @@ class BaseStorage(AbstractComponent):
             )
 
     def match(self, kind=ComponentKind.ALL, commodity=ComponentCommodity.ALL):
-        match_kind = kind == ComponentKind.ALL or kind == ComponentKind.STORAGE
-        match_commodity = (
-            commodity == ComponentCommodity.ALL
-            or commodity == self.commodity
-            or (isinstance(commodity, list) and self.commodity in commodity)
+        return ComponentPartPattern.match_simple(
+            ComponentKind.STORAGE,
+            self.__class__.__name__,
+            [self.commodity],
+            pattern_kind=kind,
+            pattern_commodity=commodity,
         )
-        return match_kind and match_commodity
 
     def operational_base_variable_names(self):
         return [
-- 
GitLab