From efa497dfd86bc64b3f9b991f6fc1a10ff23f7596 Mon Sep 17 00:00:00 2001
From: Kevin O'Connor <kevin@koconnor.net>
Date: Thu, 18 Feb 2021 14:01:40 -0500
Subject: [PATCH] msgproto: Avoid peeking into the msgproto class members

Update callers to only use exported methods of the msgproto objects.
This makes it easier to make internal changes to the code.

Signed-off-by: Kevin O'Connor <kevin@koconnor.net>
---
 klippy/console.py        | 16 +++++----
 klippy/mcu.py            | 17 ++++++----
 klippy/msgproto.py       | 70 ++++++++++++++++++++++++----------------
 scripts/buildcommands.py | 35 ++++++++++----------
 4 files changed, 78 insertions(+), 60 deletions(-)

diff --git a/klippy/console.py b/klippy/console.py
index 4b237222b..69095345d 100755
--- a/klippy/console.py
+++ b/klippy/console.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python2
 # Script to implement a test console with firmware over serial port
 #
-# Copyright (C) 2016,2017  Kevin O'Connor <kevin@koconnor.net>
+# Copyright (C) 2016-2021  Kevin O'Connor <kevin@koconnor.net>
 #
 # This file may be distributed under the terms of the GNU GPLv3 license.
 import sys, optparse, os, re, logging
@@ -54,11 +54,12 @@ class KeyboardReader:
         self.output("="*20 + " attempting to connect " + "="*20)
         self.ser.connect()
         msgparser = self.ser.get_msgparser()
-        self.output("Loaded %d commands (%s / %s)" % (
-            len(msgparser.messages_by_id),
-            msgparser.version, msgparser.build_versions))
+        message_count = len(msgparser.get_messages())
+        version, build_versions = msgparser.get_version_info()
+        self.output("Loaded %d commands (%s / %s)"
+                    % (message_count, version, build_versions))
         self.output("MCU config: %s" % (" ".join(
-            ["%s=%s" % (k, v) for k, v in msgparser.config.items()])))
+            ["%s=%s" % (k, v) for k, v in msgparser.get_constants().items()])))
         self.clocksync.connect(self.ser)
         self.ser.handle_default = self.handle_default
         self.ser.register_response(self.handle_output, '#output')
@@ -137,9 +138,10 @@ class KeyboardReader:
     def command_LIST(self, parts):
         self.update_evals(self.reactor.monotonic())
         mp = self.ser.get_msgparser()
+        cmds = [msgformat for msgid, msgtype, msgformat in mp.get_messages()
+                if msgtype == 'command']
         out = "Available mcu commands:"
-        out += "\n  ".join([""] + sorted([
-            mp.messages_by_id[i].msgformat for i in mp.command_ids]))
+        out += "\n  ".join([""] + sorted(cmds))
         out += "\nAvailable artificial commands:"
         out += "\n  ".join([""] + [n for n in sorted(self.local_commands)])
         out += "\nAvailable local variables:"
diff --git a/klippy/mcu.py b/klippy/mcu.py
index 7032df5db..6f6c80024 100644
--- a/klippy/mcu.py
+++ b/klippy/mcu.py
@@ -1,6 +1,6 @@
 # Interface to Klipper micro-controller code
 #
-# Copyright (C) 2016-2020  Kevin O'Connor <kevin@koconnor.net>
+# Copyright (C) 2016-2021  Kevin O'Connor <kevin@koconnor.net>
 #
 # This file may be distributed under the terms of the GNU GPLv3 license.
 import sys, os, zlib, logging, math
@@ -564,10 +564,11 @@ class MCU:
         return config_params
     def _log_info(self):
         msgparser = self._serial.get_msgparser()
+        message_count = len(msgparser.get_messages())
+        version, build_versions = msgparser.get_version_info()
         log_info = [
-            "Loaded MCU '%s' %d commands (%s / %s)" % (
-                self._name, len(msgparser.messages_by_id),
-                msgparser.version, msgparser.build_versions),
+            "Loaded MCU '%s' %d commands (%s / %s)"
+            % (self._name, message_count, version, build_versions),
             "MCU '%s' config: %s" % (self._name, " ".join(
                 ["%s=%s" % (k, v) for k, v in self.get_constants().items()]))]
         return "\n".join(log_info)
@@ -635,8 +636,9 @@ class MCU:
         mbaud = msgparser.get_constant('SERIAL_BAUD', None)
         if self._restart_method is None and mbaud is None and not ext_only:
             self._restart_method = 'command'
-        self._get_status_info['mcu_version'] = msgparser.version
-        self._get_status_info['mcu_build_versions'] = msgparser.build_versions
+        version, build_versions = msgparser.get_version_info()
+        self._get_status_info['mcu_version'] = version
+        self._get_status_info['mcu_build_versions'] = build_versions
         self._get_status_info['mcu_constants'] = msgparser.get_constants()
         self.register_response(self._handle_shutdown, 'shutdown')
         self.register_response(self._handle_shutdown, 'is_shutdown')
