import numpy as np from typing import Dict, Any, List, Tuple from mathema.genotype.neo4j.genotype import load_genotype_snapshot, neo4j from mathema.envs.openai_car_racing import CarRacing def af_tanh(x): return np.tanh(x) def af_cos(x): return np.cos(x) def af_gauss(x): return np.exp(-np.square(x)) def af_abs(x): return np.abs(x) AF_MAP = { "tanh": af_tanh, "cos": af_cos, "gauss": af_gauss, "abs": af_abs, } class DXNNPolicy: def __init__(self, W_in, W_ff, W_rec, b_h, af_idx, af_funcs, out_index): self.W_in, self.W_ff, self.W_rec = W_in, W_ff, W_rec self.b_h = b_h self.af_idx = np.array(af_idx, dtype=np.int32) self.af_funcs = af_funcs self.out_index = list(out_index) H = W_ff.shape[0] self.h = np.zeros((H,), dtype=np.float32) self.h_prev = np.zeros_like(self.h) def reset_state(self): self.h.fill(0.0) self.h_prev.fill(0.0) def step(self, x_t: np.ndarray) -> np.ndarray: z = self.b_h + self.W_in @ x_t + self.W_ff @ self.h + self.W_rec @ self.h_prev h_new = np.empty_like(self.h) for j, afi in enumerate(self.af_idx): h_new[j] = self.af_funcs[afi](z[j]) y = np.array([h_new[j] for j in self.out_index], dtype=np.float32) self.h_prev, self.h = self.h, h_new return y def _build_sensor_index(sensors: List[Dict[str, Any]]) -> Tuple[Dict[str, Tuple[int, int]], int]: offset = 0 idx = {} for s in sensors: L = int(s["vector_length"]) idx[str(s["id"])] = (offset, L) offset += L return idx, offset def build_policy_from_snapshot(snap: Dict[str, Any]) -> Tuple[DXNNPolicy, int]: sensors = snap["sensors"] neurons = snap["neurons"] actuators = snap["actuators"] s_idx, I = _build_sensor_index(sensors) H = len(neurons) nid2ix = {n["id"]: i for i, n in enumerate(neurons)} W_in = np.zeros((H, I), dtype=np.float32) W_ff = np.zeros((H, H), dtype=np.float32) W_rec = np.zeros((H, H), dtype=np.float32) b_h = np.zeros((H,), dtype=np.float32) af_names = [] for j, n in enumerate(neurons): b = n.get("bias") b_h[j] = (float(b) if b is not None else 0.0) af_names.append((n.get("activation_function") or "tanh").lower()) for inp in n.get("input_weights", []): src = inp["input_id"] ws = [float(x) for x in (inp.get("weights") or [])] if src in s_idx: off, L = s_idx[src] if len(ws) != L: if len(ws) > L: ws = ws[:L] else: ws = ws + [0.0] * (L - len(ws)) W_in[j, off:off + L] += np.asarray(ws, dtype=np.float32) elif src in nid2ix: i = nid2ix[src] w = float(ws[0]) if ws else 0.0 if bool(inp.get("recurrent", False)): W_rec[j, i] += w else: W_ff[j, i] += w else: pass af_funcs = [af_tanh, af_cos, af_gauss, af_abs] af_name2idx = {"tanh": 0, "cos": 1, "gauss": 2, "abs": 3} af_idx = [af_name2idx.get(nm, 0) for nm in af_names] out_index = [] for a in actuators: for nid in a.get("fanin_ids", []): if nid in nid2ix: out_index.append(nid2ix[nid]) policy = DXNNPolicy(W_in, W_ff, W_rec, b_h, af_idx, af_funcs, out_index) return policy, I