Ei kuvausta

onnx.tsx 15KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. import { useEffect, useMemo, useState } from 'react';
  2. import { StyleSheet, View } from 'react-native';
  3. import { Image } from 'expo-image';
  4. import * as ImagePicker from 'expo-image-picker';
  5. import * as ImageManipulator from 'expo-image-manipulator';
  6. import jpeg from 'jpeg-js';
  7. import ParallaxScrollView from '@/components/parallax-scroll-view';
  8. import { ThemedText } from '@/components/themed-text';
  9. import { ThemedView } from '@/components/themed-view';
  10. import { IconSymbol } from '@/components/ui/icon-symbol';
  11. import { Colors, Fonts } from '@/constants/theme';
  12. import { useTranslation } from '@/localization/i18n';
  13. import { ThemedButton } from '@/components/themed-button';
  14. import { dbPromise, initCoreTables } from '@/services/db';
  15. import { useColorScheme } from '@/hooks/use-color-scheme';
  16. const classNames = require('@/assets/class_names.json') as string[];
  17. type Prediction = {
  18. index: number;
  19. label: string;
  20. score: number;
  21. };
  22. type InferenceLog = {
  23. id: number;
  24. top_label: string | null;
  25. top_score: number | null;
  26. created_at: string | null;
  27. uri: string | null;
  28. };
  29. export default function OnnxScreen() {
  30. const { t } = useTranslation();
  31. const theme = useColorScheme() ?? 'light';
  32. const palette = Colors[theme];
  33. const [status, setStatus] = useState(t('onnx.status.pick'));
  34. const [imageUri, setImageUri] = useState<string | null>(null);
  35. const [predictions, setPredictions] = useState<Prediction[]>([]);
  36. const [logs, setLogs] = useState<InferenceLog[]>([]);
  37. const canRun = useMemo(() => Boolean(imageUri), [imageUri]);
  38. useEffect(() => {
  39. let active = true;
  40. initCoreTables()
  41. .then(async () => {
  42. const rows = await loadLogs();
  43. if (active) setLogs(rows);
  44. })
  45. .catch((error) => {
  46. setStatus(`Error: ${String(error)}`);
  47. });
  48. return () => {
  49. active = false;
  50. };
  51. }, []);
  52. async function pickImage() {
  53. const result = await ImagePicker.launchImageLibraryAsync({
  54. mediaTypes: getImageMediaTypes(),
  55. quality: 1,
  56. });
  57. if (result.canceled) return;
  58. const asset = result.assets[0];
  59. setImageUri(asset.uri);
  60. setPredictions([]);
  61. setStatus(t('onnx.status.ready'));
  62. }
  63. async function handleRun() {
  64. if (!imageUri) return;
  65. setStatus(t('onnx.status.preprocessing'));
  66. try {
  67. const { runPlantVillageInference } = await import('@/services/onnx/runPlantVillage');
  68. const inputTensor = await imageToTensor(imageUri);
  69. setStatus(t('onnx.status.running'));
  70. const scores = await runPlantVillageInference(inputTensor);
  71. const probabilities = softmax(scores);
  72. const top = topK(probabilities, 3).map((item) => ({
  73. ...item,
  74. label: humanizeLabel(classNames[item.index] ?? `Class ${item.index}`),
  75. }));
  76. setPredictions(top);
  77. await logInference({
  78. uri: imageUri,
  79. top,
  80. scores: probabilities,
  81. });
  82. setLogs(await loadLogs());
  83. setStatus(t('onnx.status.done'));
  84. } catch (error) {
  85. const message = String(error);
  86. if (message.includes('install') || message.includes('onnxruntime')) {
  87. setStatus(t('onnx.status.nativeMissing'));
  88. } else {
  89. setStatus(`Error: ${message}`);
  90. }
  91. }
  92. }
  93. return (
  94. <ParallaxScrollView
  95. headerBackgroundColor={{ light: '#E9E4DB', dark: '#2A2A2A' }}
  96. headerImage={
  97. <Image
  98. source={require('@/assets/images/leafscan.jpg')}
  99. style={styles.headerImage}
  100. contentFit="cover"
  101. />
  102. }>
  103. <ThemedView style={[styles.hero, { backgroundColor: palette.surface, borderColor: palette.border }]}>
  104. <View style={[styles.heroIcon, { backgroundColor: palette.card }]}>
  105. <IconSymbol size={44} color={palette.tint} name="leaf.fill" />
  106. </View>
  107. <View style={styles.heroText}>
  108. <ThemedText type="title" style={styles.heroTitle}>
  109. {t('onnx.title')}
  110. </ThemedText>
  111. <ThemedText style={[styles.heroSubtitle, { color: palette.muted }]}>
  112. {t('onnx.howBody')}
  113. </ThemedText>
  114. </View>
  115. </ThemedView>
  116. <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
  117. <View style={styles.cardHeader}>
  118. <ThemedText type="subtitle">{t('onnx.testTitle')}</ThemedText>
  119. <View style={[styles.statusPill, { backgroundColor: palette.surface }]}>
  120. <ThemedText style={[styles.statusText, { color: palette.muted }]}>{status}</ThemedText>
  121. </View>
  122. </View>
  123. {!imageUri ? (
  124. <ThemedText style={[styles.inlineHint, { color: palette.muted }]}>
  125. Pick an image to run the model.
  126. </ThemedText>
  127. ) : null}
  128. <View
  129. style={[
  130. styles.previewWrap,
  131. { backgroundColor: palette.surface, borderColor: palette.border },
  132. ]}>
  133. {imageUri ? (
  134. <Image source={{ uri: imageUri }} style={styles.preview} contentFit="cover" />
  135. ) : (
  136. <View style={styles.previewPlaceholder}>
  137. <IconSymbol size={36} color={palette.muted} name="photo.on.rectangle.angled" />
  138. <ThemedText style={[styles.placeholderText, { color: palette.muted }]}>
  139. {t('onnx.status.pick')}
  140. </ThemedText>
  141. </View>
  142. )}
  143. </View>
  144. <View style={styles.actionRow}>
  145. <ThemedButton title={t('onnx.pickImage')} onPress={pickImage} variant="secondary" />
  146. <ThemedButton title={t('onnx.runModel')} onPress={handleRun} disabled={!canRun} />
  147. </View>
  148. </ThemedView>
  149. {predictions.length > 0 ? (
  150. <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
  151. <ThemedText type="subtitle">{t('onnx.topPredictions')}</ThemedText>
  152. <View style={styles.results}>
  153. {predictions.map((pred) => {
  154. const percent = Number((pred.score * 100).toFixed(2));
  155. return (
  156. <View key={pred.index} style={styles.predRow}>
  157. <View style={styles.predLabelRow}>
  158. <ThemedText style={styles.predLabel}>{pred.label}</ThemedText>
  159. <ThemedText style={styles.predPercent}>{percent}%</ThemedText>
  160. </View>
  161. <View style={styles.barTrack}>
  162. <View style={[styles.barFill, { width: `${Math.min(percent, 100)}%` }]} />
  163. </View>
  164. </View>
  165. );
  166. })}
  167. </View>
  168. </ThemedView>
  169. ) : null}
  170. {logs.length > 0 ? (
  171. <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
  172. <ThemedText type="subtitle">Recent scans</ThemedText>
  173. <View style={styles.logList}>
  174. {logs.map((item) => {
  175. const score = item.top_score ? Math.round(item.top_score * 100) : null;
  176. return (
  177. <View key={item.id} style={[styles.logRow, { backgroundColor: palette.surface }]}>
  178. <View style={[styles.logThumb, { backgroundColor: palette.card }]}>
  179. {item.uri ? (
  180. <Image source={{ uri: item.uri }} style={styles.logImage} contentFit="cover" />
  181. ) : (
  182. <IconSymbol size={18} color={palette.muted} name="photo" />
  183. )}
  184. </View>
  185. <View style={styles.logMeta}>
  186. <ThemedText style={styles.logLabel}>
  187. {item.top_label ?? t('onnx.topPredictions')}
  188. </ThemedText>
  189. <ThemedText style={[styles.logSub, { color: palette.muted }]}>
  190. {item.created_at ? formatDate(item.created_at) : ''}
  191. </ThemedText>
  192. </View>
  193. <ThemedText style={[styles.logScore, { color: palette.tint }]}>
  194. {score !== null ? `${score}%` : '--'}
  195. </ThemedText>
  196. </View>
  197. );
  198. })}
  199. </View>
  200. </ThemedView>
  201. ) : null}
  202. <ThemedView style={[styles.card, { backgroundColor: palette.card, borderColor: palette.border }]}>
  203. <ThemedText type="subtitle">{t('onnx.howTitle')}</ThemedText>
  204. <ThemedText>{t('onnx.howBody')}</ThemedText>
  205. </ThemedView>
  206. </ParallaxScrollView>
  207. );
  208. }
  209. function getImageMediaTypes() {
  210. const mediaType = (ImagePicker as { MediaType?: { Image?: unknown; Images?: unknown } })
  211. .MediaType;
  212. return mediaType?.Image ?? mediaType?.Images ?? ['images'];
  213. }
  214. function topK(scores: Float32Array, k: number): Prediction[] {
  215. const items = Array.from(scores).map((score, index) => ({
  216. index,
  217. label: `Class ${index}`,
  218. score,
  219. }));
  220. items.sort((a, b) => b.score - a.score);
  221. return items.slice(0, k);
  222. }
  223. function softmax(scores: Float32Array) {
  224. const max = Math.max(...scores);
  225. const exps = Array.from(scores).map((value) => Math.exp(value - max));
  226. const sum = exps.reduce((acc, value) => acc + value, 0);
  227. return new Float32Array(exps.map((value) => value / sum));
  228. }
  229. async function imageToTensor(uri: string) {
  230. const resized = await ImageManipulator.manipulateAsync(
  231. uri,
  232. [{ resize: { width: 224, height: 224 } }],
  233. {
  234. base64: true,
  235. format: ImageManipulator.SaveFormat.JPEG,
  236. compress: 1,
  237. }
  238. );
  239. if (!resized.base64) {
  240. throw new Error('Failed to read image data.');
  241. }
  242. const bytes = base64ToUint8Array(resized.base64);
  243. const decoded = jpeg.decode(bytes, { useTArray: true });
  244. const { data, width, height } = decoded;
  245. if (width !== 224 || height !== 224) {
  246. throw new Error(`Unexpected image size: ${width}x${height}`);
  247. }
  248. const size = width * height;
  249. const floatData = new Float32Array(1 * 3 * size);
  250. for (let i = 0; i < size; i += 1) {
  251. const pixelIndex = i * 4;
  252. const r = data[pixelIndex] / 255;
  253. const g = data[pixelIndex + 1] / 255;
  254. const b = data[pixelIndex + 2] / 255;
  255. floatData[i] = r;
  256. floatData[size + i] = g;
  257. floatData[size * 2 + i] = b;
  258. }
  259. return floatData;
  260. }
  261. function base64ToUint8Array(base64: string) {
  262. const cleaned = base64.replace(/[^A-Za-z0-9+/=]/g, '');
  263. const bytes: number[] = [];
  264. const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/';
  265. let i = 0;
  266. while (i < cleaned.length) {
  267. const enc1 = decodeBase64Char(cleaned.charAt(i++), chars);
  268. const enc2 = decodeBase64Char(cleaned.charAt(i++), chars);
  269. const enc3 = decodeBase64Char(cleaned.charAt(i++), chars);
  270. const enc4 = decodeBase64Char(cleaned.charAt(i++), chars);
  271. const chr1 = (enc1 << 2) | (enc2 >> 4);
  272. const chr2 = ((enc2 & 15) << 4) | (enc3 >> 2);
  273. const chr3 = ((enc3 & 3) << 6) | enc4;
  274. bytes.push(chr1);
  275. if (enc3 !== 64) bytes.push(chr2);
  276. if (enc4 !== 64) bytes.push(chr3);
  277. }
  278. return new Uint8Array(bytes);
  279. }
  280. function decodeBase64Char(char: string, chars: string) {
  281. if (char === '=') return 64;
  282. const index = chars.indexOf(char);
  283. if (index === -1) return 64;
  284. return index;
  285. }
  286. async function loadLogs() {
  287. const db = await dbPromise;
  288. const rows = await db.getAllAsync<InferenceLog>(
  289. `SELECT mi.id, mi.top_label, mi.top_score, mi.created_at, img.uri
  290. FROM ml_inferences mi
  291. LEFT JOIN images img ON img.id = mi.image_id
  292. ORDER BY mi.created_at DESC
  293. LIMIT 5;`
  294. );
  295. return rows;
  296. }
  297. async function logInference(params: {
  298. uri: string;
  299. top: Prediction[];
  300. scores: Float32Array;
  301. }) {
  302. const db = await dbPromise;
  303. const now = new Date().toISOString();
  304. const topHit = params.top[0];
  305. const scoresJson = JSON.stringify(
  306. params.top.map((item) => ({ index: item.index, label: item.label, score: item.score }))
  307. );
  308. const insert = await db.runAsync(
  309. 'INSERT INTO images (observation_id, uri, thumbnail_uri, width, height, created_at) VALUES (?, ?, ?, ?, ?, ?);',
  310. null,
  311. params.uri,
  312. null,
  313. 224,
  314. 224,
  315. now
  316. );
  317. await db.runAsync(
  318. 'INSERT INTO ml_inferences (image_id, model_name, model_version, top_label, top_score, scores_json, created_at) VALUES (?, ?, ?, ?, ?, ?, ?);',
  319. insert.lastInsertRowId,
  320. 'plantvillage_mnv3s_224',
  321. '1',
  322. topHit?.label ?? null,
  323. topHit?.score ?? null,
  324. scoresJson,
  325. now
  326. );
  327. }
  328. function humanizeLabel(label: string) {
  329. const cleaned = label
  330. .replace(/[_/]+/g, ' ')
  331. .replace(/\s+/g, ' ')
  332. .trim();
  333. return cleaned
  334. .split(' ')
  335. .map((word) => (word ? word[0].toUpperCase() + word.slice(1) : word))
  336. .join(' ');
  337. }
  338. function formatDate(value: string) {
  339. try {
  340. return new Date(value).toLocaleString();
  341. } catch {
  342. return value;
  343. }
  344. }
  345. const styles = StyleSheet.create({
  346. headerImage: {
  347. width: '100%',
  348. height: '100%',
  349. },
  350. hero: {
  351. borderRadius: 20,
  352. borderWidth: 1,
  353. flexDirection: 'row',
  354. gap: 14,
  355. padding: 16,
  356. alignItems: 'center',
  357. marginBottom: 16,
  358. },
  359. heroIcon: {
  360. borderRadius: 999,
  361. padding: 10,
  362. },
  363. heroText: {
  364. flex: 1,
  365. gap: 6,
  366. },
  367. heroTitle: {
  368. fontFamily: Fonts.rounded,
  369. },
  370. heroSubtitle: {
  371. color: '#516458',
  372. },
  373. card: {
  374. backgroundColor: '#FFFFFF',
  375. borderRadius: 18,
  376. padding: 16,
  377. gap: 12,
  378. marginBottom: 16,
  379. borderWidth: 1,
  380. borderColor: '#E3DED6',
  381. },
  382. cardHeader: {
  383. flexDirection: 'row',
  384. alignItems: 'center',
  385. gap: 12,
  386. flexWrap: 'wrap',
  387. },
  388. statusPill: {
  389. backgroundColor: '#E6F0E2',
  390. borderRadius: 999,
  391. paddingHorizontal: 10,
  392. paddingVertical: 4,
  393. },
  394. statusText: {
  395. fontSize: 12,
  396. color: '#3D5F4A',
  397. },
  398. previewWrap: {
  399. borderRadius: 16,
  400. overflow: 'hidden',
  401. backgroundColor: '#F3F0EA',
  402. borderWidth: 1,
  403. borderColor: '#E1DBD1',
  404. },
  405. preview: {
  406. height: 220,
  407. width: '100%',
  408. },
  409. previewPlaceholder: {
  410. height: 200,
  411. alignItems: 'center',
  412. justifyContent: 'center',
  413. gap: 8,
  414. },
  415. placeholderText: {
  416. color: '#6B736D',
  417. },
  418. inlineHint: {
  419. color: '#6B736D',
  420. fontSize: 13,
  421. },
  422. actionRow: {
  423. flexDirection: 'row',
  424. flexWrap: 'wrap',
  425. gap: 12,
  426. },
  427. results: {
  428. gap: 12,
  429. },
  430. logList: {
  431. gap: 10,
  432. },
  433. logRow: {
  434. flexDirection: 'row',
  435. alignItems: 'center',
  436. gap: 12,
  437. padding: 10,
  438. borderRadius: 12,
  439. backgroundColor: '#F7F4EE',
  440. },
  441. logThumb: {
  442. width: 48,
  443. height: 48,
  444. borderRadius: 10,
  445. backgroundColor: '#ECE7DE',
  446. alignItems: 'center',
  447. justifyContent: 'center',
  448. overflow: 'hidden',
  449. },
  450. logImage: {
  451. width: '100%',
  452. height: '100%',
  453. },
  454. logMeta: {
  455. flex: 1,
  456. gap: 2,
  457. },
  458. logLabel: {
  459. fontWeight: '600',
  460. },
  461. logSub: {
  462. fontSize: 12,
  463. color: '#6B736D',
  464. },
  465. logScore: {
  466. fontWeight: '600',
  467. color: '#3D5F4A',
  468. },
  469. predRow: {
  470. gap: 6,
  471. },
  472. predLabelRow: {
  473. flexDirection: 'row',
  474. justifyContent: 'space-between',
  475. alignItems: 'center',
  476. gap: 12,
  477. },
  478. predLabel: {
  479. flex: 1,
  480. },
  481. predPercent: {
  482. color: '#3D5F4A',
  483. },
  484. barTrack: {
  485. height: 8,
  486. borderRadius: 999,
  487. backgroundColor: '#E7E1D7',
  488. overflow: 'hidden',
  489. },
  490. barFill: {
  491. height: 8,
  492. borderRadius: 999,
  493. backgroundColor: '#4B7B57',
  494. },
  495. });