@@ -693,7 +695,8 @@ class MCU:
         except self._serial.get_msgparser().error as e:
             return None
     def lookup_command_id(self, msgformat):
-        return self._serial.get_msgparser().lookup_command(msgformat).msgid
+        all_msgs = self._serial.get_msgparser().get_messages()
+        return {msgfmt: msgid for msgid, msgtype, msgfmt in all_msgs}[msgformat]
     def get_enumerations(self):
         return self._serial.get_msgparser().get_enumerations()
     def get_constants(self):
diff --git a/klippy/msgproto.py b/klippy/msgproto.py
index 72bdd534d..d1f461b87 100644
--- a/klippy/msgproto.py
+++ b/klippy/msgproto.py
@@ -1,6 +1,6 @@
 # Protocol definitions for firmware communication
 #
-# Copyright (C) 2016-2019  Kevin O'Connor <kevin@koconnor.net>
+# Copyright (C) 2016-2021  Kevin O'Connor <kevin@koconnor.net>
 #
 # This file may be distributed under the terms of the GNU GPLv3 license.
 import json, zlib, logging
@@ -128,6 +128,25 @@ def lookup_params(msgformat, enumerations={}):
         out.append((name, pt))
     return out
 
+# Lookup the message types for a debugging "output()" format string
+def lookup_output_params(msgformat):
+    param_types = []
+    args = msgformat
+    while 1:
+        pos = args.find('%')
+        if pos < 0:
+            break
+        if pos+1 >= len(args) or args[pos+1] != '%':
+            for i in range(4):
+                t = MessageTypes.get(args[pos:pos+1+i])
+                if t is not None:
+                    param_types.append(t)
+                    break
+            else:
+                raise error("Invalid output format for '%s'" % (msgformat,))
+        args = args[pos+1:]
+    return param_types
+
 # Update the message format to be compatible with python's % operator
 def convert_msg_format(msgformat):
     for c in ['%u', '%i', '%hu', '%hi', '%c', '%.*s', '%*s']:
@@ -177,21 +196,7 @@ class OutputFormat:
         self.msgid = msgid
         self.msgformat = msgformat
         self.debugformat = convert_msg_format(msgformat)
-        self.param_types = []
-        args = msgformat
-        while 1:
-            pos = args.find('%')
-            if pos < 0:
-                break
-            if pos+1 >= len(args) or args[pos+1] != '%':
-                for i in range(4):
-                    t = MessageTypes.get(args[pos:pos+1+i])
-                    if t is not None:
-                        self.param_types.append(t)
-                        break
-                else:
-                    raise error("Invalid output format for '%s'" % (msgformat,))
-            args = args[pos+1:]
+        self.param_types = lookup_output_params(msgformat)
     def parse(self, s, pos):
         pos += 1
         out = []
@@ -219,7 +224,7 @@ class MessageParser:
     def __init__(self):
         self.unknown = UnknownFormat()
         self.enumerations = {}
-        self.command_ids = []
+        self.messages = []
         self.messages_by_id = {}
         self.messages_by_name = {}
         self.config = {}
@@ -334,7 +339,7 @@ class MessageParser:
             #logging.exception("Unable to encode")
             raise error("Unable to encode: %s" % (msgname,))
         return cmd
-    def _fill_enumerations(self, enumerations):
+    def fill_enumerations(self, enumerations):
         for add_name, add_enums in enumerations.items():
             enums = self.enumerations.setdefault(add_name, {})
             for enum, value in add_enums.items():
@@ -352,30 +357,35 @@ class MessageParser:
                 start_value, count = value
                 for i in range(count):
                     enums[enum_root + str(start_enum + i)] = start_value + i
-    def _init_messages(self, messages, output_ids=[]):
+    def _init_messages(self, messages, command_ids=[], output_ids=[]):
         for msgformat, msgid in messages.items():
-            msgid = int(msgid)
-            if msgid in output_ids:
+            msgtype = 'response'
+            if msgid in command_ids:
+                msgtype = 'command'
+            elif msgid in output_ids:
+                msgtype = 'output'
+            self.messages.append((msgid, msgtype, msgformat))
+            if msgtype == 'output':
                 self.messages_by_id[msgid] = OutputFormat(msgid, msgformat)
-                continue
-            msg = MessageFormat(msgid, msgformat, self.enumerations)
-            self.messages_by_id[msgid] = msg
-            self.messages_by_name[msg.name] = msg
+            else:
+                msg = MessageFormat(msgid, msgformat, self.enumerations)
+                self.messages_by_id[msgid] = msg
+                self.messages_by_name[msg.name] = msg
     def process_identify(self, data, decompress=True):
         try:
             if decompress:
                 data = zlib.decompress(data)
             self.raw_identify_data = data
             data = json.loads(data)
-            self._fill_enumerations(data.get('enumerations', {}))
+            self.fill_enumerations(data.get('enumerations', {}))
             commands = data.get('commands')
             responses = data.get('responses')
             output = data.get('output', {})
             all_messages = dict(commands)
             all_messages.update(responses)
             all_messages.update(output)
