久久ER99热精品一区二区-久久精品99国产精品日本-久久精品免费一区二区三区-久久综合九色综合欧美狠狠

新聞中心

EEPW首頁 > 智能計算 > 設計應用 > 使用Tensorflow訓練模型并且部署到ESP32S3進行推理

使用Tensorflow訓練模型并且部署到ESP32S3進行推理

作者:御坂美琴 時間:2025-09-28 來源:EEPW 收藏

1   簡介

最近這一段時間在學習機器學習,也嘗試將一個SkLearn的模型部署到了PocketBeagle 2 上(感謝論壇提供的試用機會),發現效果是真的不錯。所以我就在想有沒有什么便捷的方式能夠將一個簡單的模型部署到單片機上來實現某種行為的邊緣計算。于是經過我的搜索后找到了一個ESP32 的基于 lite 的庫。如下圖所示。

1759024430892080.png

上述的這個庫是基于TF官方的TF-lite進行ESP32-S3的適配。所以對應TF官方的介紹如下圖所示。

1759024468138168.png

至此,邏輯關系已經被整理清楚了,TFlite是針對資源受限的Machine Learing庫,而的esp-tflite-micro是TFlite對ESP32設備的一個具體實現。

那么在本篇文章中將帶著大家從零開始進行Demo的燒錄(測試意圖)、模型訓練,轉換成C 語言數組。然后到模型的部署,最終實現和模型一樣的效果。閱讀這篇文章你最好具備一些基礎的Machine learning和Deep learning知識。

2 Hello World Demo燒錄

1-首先,你本地已經安裝好了IDF 的環境,你只需要在任意一個IDF 的項目目錄下執行下述命令來添加esp-tflite-micro的依賴

view plaincopy to clipboardprint?

1.idf.py add-dependency “esp-tflite-micro”

2-基于現在的項目新建Helloworld 的項目

view plaincopy to clipboardprint?

1. idf.py create-project-from-example “esp-tflitemicro:hello_world”

之后便可以對當前的demo 進行燒錄了。當然重點不在這里。下述截圖為實際Demo 的實際運行效果:

1759024545106498.png

3   訓練模型

對于模型的訓練,我這里環境依賴是被Anaconda進行管理的,使用的是TF 的完整版進行訓練。Demo的HelloWorld 訓練代碼來自于TFlite,可以在HelloWorldDemo中的readme中找到對應的鏈接。我們對其進行少量的修改使其可以直接在Jupyter notebook 中運行。即Python代碼。移除外部參數傳遞。

下面是代碼的核心步驟:

view plaincopy to clipboardprint?

1. def get_data():

2.     “””

3.     Generate a set of random `x` values and calculate their sine values.

4.     “””

5.     x_values = np.random.uniform(low=0, high=2 *math.pi, size=1000).astype(np.float32)

6.     np.random.shuffle(x_values)

7.     y_values = np.sin(x_values).astype(np.float32)

8.     return (x_values, y_values)

首先生成隨機的正選隨機數,進行打亂,然后返回總體的X 和Y 向量。

view plaincopy to clipboardprint?

1. def create_model() -> tf.keras.Model:

2.      model = tf.keras.Sequential([

3.         tf.keras.Input(shape=(1,)),

4.         tf.keras.layers.Dense(16, activation=”relu”),

5.         tf.keras.layers.Dense(16, activation=”relu”),

6.         tf.keras.layers.Dense(1)

7.     ])

8.     model.compile(optimizer=”adam”, loss=”mse”, metrics=[“mae”])

9.     return model

模型采用的是一個三層的神經網絡,輸入1,輸出1。其中兩層每層一共16 個神經元用來學習特征。

view plaincopy to clipboardprint?

1. def main():

2.     x_values, y_values = get_data()

3.         trained_model = train_model(EPOCHS, x_values, y_values)

4.

5.     # Convert and save the model to .tflite

6.     tflite_model = convert_tflite_model(trained_model)

7.     save_tfl ite_model(tfl ite_model, SAVE_DIR,model_name=”hello_world_float.tflite”)

然后對模型進行訓練,同時轉換成tflite 的格式。

1759024656141736.png

之后使用xxd將這個tflite的模型抓換成C 語言的數組。

view plaincopy to clipboardprint?

1.xxd -i hello_world_int8.tfl ite > hello_world_model_data.cc

至此模型的訓練和轉換已經完成了。

4   部署模型

對于模型的部署,HelloWorld 給了我們一個很好的示例。我們只需要把我們轉換成CC 文件中的c 語言數組拷貝到Model.CC文件中即可。

1759024706827210.png

注意,并不能全拷貝,只拷貝數組部分即可。和下方的數組長度。

1759024750306536.png

注意數組的類型,不要全拷貝。然后修改Model內extern暴露的數組名稱和模型數組名稱一致。

1759024777369395.png

然后修改SetUpfunction中的數組名稱為模型的名稱。

1759024812803555.png

由于我們訓練的模型沒有進行量化,所以直接使用未經量化的float類型即可。將代碼修改成下述代碼。使用了Float 類型進行輸入和輸出:

view plaincopy to clipboardprint?

1. // The name of this function is important for Arduino compatibility.

2. void loop()

3. {

4.     // Calculate an x value to feed into the model. We compare the current

5.     // inference_count to the number of inferences per cycle to determine

6.     // our position within the range of possible x values the model was

7.     // trained on, and use this to calculate a value.

8.     float position = static_cast<float>(inference_count) /

9.     static_cast<float>(kInferencesPerCycle);

10.     float x = position * kXrange;

11.

12.     input->data.f[0] = x;

13.

14.     // Run inference, and report any error

15.     TfLiteStatus invoke_status = interpreter->Invoke();

16.     if (invoke_status != kTfLiteOk)

17.     {

18.         MicroPrintf(“Invoke failed on x: %fn”,

19.         static_cast<double>(x));

20.         return;

21.     }

22.

23.     float y = output->data.f[0];

24.

25.     // Output the results. A custom HandleOutput function can be implemented

26.     // for each supported hardware target.

27.     HandleOutput(x, y);

28.

29.     // Increment the inference_counter, and reset it if we have reached

30.     // the total number per cycle

31.     inference_count += 1;

32.     if (inference_count >= kInferencesPerCycle)

33.     inference_count = 0;

34. }

需要注意的是,如果你的模型進行過量化,那就根據對應的量化參數進行傳遞。否則模型精度將會很低。

5   實驗效果

1759024893577793.png

模型的X 輸入和Y 輸出。滿足預期。

(本文來源于《EEPW》


評論


技術專區

關閉