stm32f4-uart-bootloader

Simple UART bootloader for STM32F4 MCU's
git clone git://git.mdnr.space/stm32f4-uart-bootloader
Log | Files | Refs | Submodules | README | LICENSE

bootloader.py (16577B)


      1 #!/usr/bin/python
      2 
      3 from construct import *
      4 import serial
      5 import time
      6 import ast
      7 import os
      8 import sys
      9 
     10 DMA_RX_BUFFER_SIZE = 128
     11 PACKET_MAX_DATA_LEN = DMA_RX_BUFFER_SIZE - 1
     12 
     13 PACK_TAG_INDEX = 0
     14 PACK_LENGTH_INDEX = 1
     15 
     16 fwInfoStruct = Struct(
     17         "sentinel" / Int32ul,
     18         "devID" / Int32ul,
     19         "fwVersion" / Struct(
     20                 "major" / Int8ul,
     21                 "minor" / Int8ul,
     22                 "patch" / Int8ul,
     23                 "sha1" / Array(5, Int8ul)
     24                 ),
     25         "fwLength" / Int32ul,
     26         "crc32" / Int32ul,
     27         )
     28 
     29 paramStruct = Struct(
     30         "data" / Struct(
     31             "float_param_1" / Float32l,
     32             "float_param_2" / Float32l,
     33             "dw_param" / Int32ul,
     34             "w_param" / Int16ul,
     35             "bool_param_1" / Int8ul,
     36             "bool_param_2" / Int8ul,
     37             )
     38         )
     39 
     40 packetStruct = Struct(
     41         "tag" / Enum(Byte, 
     42                      PR=0, 
     43                      PW=1,
     44                      FIR=2,
     45                      RST=3,
     46                      ACK=4,
     47                      NACK=5,
     48                      LOCK=6,
     49                      UNLOCK=7,
     50                      DIAG=8,
     51                      SYNC=9,
     52                      UPDATE_REQ=10,
     53                      DEVID_CHECK=11,
     54                      UPDATE_SIZE=12,
     55                      ERASE_REQ=13,
     56                      ),
     57         "len" / Int8ul,
     58         "data" / Array(this.len, Byte),
     59         "crc" / Byte[2],
     60         )
     61 
     62 fwChunkStruct = Struct(
     63         "tag" / Enum(Byte,
     64                      FW_NEW=14,
     65                      FW_REP=15,
     66                      ),
     67         "len" / Int8ul,
     68         "data" / Array(this.len, Byte),
     69         "crc32" / Byte[4],
     70         )
     71 
     72 port = input('\nEnter the serial port:\nInput: ')
     73 # port = '/dev/ttyUSB0'
     74 ser = serial.Serial(port)
     75 ser.baudrate = 921600
     76 action = 0
     77 waitForPacket = False
     78 
     79 packetType = Enum(Byte, readInfo=0, readParam=1, reset=2, pUpdate=3, serialLock=4, serialFree=5, readDiag=6)
     80 
     81 def calculate_crc16(data: bytes, poly=0x8005):
     82     '''
     83     CRC-16-CCITT Algorithm
     84     '''
     85     data = bytearray(data)
     86     crc = 0x0000
     87     for b in data:
     88         cur_byte = 0xFF & b
     89         for i in range(0, 8):
     90             bit_flag = (crc >> 15) & 0x01
     91             crc = crc << 1
     92             crc = crc | (cur_byte >> (7 - i)) & 1
     93             if bit_flag:
     94                 crc = crc ^ poly
     95     return crc & 0xFFFF
     96 
     97 def calculate_crc32(data: bytes, poly=0xEDB88320):
     98     '''
     99     CRC-16-CCITT Algorithm
    100     '''
    101     data = bytearray(data)
    102     crc = 0x00000000
    103     for b in data:
    104         cur_byte = 0xFF & b
    105         for i in range(0, 8):
    106             bit_flag = (crc >> 31) & 0x01
    107             crc = crc << 1
    108             crc = crc | (cur_byte >> (7 - i)) & 1
    109             if bit_flag:
    110                 crc = crc ^ poly
    111     return crc & 0xFFFFFFFF
    112 
    113 def create_packet(tag: Enum, data = [], isFirmware = False):
    114     if not isFirmware:
    115         tempPacket = packetStruct.build(dict(tag = tag, len = len(data), data = data, crc = [0, 0]))
    116         crcVal = calculate_crc16(tempPacket)
    117         packet = packetStruct.build(dict(tag = tag, len = len(data), data = data, 
    118                                          crc = [(crcVal >> 8) & 0xff, crcVal & 0xff]))
    119     elif isFirmware:
    120         fwChunk = fwChunkStruct.build(dict(tag = tag, len = len(data), data = data, crc32 = [0, 0, 0, 0]))
    121         crc32Val = calculate_crc32(fwChunk)
    122         packet = fwChunkStruct.build(dict(tag = tag, len = len(data), data = data, 
    123             crc32 = [(crc32Val >> 24) & 0xff, (crc32Val >> 16) & 0xff, (crc32Val >> 8) & 0xff, crc32Val & 0xff]))
    124     return packet
    125 
    126 def update_state_machine():
    127     time.sleep(1)
    128     retry = 0
    129     state = "SYNC"
    130     while state != 'DONE':
    131         if retry > 0:
    132             print("\nRetrying...",retry)
    133             if retry >= 5:
    134                 state = 'DONE'
    135         match state:
    136             case 'SYNC':
    137                 print("[INFO] Sending sync request...")
    138                 packet = create_packet("SYNC", [])
    139                 ser.write(packet)
    140                 ser.timeout = 1
    141                 response = ser.read(5)
    142                 if len(response) != 0 and calculate_crc16(response) == 0:
    143                     packet = packetStruct.parse(response)
    144                     if packet.tag == 'ACK':
    145                         retry = 0
    146                         print("[SUCCESS] Synced")
    147                         state = 'UPDATE_REQ'
    148                     else: 
    149                         print("[ERROR] NACK received")
    150                         retry += 1
    151                 else:
    152                     print("[TIMEOUT] No response received")
    153                     retry += 1
    154             case 'UPDATE_REQ':
    155                 print("[INFO] Sending update request...")
    156                 packet = create_packet("UPDATE_REQ", [])
    157                 ser.write(packet)
    158                 ser.timeout = 1
    159                 response = ser.read(5)
    160                 if len(response) != 0 and calculate_crc16(response) == 0:
    161                     packet = packetStruct.parse(response)
    162                     if packet.tag == 'ACK':
    163                         retry = 0
    164                         print("[SUCCESS] Update request acknowledged")
    165                         state = 'DEVID_CHECK'
    166                     else: 
    167                         print("[ERROR] NACK received")
    168                         retry += 1
    169                 else:
    170                     print("[TIMEOUT] No response received")
    171                     retry += 1
    172             case 'DEVID_CHECK':
    173                 try: 
    174                     devIDFile = open('devID', 'r')
    175                 except OSError:
    176                     print("[ERROR] devID file not found.")
    177                     print("[INFO] Crete a file named \"devID\" in the current directory and write target device id in it.")
    178                     sys.exit()
    179                 with devIDFile:
    180                     devID = ast.literal_eval(devIDFile.read())
    181                     data = devID.to_bytes(4, byteorder = 'little')
    182                 print("[INFO] Checking device id...")
    183                 packet = create_packet("DEVID_CHECK", data)
    184                 ser.write(packet)
    185                 ser.timeout = 1
    186                 response = ser.read(5)
    187                 if len(response) != 0 and calculate_crc16(response) == 0:
    188                     packet = packetStruct.parse(response)
    189                     if packet.tag == 'ACK':
    190                         retry = 0
    191                         print("[SUCCESS] Device ID is correct")
    192                         state = 'UPDATE_SIZE'
    193                     else: 
    194                         print("[ERROR] NACK received")
    195                         retry += 1
    196                 else:
    197                     print("[TIMEOUT] No response received")
    198                     retry += 1
    199             case 'UPDATE_SIZE':
    200                 try: 
    201                     firmware = open('./build/fw.bin', 'r')
    202                 except OSError:
    203                     print("[ERROR] File \"./build/fw.bin\" not found.")
    204                     sys.exit()
    205                 with firmware:
    206                     FWSize = os.path.getsize('./build/fw.bin')
    207                     data = FWSize.to_bytes(4, byteorder = 'little')
    208                 print("[INFO] Sending update size...")
    209                 packet = create_packet("UPDATE_SIZE", data)
    210                 ser.write(packet)
    211                 ser.timeout = 1
    212                 response = ser.read(5)
    213                 if len(response) != 0 and calculate_crc16(response) == 0:
    214                     packet = packetStruct.parse(response)
    215                     if packet.tag == 'ACK':
    216                         retry = 0
    217                         print("[SUCCESS] Update size deliviered")
    218                         state = 'ERASE_REQ'
    219                     else: 
    220                         print("[ERROR] NACK received")
    221                         retry += 1
    222                 else:
    223                     print("[TIMEOUT] No response received")
    224                     retry += 1
    225             case 'ERASE_REQ':
    226                 print("[INFO] Sending erase request...")
    227                 packet = create_packet("ERASE_REQ", [])
    228                 ser.write(packet)
    229                 ser.timeout = 4
    230                 response = ser.read(4)
    231                 if len(response) != 0 and calculate_crc16(response) == 0:
    232                     packet = packetStruct.parse(response)
    233                     if packet.tag == 'ACK':
    234                         retry = 0
    235                         print("[SUCCESS] Chip erase done")
    236                         state = 'FIRMWARE_UPLOAD'
    237                     else: 
    238                         print("[ERROR] NACK received")
    239                         retry += 1
    240                 else:
    241                     print("[TIMEOUT] No response received")
    242                     retry += 1
    243             case 'FIRMWARE_UPLOAD':
    244                 try: 
    245                     firmware = open('./build/fw.bin', 'rb')
    246                 except OSError:
    247                     print("[ERROR] File \"./build/fw.bin\" not found.")
    248                     sys.exit()
    249                 with firmware:
    250                     totalChunk = (FWSize // 128) + 1
    251                     pack = 1
    252                     while True and retry <= 5:
    253                         if retry == 0:
    254                             chunk = firmware.read(128)
    255                         if not chunk: 
    256                             break
    257                         persent = pack * 100 // totalChunk
    258                         progress = persent // 2
    259                         print("Writing chunk",pack,"/",totalChunk,"[",
    260                               ''.join('#'*progress),
    261                               ''.join(' '*(50 - progress)),"]",persent,"%",end = "\r")
    262                         if retry == 0:
    263                             pack += 1
    264                             packet = create_packet("FW_NEW", chunk, True)
    265                         else:
    266                             packet = create_packet("FW_REP", chunk, True)
    267                         ser.write(packet)
    268                         ser.timeout = 4
    269                         response = ser.read(4)
    270                         if len(response) != 0 and calculate_crc16(response) == 0:
    271                             packet = packetStruct.parse(response)
    272                             if packet.tag == 'ACK':
    273                                 retry = 0
    274                             else: 
    275                                 print("\n[ERROR] NACK received. Resending chunk:",pack)
    276                                 retry += 1
    277                         else:
    278                             print("[TIMEOUT] No response received. Resending chunk:",pack)
    279                             retry += 1
    280                 if pack == totalChunk + 1:
    281                     print("\n[INFO] Firmware update was successfull")
    282                 else: 
    283                     print("\n[ERROR] Firmware upload failed")
    284                 state = 'DONE'
    285 
    286 def protocol_state_machine(prx: Bytes, type: Enum):
    287     # print("received: {:x}",prx)
    288     # print (''.join('{:02x}, '.format(x) for x in prx))
    289     if calculate_crc16(prx) != 0:
    290         print("CRC Error\n")
    291     else:
    292         packet = packetStruct.parse(prx)
    293         match packet.tag:
    294             case 'ACK':
    295                 if type == 'readInfo':
    296                     fwInfo = fwInfoStruct.parse(prx[2:len(packet.data)+2])
    297                     print("####################################")
    298                     print("Sentinel: %#4x"% fwInfo.sentinel)
    299                     print("Device ID: %d"% fwInfo.devID)
    300                     print("FW Version: v{0:1d}.{1:1d}-{2:1d}-{3:1}{4:1}{5:1}{6:1}{7:1}".format(fwInfo.fwVersion.major, fwInfo.fwVersion.minor, fwInfo.fwVersion.patch, chr(fwInfo.fwVersion.sha1[0]), chr(fwInfo.fwVersion.sha1[1]), chr(fwInfo.fwVersion.sha1[2]), chr(fwInfo.fwVersion.sha1[3]), chr(fwInfo.fwVersion.sha1[4])))
    301                     print("FW Size: %4d"% fwInfo.fwLength)
    302                     print("CRC-32: %#4x"% fwInfo.crc32)
    303                     print("####################################")
    304                 elif type == 'readParam':
    305                     print("####################################")
    306                     params = paramStruct.parse(prx[2:len(packet.data)+2])
    307                     print("Current Instance: ",params.currentInstance)
    308                     print("Entrance Threshold: %d"% params.entranceTh)
    309                     print("Exit Threshold: %d"% params.exitTh)
    310                     print("Entrance Debounce: %d"% params.entranceDB)
    311                     print("Exit Debounce: %d"% params.exitDB)
    312                     print("Layout: ",params.layout)
    313                     print("capture config: ",params.captureConfig)
    314                     print("Filter Avg Window: ",params.filterAvgWindow)
    315                     print("####################################")
    316                 elif type == 'pUpdate':
    317                     print("####################################")
    318                     print("Parameters updated successfully.")
    319                     print("####################################")
    320                 elif type == 'reset':
    321                     print("Core rebooted. Sending sync command ...")
    322                     update_state_machine()
    323                 elif type == 'serialLock' or type == 'serialFree':
    324                     print("####################################")
    325                     print("Process successfull\n")
    326                     print("####################################")
    327                 elif type == 'readDiag':
    328                     if packet.len == 0:
    329                         print("####################################")
    330                         print("no diagnostics saved\n")
    331                         print("####################################")
    332                     else:
    333                         print("####################################")
    334                         print("Diagnostics:\n")
    335                         for i in range(0, packet.len):
    336                             print(i+1,"->","".join(hex(packet.data[i])))
    337                         print("####################################")
    338 
    339             case 'NACK':
    340                 print("####################################")
    341                 print("Nack received\n")
    342                 print("####################################")
    343 
    344 
    345 while (action != '0'):
    346     action = input('\nSelect an action to perform:\n\
    347             1. Read parameters\n\
    348             2. Update parameters\n\
    349             3. Read firmware info\n\
    350             4. Reset the core\n\
    351             5. Lock serial line\n\
    352             6. Unlock serial line\n\
    353             7. Read diagnostics\n\
    354             0. Abort\n\
    355             Input: ')
    356     match action:
    357         case '1':
    358             packet = create_packet("PR", [0, paramStruct.sizeof()])
    359             ser.write(packet)
    360             queryType = "readParam"
    361             waitForPacket = True
    362         case '2':
    363             print('\nLoading parameters from config file ...')
    364             time.sleep(0.2)
    365             fileParam = ast.literal_eval(open('param.txt', 'r').read())
    366             packet = create_packet("PW", paramStruct.build(fileParam))
    367             print('\nUpdating parameters ...\n')
    368             time.sleep(0.2)
    369             ser.write(packet)
    370             queryType = "pUpdate"
    371             waitForPacket = True
    372 
    373         case '3':
    374             packet = create_packet("FIR", [0, fwInfoStruct.sizeof()])
    375             ser.write(packet)
    376             queryType = "readInfo"
    377             waitForPacket = True
    378 
    379         case '4':
    380             print('Sending reboot command...\n')
    381             packet = create_packet("RST", [])
    382             queryType = "reset"
    383             ser.write(packet)
    384             waitForPacket = True
    385 
    386         case '5':
    387             print('Sending lock command...\n')
    388             packet = create_packet("LOCK", [])
    389             queryType = "serialLock"
    390             ser.write(packet)
    391             waitForPacket = True
    392 
    393         case '6':
    394             print('Unlocking serial line...\n')
    395             packet = create_packet("UNLOCK", [])
    396             queryType = "serialFree"
    397             ser.write(packet)
    398             waitForPacket = True
    399 
    400         case '7':
    401             print('Sending diag read command...\n')
    402             packet = create_packet("DIAG", [])
    403             queryType = "readDiag"
    404             ser.write(packet)
    405             waitForPacket = True
    406 
    407         case '0':
    408             print('Aborting ...')
    409             time.sleep(0.5)
    410             break
    411 
    412         case _:
    413             print('[ERROR] Invalid input')
    414             time.sleep(0.5)
    415         
    416     if waitForPacket:
    417         ser.timeout = 0.2
    418         prx = ser.read(1024)
    419         if len(prx) != 0:
    420             waitForPacket = False
    421             protocol_state_machine(prx, queryType)
    422         else:
    423             print("\n[TIMEOUT] No data received\n")