last workig state.
This commit is contained in:
120
mathema/viz_replay.py
Normal file
120
mathema/viz_replay.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import numpy as np
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from mathema.genotype.neo4j.genotype import load_genotype_snapshot, neo4j
|
||||
from mathema.envs.openai_car_racing import CarRacing
|
||||
|
||||
|
||||
def af_tanh(x): return np.tanh(x)
|
||||
|
||||
|
||||
def af_cos(x): return np.cos(x)
|
||||
|
||||
|
||||
def af_gauss(x): return np.exp(-np.square(x))
|
||||
|
||||
|
||||
def af_abs(x): return np.abs(x)
|
||||
|
||||
|
||||
AF_MAP = {
|
||||
"tanh": af_tanh,
|
||||
"cos": af_cos,
|
||||
"gauss": af_gauss,
|
||||
"abs": af_abs,
|
||||
}
|
||||
|
||||
|
||||
class DXNNPolicy:
|
||||
def __init__(self, W_in, W_ff, W_rec, b_h, af_idx, af_funcs, out_index):
|
||||
self.W_in, self.W_ff, self.W_rec = W_in, W_ff, W_rec
|
||||
self.b_h = b_h
|
||||
self.af_idx = np.array(af_idx, dtype=np.int32)
|
||||
self.af_funcs = af_funcs
|
||||
self.out_index = list(out_index)
|
||||
|
||||
H = W_ff.shape[0]
|
||||
self.h = np.zeros((H,), dtype=np.float32)
|
||||
self.h_prev = np.zeros_like(self.h)
|
||||
|
||||
def reset_state(self):
|
||||
self.h.fill(0.0)
|
||||
self.h_prev.fill(0.0)
|
||||
|
||||
def step(self, x_t: np.ndarray) -> np.ndarray:
|
||||
z = self.b_h + self.W_in @ x_t + self.W_ff @ self.h + self.W_rec @ self.h_prev
|
||||
h_new = np.empty_like(self.h)
|
||||
for j, afi in enumerate(self.af_idx):
|
||||
h_new[j] = self.af_funcs[afi](z[j])
|
||||
|
||||
y = np.array([h_new[j] for j in self.out_index], dtype=np.float32)
|
||||
self.h_prev, self.h = self.h, h_new
|
||||
return y
|
||||
|
||||
|
||||
def _build_sensor_index(sensors: List[Dict[str, Any]]) -> Tuple[Dict[str, Tuple[int, int]], int]:
|
||||
offset = 0
|
||||
idx = {}
|
||||
for s in sensors:
|
||||
L = int(s["vector_length"])
|
||||
idx[str(s["id"])] = (offset, L)
|
||||
offset += L
|
||||
return idx, offset
|
||||
|
||||
|
||||
def build_policy_from_snapshot(snap: Dict[str, Any]) -> Tuple[DXNNPolicy, int]:
|
||||
sensors = snap["sensors"]
|
||||
neurons = snap["neurons"]
|
||||
actuators = snap["actuators"]
|
||||
|
||||
s_idx, I = _build_sensor_index(sensors)
|
||||
H = len(neurons)
|
||||
nid2ix = {n["id"]: i for i, n in enumerate(neurons)}
|
||||
|
||||
W_in = np.zeros((H, I), dtype=np.float32)
|
||||
W_ff = np.zeros((H, H), dtype=np.float32)
|
||||
W_rec = np.zeros((H, H), dtype=np.float32)
|
||||
b_h = np.zeros((H,), dtype=np.float32)
|
||||
|
||||
af_names = []
|
||||
for j, n in enumerate(neurons):
|
||||
b = n.get("bias")
|
||||
b_h[j] = (float(b) if b is not None else 0.0)
|
||||
af_names.append((n.get("activation_function") or "tanh").lower())
|
||||
|
||||
for inp in n.get("input_weights", []):
|
||||
src = inp["input_id"]
|
||||
ws = [float(x) for x in (inp.get("weights") or [])]
|
||||
if src in s_idx:
|
||||
|
||||
off, L = s_idx[src]
|
||||
if len(ws) != L:
|
||||
|
||||
if len(ws) > L:
|
||||
ws = ws[:L]
|
||||
else:
|
||||
ws = ws + [0.0] * (L - len(ws))
|
||||
W_in[j, off:off + L] += np.asarray(ws, dtype=np.float32)
|
||||
elif src in nid2ix:
|
||||
i = nid2ix[src]
|
||||
w = float(ws[0]) if ws else 0.0
|
||||
if bool(inp.get("recurrent", False)):
|
||||
W_rec[j, i] += w
|
||||
else:
|
||||
W_ff[j, i] += w
|
||||
else:
|
||||
|
||||
pass
|
||||
|
||||
af_funcs = [af_tanh, af_cos, af_gauss, af_abs]
|
||||
af_name2idx = {"tanh": 0, "cos": 1, "gauss": 2, "abs": 3}
|
||||
af_idx = [af_name2idx.get(nm, 0) for nm in af_names]
|
||||
|
||||
out_index = []
|
||||
for a in actuators:
|
||||
|
||||
for nid in a.get("fanin_ids", []):
|
||||
if nid in nid2ix:
|
||||
out_index.append(nid2ix[nid])
|
||||
|
||||
policy = DXNNPolicy(W_in, W_ff, W_rec, b_h, af_idx, af_funcs, out_index)
|
||||
return policy, I
|
||||
Reference in New Issue
Block a user