-            self.command_ids = sorted(commands.values())
-            self._init_messages(all_messages, output.values())
+            self._init_messages(all_messages, commands.values(),
+                                output.values())
             self.config.update(data.get('config', {}))
             self.version = data.get('version', '')
             self.build_versions = data.get('build_versions', '')
@@ -384,6 +394,10 @@ class MessageParser:
         except Exception as e:
             logging.exception("process_identify error")
             raise error("Error during identify: %s" % (str(e),))
+    def get_version_info(self):
+        return self.version, self.build_versions
+    def get_messages(self):
+        return list(self.messages)
     def get_enumerations(self):
         return dict(self.enumerations)
     def get_constants(self):
diff --git a/scripts/buildcommands.py b/scripts/buildcommands.py
index eaa6d5859..9e811485b 100644
--- a/scripts/buildcommands.py
+++ b/scripts/buildcommands.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python2
 # Script to handle build time requests embedded in C code.
 #
-# Copyright (C) 2016-2018  Kevin O'Connor <kevin@koconnor.net>
+# Copyright (C) 2016-2021  Kevin O'Connor <kevin@koconnor.net>
 #
 # This file may be distributed under the terms of the GNU GPLv3 license.
 import sys, os, subprocess, optparse, logging, shlex, socket, time, traceback
@@ -172,8 +172,8 @@ class HandleInitialPins:
         if not self.initial_pins:
             return []
         mp = msgproto.MessageParser()
-        mp._fill_enumerations(HandlerEnumerations.enumerations)
-        pinmap = mp.enumerations.get('pin', {})
+        mp.fill_enumerations(HandlerEnumerations.enumerations)
+        pinmap = mp.get_enumerations().get('pin', {})
         out = []
         for p in self.initial_pins:
             flag = "IP_OUT_HIGH"
@@ -304,13 +304,15 @@ class HandleCommandGeneration:
                    if msgid not in command_ids and msgid not in response_ids }
         if output:
             data['output'] = output
-    def build_parser(self, parser, iscmd):
-        if parser.name == "#output":
-            comment = "Output: " + parser.msgformat
+    def build_parser(self, msgid, msgformat, msgtype):
+        if msgtype == "output":
+            param_types = msgproto.lookup_output_params(msgformat)
+            comment = "Output: " + msgformat
         else:
-            comment = parser.msgformat
+            param_types = [t for name, t in msgproto.lookup_params(msgformat)]
+            comment = msgformat
         params = '0'
-        types = tuple([t.__class__.__name__ for t in parser.param_types])
+        types = tuple([t.__class__.__name__ for t in param_types])
         if types:
             paramid = self.all_param_types.get(types)
             if paramid is None:
@@ -322,15 +324,15 @@ class HandleCommandGeneration:
     .msg_id=%d,
     .num_params=%d,
     .param_types = %s,
-""" % (comment, parser.msgid, len(types), params)
-        if iscmd:
+""" % (comment, msgid, len(types), params)
+        if msgtype == 'response':
             num_args = (len(types) + types.count('PT_progmem_buffer')
                         + types.count('PT_buffer'))
             out += "    .num_args=%d," % (num_args,)
         else:
             max_size = min(msgproto.MESSAGE_MAX,
                            (msgproto.MESSAGE_MIN + 1
-                            + sum([t.max_length for t in parser.param_types])))
+                            + sum([t.max_length for t in param_types])))
             out += "    .max_size=%d," % (max_size,)
         return out
     def generate_responses_code(self):
@@ -342,17 +344,15 @@ class HandleCommandGeneration:
             msgid = self.msg_to_id[msg]
             if msgid in did_output:
                 continue
-            s = msg
             did_output[msgid] = True
             code = ('    if (__builtin_strcmp(str, "%s") == 0)\n'
-                    '        return &command_encoder_%s;\n' % (s, msgid))
+                    '        return &command_encoder_%s;\n' % (msg, msgid))
             if msgname is None:
-                parser = msgproto.OutputFormat(msgid, msg)
+                parsercode = self.build_parser(msgid, msg, 'output')
                 output_code.append(code)
             else:
-                parser = msgproto.MessageFormat(msgid, msg)
+                parsercode = self.build_parser(msgid, msg, 'command')
                 encoder_code.append(code)
-            parsercode = self.build_parser(parser, 0)
             encoder_defs.append(
                 "const struct command_encoder command_encoder_%s PROGMEM = {"
                 "    %s\n};\n" % (
@@ -392,8 +392,7 @@ ctr_lookup_output(const char *str)
             funcname, flags, msgname = cmd_by_id[msgid]
             msg = self.messages_by_name[msgname]
             externs[funcname] = 1
-            parser = msgproto.MessageFormat(msgid, msg)
-            parsercode = self.build_parser(parser, 1)
+            parsercode = self.build_parser(msgid, msg, 'response')
             index.append(" {%s\n    .flags=%s,\n    .func=%s\n}," % (
                 parsercode, flags, funcname))
         index = "".join(index).strip()